├── .gitignore
├── LICENSE
├── README.md
├── setup.py
├── torchfcpe
├── __init__.py
├── assets
│ └── fcpe_c_v001.pt
├── f02midi
│ ├── MIDI.py
│ ├── featureExtraction.py
│ ├── quantization.py
│ ├── transpose.py
│ └── utils.py
├── mel_extractor.py
├── mel_fn_librosa.py
├── model_conformer_naive.py
├── model_convnext.py
├── models.py
├── models_infer.py
├── tools.py
└── torch_interp.py
└── train
├── configs
└── config.yaml
├── data_loaders_wav.py
├── draw.py
├── pre_data.py
├── redis_coder.py
├── savertools
├── __init__.py
├── saver.py
└── utils.py
├── solver_wav.py
├── train_wav.py
├── utils_1.py
└── utils_all.py
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | __pycache__/
3 | torchfcpe.egg-info/
4 |
5 | test**
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 CN_ChiTu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
TorchFCPE
2 |
3 | ## Overview
4 |
5 | TorchFCPE(Fast Context-based Pitch Estimation) is a PyTorch-based library designed for audio pitch extraction and MIDI conversion. This README provides a quick guide on how to use the library for audio pitch inference and MIDI extraction.
6 |
7 | Note: that the MIDI extractor of FCPE is quantized from f0 using non neural network methods
8 |
9 | Note: I won't be updating FCPE (or benchmark) so soon, but I will definitely release a version with cleaned-up code by no later than next year.
10 |
11 | ## Installation
12 |
13 | Before using the library, make sure you have the necessary dependencies installed:
14 |
15 | ```bash
16 | pip install torchfcpe
17 | ```
18 |
19 | ## Usage
20 |
21 | ### 1. Audio Pitch Inference
22 |
23 | ```python
24 | from torchfcpe import spawn_bundled_infer_model
25 | import torch
26 | import librosa
27 |
28 | # Configure device and target hop size
29 | device = 'cpu' # or 'cuda' if using a GPU
30 | sr = 16000 # Sample rate
31 | hop_size = 160 # Hop size for processing
32 |
33 | # Load and preprocess audio
34 | audio, sr = librosa.load('test.wav', sr=sr)
35 | audio = librosa.to_mono(audio)
36 | audio_length = len(audio)
37 | f0_target_length = (audio_length // hop_size) + 1
38 | audio = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(-1).to(device)
39 |
40 | # Load the model
41 | model = spawn_bundled_infer_model(device=device)
42 |
43 | # Perform pitch inference
44 | f0 = model.infer(
45 | audio,
46 | sr=sr,
47 | decoder_mode='local_argmax', # Recommended mode
48 | threshold=0.006, # Threshold for V/UV decision
49 | f0_min=80, # Minimum pitch
50 | f0_max=880, # Maximum pitch
51 | interp_uv=False, # Interpolate unvoiced frames
52 | output_interp_target_length=f0_target_length, # Interpolate to target length
53 | )
54 |
55 | print(f0)
56 | ```
57 |
58 | ### 2. MIDI Extraction
59 |
60 | ```python
61 | # Extract MIDI from audio
62 | midi = model.extact_midi(
63 | audio,
64 | sr=sr,
65 | decoder_mode='local_argmax', # Recommended mode
66 | threshold=0.006, # Threshold for V/UV decision
67 | f0_min=80, # Minimum pitch
68 | f0_max=880, # Maximum pitch
69 | output_path="test.mid", # Save MIDI to file
70 | )
71 |
72 | print(midi)
73 | ```
74 |
75 | ### Notes
76 |
77 | - **Inference Parameters:**
78 |
79 | - `audio`: Input audio as a `torch.Tensor`.
80 | - `sr`: Sample rate of the audio.
81 | - `decoder_mode` (Optional): Mode for decoding, 'local_argmax' is recommended.
82 | - `threshold` (Optional): Threshold for voice/unvoiced decision; default is 0.006.
83 | - `f0_min` (Optional): Minimum pitch value; default is 80 Hz.
84 | - `f0_max` (Optional): Maximum pitch value; default is 880 Hz.
85 | - `interp_uv` (Optional): Whether to interpolate unvoiced frames; default is False.
86 | - `output_interp_target_length` (Optional): Length to which the output pitch should be interpolated.
87 |
88 | - **MIDI Extraction Parameters:**
89 | - `audio`: Input audio as a `torch.Tensor`.
90 | - `sr`: Sample rate of the audio.
91 | - `decoder_mode` (Optional): Mode for decoding; 'local_argmax' is recommended.
92 | - `threshold` (Optional): Threshold for voice/unvoiced decision; default is 0.006.
93 | - `f0_min` (Optional): Minimum pitch value; default is 80 Hz.
94 | - `f0_max` (Optional): Maximum pitch value; default is 880 Hz.
95 | - `output_path` (Optional): File path to save the MIDI file. If not provided, only returns the MIDI structure.
96 | - `tempo` (Optional): BPM for the MIDI file. If None, BPM is automatically predicted.
97 |
98 | ## Additional Features
99 |
100 | - **Model as a PyTorch Module:**
101 | You can use the model as a standard PyTorch module. For example:
102 |
103 | ```python
104 | # Change device
105 | model = model.to(device)
106 |
107 | # Compile model
108 | model = torch.compile(model)
109 | ```
110 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 |
4 | with open('README.md', encoding='utf8') as file:
5 | long_description = file.read()
6 |
7 |
8 | setup(
9 | name='torchfcpe',
10 | description='The official Pytorch implementation of Fast Context-based Pitch Estimation (FCPE)',
11 | version='0.0.4',
12 | author='CNChTu',
13 | author_email='2921046558@qq.com',
14 | url='https://github.com/CNChTu/FCPE',
15 | install_requires=['einops', 'local_attention', 'torch', 'torchaudio', 'numpy', 'scipy', 'librosa', 'pydub', 'pretty_midi'],
16 | packages=['torchfcpe'],
17 | package_data={'torchfcpe': ['assets/*']},
18 | long_description=long_description,
19 | long_description_content_type='text/markdown',
20 | keywords=['pitch', 'audio', 'speech', 'music', 'pytorch', 'fcpe'],
21 | classifiers=['License :: OSI Approved :: MIT License'],
22 | license='MIT')
23 |
--------------------------------------------------------------------------------
/torchfcpe/__init__.py:
--------------------------------------------------------------------------------
1 | from .tools import (
2 | spawn_wav2mel,
3 | )
4 | from .models_infer import (
5 | spawn_model,
6 | spawn_infer_model_from_pt,
7 | spawn_infer_model_from_onnx,
8 | spawn_bundled_infer_model,
9 | bundled_infer_model_unit_test
10 | )
11 |
--------------------------------------------------------------------------------
/torchfcpe/assets/fcpe_c_v001.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CNChTu/FCPE/d55c1b636dc3564dc35be19675998688931e870c/torchfcpe/assets/fcpe_c_v001.pt
--------------------------------------------------------------------------------
/torchfcpe/f02midi/MIDI.py:
--------------------------------------------------------------------------------
1 | #%%
2 | import pretty_midi
3 | import numpy as np
4 | import librosa.display
5 |
6 |
7 | #%%
8 | def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
9 | """ Plot piano roll from .mid file
10 | ----------
11 | Parameters:
12 | pm: RWC, MDB, iKala, DSD100
13 | start/end_pitch: lowest/highest note (float)
14 | fs: sampling freq. (int)
15 |
16 | """
17 | # Use librosa's specshow function for displaying the piano roll
18 | librosa.display.specshow(
19 | pm.get_piano_roll(fs)[start_pitch:end_pitch],
20 | hop_length=1,
21 | sr=fs,
22 | x_axis="time",
23 | y_axis="cqt_note",
24 | fmin=pretty_midi.note_number_to_hz(start_pitch),
25 | )
26 |
27 |
28 | def midi_to_note(file_name, pitch_shift, fs=100, start_note=40, end_note=95):
29 | """ Convert .mid to note
30 | ----------
31 | Parameters:
32 | file_name: '.mid' (str)
33 | pitch_sifht: shift the pitch to adjust notes correctly (int)
34 | fs: sampling freq. (int)
35 | start/end_pitch: lowest/highest note(int)
36 |
37 | ----------
38 | Returns:
39 | notes: note/10ms (array)
40 | """
41 |
42 | pm = pretty_midi.PrettyMIDI(file_name)
43 | frame_note = pm.get_piano_roll(fs)[start_note:end_note]
44 |
45 | length_audio = frame_note.shape[1]
46 | notes = np.zeros(length_audio)
47 |
48 | for i in range(length_audio):
49 | note_tmp = np.argmax(frame_note[:, i])
50 | if note_tmp > 0:
51 | notes[i] = (note_tmp + start_note) + pitch_shift
52 | # note[i] = 2 ** ((note_tmp -69) / 12.) * 440
53 | return notes
54 |
55 |
56 | def midi_to_segment(filename):
57 | """ Convert .mid to segment
58 | ----------
59 | Parameters:
60 | filename: .mid (str)
61 |
62 | ----------
63 | Returns:
64 | segments: [start(s),end(s),pitch] (list)
65 | """
66 |
67 | pm = pretty_midi.PrettyMIDI(filename)
68 | segment = []
69 | for note in pm.instruments[0].notes:
70 | segment.append([note.start, note.end, note.pitch])
71 | return segment
72 |
73 |
74 | def segment_to_midi(segments, path_output, tempo=120):
75 | """ Convert segment to .mid
76 | ----------
77 | Parameters:
78 | segments: [start(s),end(s),pitch] (list)
79 | path_output: path of save file (str)
80 | """
81 | pm = pretty_midi.PrettyMIDI(initial_tempo=int(tempo))
82 | inst_program = pretty_midi.instrument_name_to_program("Acoustic Grand Piano")
83 | inst = pretty_midi.Instrument(program=inst_program)
84 | for segment in segments:
85 | note = pretty_midi.Note(
86 | velocity=100, start=segment[0], end=segment[1], pitch=np.int32(segment[2])
87 | )
88 | inst.notes.append(note)
89 | pm.instruments.append(inst)
90 | pm.write(f"{path_output}")
91 |
92 |
93 | def note_to_segment(note):
94 | """ Convert note to segment
95 | ----------
96 | Parameters:
97 | note: note/10ms (array)
98 | ----------
99 | Returns:
100 | segments: [start(s),end(s),pitch] (list)
101 | """
102 | startSeg = []
103 | endSeg = []
104 | notes = []
105 | flag = -1
106 |
107 | if note[0] > 0:
108 | startSeg.append(0)
109 | notes.append(np.int32(note[0]))
110 | flag *= -1
111 | for i in range(0, len(note) - 1):
112 | if note[i] != note[i + 1]:
113 | if flag < 0:
114 | startSeg.append(0.01 * (i + 1))
115 | notes.append(np.int32(note[i + 1]))
116 | flag *= -1
117 | else:
118 | if note[i + 1] == 0:
119 | endSeg.append(0.01 * i)
120 | flag *= -1
121 | else:
122 | endSeg.append(0.01 * i)
123 | startSeg.append(0.01 * (i + 1))
124 | notes.append(np.int32(note[i + 1]))
125 |
126 | return list(zip(startSeg, endSeg, notes))
127 |
128 |
129 | def note2Midi(frame_level_pitchscroe, path_output, tempo):
130 | # note = np.loadtxt(path_input_note)
131 | # note = note[:, 1]
132 | segment = note_to_segment(frame_level_pitchscroe)
133 | segment_to_midi(segment, path_output=path_output, tempo=tempo)
134 |
135 |
136 | # def note2Midi(path_input_note, path_output, tempo):
137 | # note = np.loadtxt(path_input_note)
138 | # note = note[:, 1]
139 | # segment = note_to_segment(note)
140 | # segment_to_midi(segment, path_output=path_output, tempo=tempo)
141 |
142 |
--------------------------------------------------------------------------------
/torchfcpe/f02midi/featureExtraction.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import librosa
3 | from pydub import AudioSegment
4 | import pathlib
5 |
6 | # from pydub.playback import play
7 | import numpy as np
8 | import os
9 |
10 | PATH_PROJECT = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | def read_audio(filepath, sr=None):
14 | path = pathlib.Path(filepath)
15 | extenstion = path.suffix.replace(".", "")
16 | if extenstion == "mp3":
17 | sound = AudioSegment.from_mp3(filepath)
18 | else:
19 | sound = AudioSegment.from_file(filepath)
20 | # sound = sound[start * 1000 : end * 1000]
21 | sound = sound.set_channels(1)
22 | if sr == None:
23 | sr = sound.frame_rate
24 | sound = sound.set_frame_rate(sr)
25 | samples = sound.get_array_of_samples()
26 | y = np.array(samples).T.astype(np.float32)
27 |
28 | return y, sr
29 |
30 |
31 | def spec_extraction(file_name, win_size):
32 |
33 | y, _ = read_audio(file_name, sr=8000)
34 |
35 | S = librosa.core.stft(y, n_fft=1024, hop_length=80, win_length=1024)
36 | x_spec = np.abs(S)
37 | x_spec = librosa.core.power_to_db(x_spec, ref=np.max)
38 | x_spec = x_spec.astype(np.float32)
39 | num_frames = x_spec.shape[1]
40 |
41 | # for padding
42 | padNum = num_frames % win_size
43 | if padNum != 0:
44 | len_pad = win_size - padNum
45 | padding_feature = np.zeros(shape=(513, len_pad))
46 | x_spec = np.concatenate((x_spec, padding_feature), axis=1)
47 | num_frames = num_frames + len_pad
48 |
49 | x_test = []
50 | for j in range(0, num_frames, win_size):
51 | x_test_tmp = x_spec[:, range(j, j + win_size)].T
52 | x_test.append(x_test_tmp)
53 | x_test = np.array(x_test)
54 |
55 | # for standardization
56 | path_project = pathlib.Path(__file__).parent.parent
57 | x_train_mean = np.load(f"{path_project}/data/x_train_mean.npy")
58 | x_train_std = np.load(f"{path_project}/data/x_train_std.npy")
59 | x_test = (x_test - x_train_mean) / (x_train_std + 0.0001)
60 | x_test = x_test[:, :, :, np.newaxis]
61 | return x_test, x_spec
62 |
--------------------------------------------------------------------------------
/torchfcpe/f02midi/quantization.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import numpy as np
3 | import librosa
4 | import librosa.display
5 |
6 | from scipy.signal import medfilt
7 | from matplotlib import pyplot as plt
8 | from .featureExtraction import read_audio
9 | from .utils import *
10 |
11 |
12 | # %%
13 | def calc_tempo(path_audio):
14 | """ Calculate audio tempo
15 | ----------
16 | Parameters:
17 | path_audio: str
18 |
19 | ----------
20 | Returns:
21 | tempo: float
22 |
23 | """
24 | target_sr = 22050
25 | y, _ = read_audio(path_audio, sr=target_sr)
26 | onset_strength = librosa.onset.onset_strength(y, sr=target_sr)
27 | tempo = librosa.beat.tempo(onset_envelope=onset_strength, sr=target_sr)
28 | return tempo
29 |
30 |
31 | def one_beat_frame_size(tempo):
32 | """ Calculate frame size of 1 beat
33 | ----------
34 | Parameters:
35 | tempo: float
36 |
37 | ----------
38 | Returns:
39 | tempo: int
40 |
41 | """
42 | return np.int32(np.round(60 / tempo * 100))
43 |
44 |
45 | def median_filter_pitch(pitch, medfilt_size, weight):
46 | """ Smoothing pitch using median filter
47 | ----------
48 | Parameters:
49 | pitch: array
50 | medfilt_size: int
51 | weight: float
52 |
53 | ----------
54 | Returns:
55 | pitch: array
56 |
57 | """
58 |
59 | medfilt_size = np.int32(medfilt_size * weight)
60 | if medfilt_size % 2 == 0:
61 | medfilt_size += 1
62 | return np.round(medfilt(pitch, medfilt_size))
63 |
64 |
65 | def clean_note_frames(note, min_note_len=5):
66 | """ Remove short pitch frames
67 | ----------
68 | Parameters:
69 | note: array
70 | min_note_len: int
71 |
72 | ----------
73 | Returns:
74 | output: array
75 |
76 | """
77 |
78 | prev_pitch = 0
79 | prev_pitch_start = 0
80 | output = np.copy(note)
81 | for i in range(len(note)):
82 | pitch = note[i]
83 | if pitch != prev_pitch:
84 | prev_pitch_duration = i - prev_pitch_start
85 | if prev_pitch_duration < min_note_len:
86 | output[prev_pitch_start:i] = [0] * prev_pitch_duration
87 | prev_pitch = pitch
88 | prev_pitch_start = i
89 | return output
90 |
91 |
92 | def makeSegments(note):
93 | """ Make segments of notes
94 | ----------
95 | Parameters:
96 | note: array
97 |
98 | ----------
99 | Returns:
100 | startSeg: starting points (array)
101 | endSeg: ending points (array)
102 |
103 | """
104 | startSeg = []
105 | endSeg = []
106 | flag = -1
107 | if note[0] > 0:
108 | startSeg.append(0)
109 | flag *= -1
110 | for i in range(0, len(note) - 1):
111 | if note[i] != note[i + 1]:
112 | if flag < 0:
113 | startSeg.append(i + 1)
114 | flag *= -1
115 | else:
116 | if note[i + 1] == 0:
117 | endSeg.append(i)
118 | flag *= -1
119 | else:
120 | endSeg.append(i)
121 | startSeg.append(i + 1)
122 | return startSeg, endSeg
123 |
124 |
125 | def remove_short_segment(idx, note_cleaned, start, end, minLength):
126 | """ Remove short segments
127 | ----------
128 | Parameters:
129 | idx: (int)
130 | note_cleaned: (array)
131 | start: starting points (array)
132 | end: ending points (array)
133 | minLength: (int)
134 |
135 | ----------
136 | Returns:
137 | note_cleaned: (array)
138 |
139 | """
140 |
141 | len_seg = end[idx] - start[idx]
142 | if len_seg < minLength:
143 | if (start[idx + 1] - end[idx] > minLength) and (start[idx] - end[idx - 1] > minLength):
144 | note_cleaned[start[idx] : end[idx] + 1] = [0] * (len_seg + 1)
145 | return note_cleaned
146 |
147 |
148 | def remove_octave_error(idx, note_cleaned, start, end):
149 | """ Remove octave error
150 | ----------
151 | Parameters:
152 | idx: (int)
153 | note_cleaned: (array)
154 | start: starting points (array)
155 | end: ending points (array)
156 |
157 | ----------
158 | Returns:
159 | note_cleaned: (array)
160 |
161 | """
162 | len_seg = end[idx] - start[idx]
163 | if (note_cleaned[start[idx - 1]] == note_cleaned[start[idx + 1]]) and (
164 | note_cleaned[start[idx]] != note_cleaned[start[idx + 1]]
165 | ):
166 | if np.abs(note_cleaned[start[idx]] - note_cleaned[start[idx + 1]]) % 12 == 0:
167 | note_cleaned[start[idx] - 1 : end[idx] + 1] = [note_cleaned[start[idx + 1]]] * (
168 | len_seg + 2
169 | )
170 | return note_cleaned
171 |
172 |
173 | def clean_segment(note, minLength):
174 | """ clean note segments
175 | ----------
176 | Parameters:
177 | note: (array)
178 | minLength: (int)
179 |
180 | ----------
181 | Returns:
182 | note_cleaned: (array)
183 |
184 | """
185 |
186 | note_cleaned = np.copy(note)
187 | start, end = makeSegments(note_cleaned)
188 |
189 | for i in range(1, len(start) - 1):
190 | note_cleaned = remove_short_segment(i, note_cleaned, start, end, minLength)
191 | note_cleaned = remove_octave_error(i, note_cleaned, start, end)
192 | return note_cleaned
193 |
194 |
195 | def refine_note(est_note, tempo):
196 | """ main: refine note segments
197 | ----------
198 | Parameters:
199 | est_note: (array)
200 | tempo: (float)
201 |
202 | ----------
203 | Returns:
204 | est_pitch_mf3_v: (array)
205 |
206 | """
207 | one_beat_size = one_beat_frame_size(tempo)
208 | est_note_mf1 = median_filter_pitch(est_note, one_beat_size, 1 / 6)
209 | est_note_mf2 = median_filter_pitch(est_note_mf1, one_beat_size, 1 / 3)
210 | est_note_mf3 = median_filter_pitch(est_note_mf2, one_beat_size, 1 / 2)
211 |
212 | vocing = est_note_mf1 > 0
213 | est_pitch_mf3_v = vocing * est_note_mf3
214 | est_pitch_mf3_v = clean_note_frames(est_pitch_mf3_v, int(one_beat_size * 1 / 4))
215 | est_pitch_mf3_v = clean_segment(est_pitch_mf3_v, int(one_beat_size * 1 / 4))
216 | return est_pitch_mf3_v
217 |
218 |
--------------------------------------------------------------------------------
/torchfcpe/f02midi/transpose.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # %%
3 | import argparse
4 | import numpy as np
5 | from pathlib import Path
6 | from .featureExtraction import *
7 | from .quantization import *
8 | from .utils import *
9 | from .MIDI import *
10 | import librosa
11 |
12 | def f0_to_note(f0):
13 | """ convert frame-level pitch score(hz) to note-level (time-axis) """
14 | note = 69 + 12 * np.log2(f0 / 440 + 1e-4)
15 | note = np.round(note)
16 | note = note.astype(int)
17 | note[note < 0] = 0
18 | note[note > 127] = 127
19 | return note
20 |
21 | def f02midi(f0, tempo = None, y = None, sr = None, output_path = None):
22 | """ f0 shape: (n_frames,) """
23 |
24 | if tempo is None:
25 | if y is not None:
26 | target_sr = 22050
27 | y = librosa.resample(y = y, orig_sr = sr, target_sr = target_sr)
28 | onset_strength = librosa.onset.onset_strength(y = y, sr=target_sr)
29 | tempo = librosa.beat.tempo(onset_envelope=onset_strength, sr=target_sr)
30 | else:
31 | tempo = 120
32 |
33 | f0 = f0_to_note(f0)
34 | refined_fl_note = refine_note(f0, tempo) # frame-level pitch score
35 |
36 | """ convert frame-level pitch score to note-level (time-axis) """
37 | segment = note_to_segment(refined_fl_note) # note-level pitch score
38 | if output_path is None:
39 | return segment
40 | else:
41 | """ save ouput to .mid """
42 | segment_to_midi(segment, path_output=output_path, tempo=tempo)
43 | return segment
44 |
--------------------------------------------------------------------------------
/torchfcpe/f02midi/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pydub import AudioSegment
4 | import pathlib
5 |
6 |
7 | def check_and_make_dir(path_dir):
8 | if not os.path.exists(os.path.dirname(path_dir)):
9 | os.makedirs(os.path.dirname(path_dir))
10 |
11 |
12 | def get_filename_wo_extension(path_dir):
13 | return pathlib.Path(path_dir).stem
14 |
15 |
16 | def note2pitch(pitch):
17 | """ Convert MIDI number to freq.
18 | ----------
19 | Parameters:
20 | pitch: MIDI note numbers of pitch (array)
21 |
22 | ----------
23 | Returns:
24 | pitch: freqeuncy of pitch (array)
25 | """
26 |
27 | pitch = np.array(pitch)
28 | pitch[pitch > 0] = 2 ** ((pitch[pitch > 0] - 69) / 12.0) * 440
29 | return pitch
30 |
31 |
32 | def pitch2note(pitch):
33 | """ Convert freq to MIDI number
34 | ----------
35 | Parameters:
36 | pitch: freqeuncy of pitch (array)
37 |
38 | ----------
39 | Returns:
40 | pitch: MIDI note numbers of pitch (array)
41 | """
42 | pitch = np.array(pitch)
43 | pitch[pitch > 0] = np.round((69.0 + 12.0 * np.log2(pitch[pitch > 0] / 440.0)))
44 | return pitch
45 |
46 |
47 | a = np.array([0, 0, 0, 1, 2, 3, 5, 0, 0, 0, 1, 2, 4, 5])
48 | b = a[a > 0] * 2
49 | print(b)
50 |
--------------------------------------------------------------------------------
/torchfcpe/mel_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torchaudio.transforms import Resample
6 |
7 | import os
8 |
9 | os.environ["LRU_CACHE_CAPACITY"] = "3"
10 |
11 | try:
12 | from librosa.filters import mel as librosa_mel_fn
13 | except ImportError:
14 | print(' [INF0] torchfcpe.mel_tools.nv_mel_extractor: Librosa not found,'
15 | ' use torchfcpe.mel_tools.mel_fn_librosa instead.')
16 | from .mel_fn_librosa import mel as librosa_mel_fn
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | return torch.log(torch.clamp(x, min=clip_val) * C)
21 |
22 |
23 | class HannWindow(torch.nn.Module):
24 | def __init__(self, win_size):
25 | super().__init__()
26 | self.register_buffer('window', torch.hann_window(win_size), persistent=False)
27 |
28 | def forward(self):
29 | return self.window
30 |
31 |
32 | class MelModule(torch.nn.Module):
33 | """Mel extractor
34 |
35 | Args:
36 | sr (int): Sampling rate. Defaults to 16000.
37 | n_mels (int): Number of mel bins. Defaults to 128.
38 | n_fft (int): FFT size. Defaults to 1024.
39 | win_size (int): Window size. Defaults to 1024.
40 | hop_length (int): Hop length. Defaults to 160.
41 | fmin (float, optional): Minimum frequency. Defaults to 0.
42 | fmax (float, optional): Maximum frequency. Defaults to sr/2.
43 | clip_val (float, optional): Clipping value. Defaults to 1e-5.
44 | """
45 |
46 | def __init__(self,
47 | sr: [int, float],
48 | n_mels: int,
49 | n_fft: int,
50 | win_size: int,
51 | hop_length: int,
52 | fmin: float = None,
53 | fmax: float = None,
54 | clip_val: float = 1e-5,
55 | out_stft: bool = False,
56 | ):
57 | super().__init__()
58 | if fmin is None:
59 | fmin = 0
60 | if fmax is None:
61 | fmax = sr / 2
62 | self.target_sr = sr
63 | self.n_mels = n_mels
64 | self.n_fft = n_fft
65 | self.win_size = win_size
66 | self.hop_length = hop_length
67 | self.fmin = fmin
68 | self.fmax = fmax
69 | self.clip_val = clip_val
70 | # self.mel_basis = {}
71 | self.register_buffer(
72 | 'mel_basis',
73 | torch.tensor(librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(),
74 | persistent=False
75 | )
76 | self.hann_window = torch.nn.ModuleDict()
77 | self.out_stft = out_stft
78 |
79 | @torch.no_grad()
80 | def __call__(self,
81 | y: torch.Tensor, # (B, T, 1)
82 | key_shift: [int, float] = 0,
83 | speed: [int, float] = 1,
84 | center: bool = False,
85 | no_cache_window: bool = False
86 | ) -> torch.Tensor: # (B, T, n_mels)
87 | """Get mel spectrogram
88 |
89 | Args:
90 | y (torch.Tensor): Input waveform, shape=(B, T, 1).
91 | key_shift (int, optional): Key shift. Defaults to 0.
92 | speed (int, optional): Variable speed enhancement factor. Defaults to 1.
93 | center (bool, optional): center for torch.stft. Defaults to False.
94 | no_cache_window (bool, optional): If True will clear cache. Defaults to False.
95 | return:
96 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels).
97 | """
98 |
99 | n_fft = self.n_fft
100 | win_size = self.win_size
101 | hop_length = self.hop_length
102 | clip_val = self.clip_val
103 |
104 | factor = 2 ** (key_shift / 12)
105 | n_fft_new = int(np.round(n_fft * factor))
106 | win_size_new = int(np.round(win_size * factor))
107 | hop_length_new = int(np.round(hop_length * speed))
108 |
109 | y = y.squeeze(-1)
110 |
111 | if torch.min(y) < -1.:
112 | print('[error with torchfcpe.mel_extractor.MelModule]min value is ', torch.min(y))
113 | if torch.max(y) > 1.:
114 | print('[error with torchfcpe.mel_extractor.MelModule]max value is ', torch.max(y))
115 |
116 | key_shift_key = str(key_shift)
117 | if not no_cache_window:
118 | if key_shift_key in self.hann_window:
119 | hann_window = self.hann_window[key_shift_key]
120 | else:
121 | hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
122 | self.hann_window[key_shift_key] = hann_window
123 | hann_window_tensor = hann_window()
124 | else:
125 | hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
126 |
127 | pad_left = (win_size_new - hop_length_new) // 2
128 | pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
129 | if pad_right < y.size(-1):
130 | mode = 'reflect'
131 | else:
132 | mode = 'constant'
133 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
134 | y = y.squeeze(1)
135 |
136 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new,
137 | window=hann_window_tensor,
138 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
139 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
140 | if key_shift != 0:
141 | size = n_fft // 2 + 1
142 | resize = spec.size(1)
143 | if resize < size:
144 | spec = F.pad(spec, (0, 0, 0, size - resize))
145 | spec = spec[:, :size, :] * win_size / win_size_new
146 | if self.out_stft:
147 | spec = spec[:, :512, :]
148 | else:
149 | spec = torch.matmul(self.mel_basis, spec)
150 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
151 | spec = spec.transpose(-1, -2)
152 | return spec # (B, T, n_mels)
153 |
154 |
155 | class Wav2MelModule(torch.nn.Module):
156 | """
157 | Wav to mel converter
158 | NOTE: This class of code is reserved for training only, please use Wav2MelModule for inference
159 |
160 | Args:
161 | sr (int): Sampling rate. Defaults to 16000.
162 | n_mels (int): Number of mel bins. Defaults to 128.
163 | n_fft (int): FFT size. Defaults to 1024.
164 | win_size (int): Window size. Defaults to 1024.
165 | hop_length (int): Hop length. Defaults to 160.
166 | fmin (float, optional): Minimum frequency. Defaults to 0.
167 | fmax (float, optional): Maximum frequency. Defaults to sr/2.
168 | clip_val (float, optional): Clipping value. Defaults to 1e-5.
169 | device (str, optional): Device. Defaults to 'cpu'.
170 | """
171 |
172 | def __init__(self,
173 | sr: [int, float],
174 | n_mels: int,
175 | n_fft: int,
176 | win_size: int,
177 | hop_length: int,
178 | fmin: float = None,
179 | fmax: float = None,
180 | clip_val: float = 1e-5,
181 | mel_type="default",
182 | ):
183 | super().__init__()
184 | # catch None
185 | if fmin is None:
186 | fmin = 0
187 | if fmax is None:
188 | fmax = sr / 2
189 | # init
190 | self.sampling_rate = sr
191 | self.n_mels = n_mels
192 | self.n_fft = n_fft
193 | self.win_size = win_size
194 | self.hop_size = hop_length
195 | self.fmin = fmin
196 | self.fmax = fmax
197 | self.clip_val = clip_val
198 | # self.device = device
199 | self.register_buffer(
200 | 'tensor_device_marker',
201 | torch.tensor(1.0).float(),
202 | persistent=False
203 | )
204 | self.resample_kernel = torch.nn.ModuleDict()
205 | if mel_type == "default":
206 | self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val,
207 | out_stft=False)
208 | elif mel_type == "stft":
209 | self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val,
210 | out_stft=True)
211 | self.mel_type = mel_type
212 |
213 | def device(self):
214 | """Get device"""
215 | return self.tensor_device_marker.device
216 |
217 | @torch.no_grad()
218 | def __call__(self,
219 | audio: torch.Tensor, # (B, T, 1)
220 | sample_rate: [int, float],
221 | keyshift: [int, float] = 0,
222 | no_cache_window: bool = False
223 | ) -> torch.Tensor: # (B, T, n_mels)
224 | """
225 | Get mel spectrogram
226 |
227 | Args:
228 | audio (torch.Tensor): Input waveform, shape=(B, T, 1).
229 | sample_rate (int): Sampling rate.
230 | keyshift (int, optional): Key shift. Defaults to 0.
231 | no_cache_window (bool, optional): If True will clear cache. Defaults to False.
232 | return:
233 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels).
234 | """
235 |
236 | # resample
237 | if sample_rate == self.sampling_rate:
238 | audio_res = audio
239 | else:
240 | key_str = str(sample_rate)
241 | if key_str not in self.resample_kernel:
242 | if len(self.resample_kernel) > 8:
243 | self.resample_kernel.clear()
244 | self.resample_kernel[key_str] = Resample(
245 | sample_rate,
246 | self.sampling_rate,
247 | lowpass_filter_width=128
248 | ).to(self.tensor_device_marker.device)
249 | audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
250 |
251 | # extract
252 | mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
253 | n_frames = int(audio.shape[1] // self.hop_size) + 1
254 | if n_frames > int(mel.shape[1]):
255 | mel = torch.cat((mel, mel[:, -1:, :]), 1)
256 | if n_frames < int(mel.shape[1]):
257 | mel = mel[:, :n_frames, :]
258 |
259 | return mel # (B, T, n_mels)
260 |
261 |
262 | class MelExtractor:
263 | """Mel extractor
264 | NOTE: This class of code is reserved for training only, please use MelModule for inference
265 |
266 | Args:
267 | sr (int): Sampling rate. Defaults to 16000.
268 | n_mels (int): Number of mel bins. Defaults to 128.
269 | n_fft (int): FFT size. Defaults to 1024.
270 | win_size (int): Window size. Defaults to 1024.
271 | hop_length (int): Hop length. Defaults to 160.
272 | fmin (float, optional): Minimum frequency. Defaults to 0.
273 | fmax (float, optional): Maximum frequency. Defaults to sr/2.
274 | clip_val (float, optional): Clipping value. Defaults to 1e-5.
275 | """
276 |
277 | def __init__(self,
278 | sr: [int, float],
279 | n_mels: int,
280 | n_fft: int,
281 | win_size: int,
282 | hop_length: int,
283 | fmin: float = None,
284 | fmax: float = None,
285 | clip_val: float = 1e-5,
286 | out_stft: bool = False,
287 | ):
288 | if fmin is None:
289 | fmin = 0
290 | if fmax is None:
291 | fmax = sr / 2
292 | self.target_sr = sr
293 | self.n_mels = n_mels
294 | self.n_fft = n_fft
295 | self.win_size = win_size
296 | self.hop_length = hop_length
297 | self.fmin = fmin
298 | self.fmax = fmax
299 | self.clip_val = clip_val
300 | self.mel_basis = {}
301 | self.hann_window = {}
302 | self.out_stft = out_stft
303 |
304 | @torch.no_grad()
305 | def __call__(self,
306 | y: torch.Tensor, # (B, T, 1)
307 | key_shift: [int, float] = 0,
308 | speed: [int, float] = 1,
309 | center: bool = False,
310 | no_cache_window: bool = False
311 | ) -> torch.Tensor: # (B, T, n_mels)
312 | """Get mel spectrogram
313 |
314 | Args:
315 | y (torch.Tensor): Input waveform, shape=(B, T, 1).
316 | key_shift (int, optional): Key shift. Defaults to 0.
317 | speed (int, optional): Variable speed enhancement factor. Defaults to 1.
318 | center (bool, optional): center for torch.stft. Defaults to False.
319 | no_cache_window (bool, optional): If True will clear cache. Defaults to False.
320 | return:
321 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels).
322 | """
323 |
324 | sampling_rate = self.target_sr
325 | n_mels = self.n_mels
326 | n_fft = self.n_fft
327 | win_size = self.win_size
328 | hop_length = self.hop_length
329 | fmin = self.fmin
330 | fmax = self.fmax
331 | clip_val = self.clip_val
332 |
333 | factor = 2 ** (key_shift / 12)
334 | n_fft_new = int(np.round(n_fft * factor))
335 | win_size_new = int(np.round(win_size * factor))
336 | hop_length_new = int(np.round(hop_length * speed))
337 | if not no_cache_window:
338 | mel_basis = self.mel_basis
339 | hann_window = self.hann_window
340 | else:
341 | mel_basis = {}
342 | hann_window = {}
343 |
344 | y = y.squeeze(-1)
345 |
346 | if torch.min(y) < -1.:
347 | print('min value is ', torch.min(y))
348 | if torch.max(y) > 1.:
349 | print('max value is ', torch.max(y))
350 |
351 | mel_basis_key = str(fmax) + '_' + str(y.device)
352 | if (mel_basis_key not in mel_basis) and (not self.out_stft):
353 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
354 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
355 |
356 | key_shift_key = str(key_shift) + '_' + str(y.device)
357 | if key_shift_key not in hann_window:
358 | hann_window[key_shift_key] = torch.hann_window(win_size_new).to(y.device)
359 |
360 | pad_left = (win_size_new - hop_length_new) // 2
361 | pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
362 | if pad_right < y.size(-1):
363 | mode = 'reflect'
364 | else:
365 | mode = 'constant'
366 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
367 | y = y.squeeze(1)
368 |
369 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new,
370 | window=hann_window[key_shift_key],
371 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
372 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
373 | if key_shift != 0:
374 | size = n_fft // 2 + 1
375 | resize = spec.size(1)
376 | if resize < size:
377 | spec = F.pad(spec, (0, 0, 0, size - resize))
378 | spec = spec[:, :size, :] * win_size / win_size_new
379 | if self.out_stft:
380 | spec = spec[:, :512, :]
381 | else:
382 | spec = torch.matmul(mel_basis[mel_basis_key], spec)
383 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
384 | spec = spec.transpose(-1, -2)
385 | return spec # (B, T, n_mels)
386 |
387 |
388 | # init nv_mel_extractor cache
389 | # will remove this when we have a better solution
390 | # mel_extractor = MelExtractor(16000, 128, 1024, 1024, 160, 0, 8000)
391 |
392 |
393 | class Wav2Mel:
394 | """
395 | Wav to mel converter
396 | NOTE: This class of code is reserved for training only, please use Wav2MelModule for inference
397 |
398 | Args:
399 | sr (int): Sampling rate. Defaults to 16000.
400 | n_mels (int): Number of mel bins. Defaults to 128.
401 | n_fft (int): FFT size. Defaults to 1024.
402 | win_size (int): Window size. Defaults to 1024.
403 | hop_length (int): Hop length. Defaults to 160.
404 | fmin (float, optional): Minimum frequency. Defaults to 0.
405 | fmax (float, optional): Maximum frequency. Defaults to sr/2.
406 | clip_val (float, optional): Clipping value. Defaults to 1e-5.
407 | device (str, optional): Device. Defaults to 'cpu'.
408 | """
409 |
410 | def __init__(self,
411 | sr: [int, float],
412 | n_mels: int,
413 | n_fft: int,
414 | win_size: int,
415 | hop_length: int,
416 | fmin: float = None,
417 | fmax: float = None,
418 | clip_val: float = 1e-5,
419 | device='cpu',
420 | mel_type="default",
421 | ):
422 | # catch None
423 | if fmin is None:
424 | fmin = 0
425 | if fmax is None:
426 | fmax = sr / 2
427 | # init
428 | self.sampling_rate = sr
429 | self.n_mels = n_mels
430 | self.n_fft = n_fft
431 | self.win_size = win_size
432 | self.hop_size = hop_length
433 | self.fmin = fmin
434 | self.fmax = fmax
435 | self.clip_val = clip_val
436 | self.device = device
437 | self.resample_kernel = {}
438 | if mel_type == "default":
439 | self.mel_extractor = MelExtractor(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val,
440 | out_stft=False)
441 | elif mel_type == "stft":
442 | self.mel_extractor = MelExtractor(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val,
443 | out_stft=True)
444 | self.mel_type = mel_type
445 |
446 | def device(self):
447 | """Get device"""
448 | return self.device
449 |
450 | @torch.no_grad()
451 | def __call__(self,
452 | audio: torch.Tensor, # (B, T, 1)
453 | sample_rate: [int, float],
454 | keyshift: [int, float] = 0,
455 | no_cache_window: bool = False
456 | ) -> torch.Tensor: # (B, T, n_mels)
457 | """
458 | Get mel spectrogram
459 |
460 | Args:
461 | audio (torch.Tensor): Input waveform, shape=(B, T, 1).
462 | sample_rate (int): Sampling rate.
463 | keyshift (int, optional): Key shift. Defaults to 0.
464 | no_cache_window (bool, optional): If True will clear cache. Defaults to False.
465 | return:
466 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels).
467 | """
468 |
469 | # resample
470 | if sample_rate == self.sampling_rate:
471 | audio_res = audio
472 | else:
473 | key_str = str(sample_rate)
474 | if key_str not in self.resample_kernel:
475 | self.resample_kernel[key_str] = Resample(
476 | sample_rate,
477 | self.sampling_rate,
478 | lowpass_filter_width=128
479 | ).to(self.device)
480 | audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
481 |
482 | # extract
483 | mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
484 | n_frames = int(audio.shape[1] // self.hop_size) + 1
485 | if n_frames > int(mel.shape[1]):
486 | mel = torch.cat((mel, mel[:, -1:, :]), 1)
487 | if n_frames < int(mel.shape[1]):
488 | mel = mel[:, :n_frames, :]
489 |
490 | return mel # (B, T, n_mels)
491 |
492 |
493 | def unit_text():
494 | """
495 | Test unit for nv_mel_extractor.py
496 | Should be set path to your test audio file.
497 | Need matplotlib and librosa to plot.
498 | require: pip install matplotlib librosa
499 | """
500 | import time
501 |
502 | try:
503 | import matplotlib.pyplot as plt
504 | import librosa
505 | import librosa.display
506 | except ImportError:
507 | print(' [UNIT_TEST] torchfcpe.mel_tools.nv_mel_extractor: Matplotlib or Librosa not found,'
508 | ' skip plotting.')
509 | exit(1)
510 |
511 | # spawn mel extractor and wav2mel
512 | mel_extractor_test = MelExtractor(16000, 128, 1024, 1024, 160, 0, 8000)
513 | wav2mel_test = Wav2Mel(16000, 128, 1024, 1024, 160, 0, 8000)
514 |
515 | # load audio
516 | audio_path = r'E:\AUFSe04BPyProgram\AUFSd04BPyProgram\ddsp-svc\20230308\diffusion-svc\samples\GJ2.wav'
517 | audio, sr = librosa.load(audio_path, sr=16000)
518 | audio = torch.from_numpy(audio).unsqueeze(0).unsqueeze(-1)
519 | audio = audio.to('cuda')
520 | print(' [UNIT_TEST] torchfcpe.mel_tools.mel_extractor: Audio shape: {}'.format(audio.shape))
521 |
522 | # test mel extractor
523 | start_time = time.time()
524 | mel1 = mel_extractor_test(audio, 0, 1, False)
525 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Mel extractor time cost: {:.3f}s'.format(
526 | time.time() - start_time))
527 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Mel extractor output shape: {}'.format(mel1.shape))
528 |
529 | # test wav2mel
530 | start_time = time.time()
531 | mel2 = wav2mel_test(audio, 16000, 0)
532 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2mel time cost: {:.3f}s'.format(
533 | time.time() - start_time))
534 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2mel output shape: {}'.format(mel2.shape))
535 |
536 | # test melModule
537 | mel_module = MelModule(16000, 128, 1024, 1024, 160, 0, 8000).to('cuda')
538 | mel3 = mel_module(audio, 0, 1, False).to('cuda')
539 | print(' [UNIT_TEST] torchfcpe.mel_extractor: MelModule output shape: {}'.format(mel3.shape))
540 |
541 | # test Wav2MelModule
542 | wav2mel_module = Wav2MelModule(16000, 128, 1024, 1024, 160, 0, 8000).to('cuda')
543 | mel4 = wav2mel_module(audio, 16000, 0).to('cuda')
544 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2MelModule output shape: {}'.format(mel4.shape))
545 |
546 | # plot
547 | plt.figure(figsize=(12, 4))
548 | plt.subplot(1, 5, 1)
549 | librosa.display.waveshow(audio.squeeze().cpu().numpy(), sr=16000)
550 | plt.title('Audio')
551 | plt.subplot(1, 5, 2)
552 | librosa.display.specshow(mel1.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel')
553 | plt.title('Mel extractor')
554 | plt.subplot(1, 5, 3)
555 | librosa.display.specshow(mel2.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel')
556 | plt.title('Wav2mel')
557 |
558 | plt.subplot(1, 5, 4)
559 | librosa.display.specshow(mel3.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel')
560 | plt.title('MelModule')
561 | plt.subplot(1, 5, 5)
562 | librosa.display.specshow(mel4.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel')
563 | plt.title('Wav2MelModule')
564 |
565 | plt.tight_layout()
566 | plt.show()
567 |
568 |
569 | if __name__ == '__main__':
570 | unit_text()
571 |
--------------------------------------------------------------------------------
/torchfcpe/mel_fn_librosa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """
4 | from librosa.filters
5 | """
6 |
7 |
8 | def mel(
9 | *,
10 | sr,
11 | n_fft,
12 | n_mels=128,
13 | fmin=0.0,
14 | fmax=None,
15 | htk=False,
16 | norm="slaney",
17 | dtype=np.float32,
18 | ):
19 | """Create a Mel filter-bank.
20 |
21 | This produces a linear transformation matrix to project
22 | FFT bins onto Mel-frequency bins.
23 |
24 | Parameters
25 | ----------
26 | sr : number > 0 [scalar]
27 | sampling rate of the incoming signal
28 |
29 | n_fft : int > 0 [scalar]
30 | number of FFT components
31 |
32 | n_mels : int > 0 [scalar]
33 | number of Mel bands to generate
34 |
35 | fmin : float >= 0 [scalar]
36 | lowest frequency (in Hz)
37 |
38 | fmax : float >= 0 [scalar]
39 | highest frequency (in Hz).
40 | If `None`, use ``fmax = sr / 2.0``
41 |
42 | htk : bool [scalar]
43 | use HTK formula instead of Slaney
44 |
45 | norm : {None, 'slaney', or number} [scalar]
46 | If 'slaney', divide the triangular mel weights by the width of the mel band
47 | (area normalization).
48 |
49 | If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
50 | See `librosa.util.normalize` for a full description of supported norm values
51 | (including `+-np.inf`).
52 |
53 | Otherwise, leave all the triangles aiming for a peak value of 1.0
54 |
55 | dtype : np.dtype
56 | The data type of the output basis.
57 | By default, uses 32-bit (single-precision) floating point.
58 |
59 | Returns
60 | -------
61 | M : np.ndarray [shape=(n_mels, 1 + n_fft/2)]
62 | Mel transform matrix
63 |
64 | See Also
65 | --------
66 | librosa.util.normalize
67 |
68 | Notes
69 | -----
70 | This function caches at level 10.
71 |
72 | Examples
73 | --------
74 | # >>> melfb = librosa.filters.mel(sr=22050, n_fft=2048)
75 | # >>> melfb
76 | array([[ 0. , 0.016, ..., 0. , 0. ],
77 | [ 0. , 0. , ..., 0. , 0. ],
78 | ...,
79 | [ 0. , 0. , ..., 0. , 0. ],
80 | [ 0. , 0. , ..., 0. , 0. ]])
81 |
82 | Clip the maximum frequency to 8KHz
83 |
84 | # >>> librosa.filters.mel(sr=22050, n_fft=2048, fmax=8000)
85 | array([[ 0. , 0.02, ..., 0. , 0. ],
86 | [ 0. , 0. , ..., 0. , 0. ],
87 | ...,
88 | [ 0. , 0. , ..., 0. , 0. ],
89 | [ 0. , 0. , ..., 0. , 0. ]])
90 |
91 | # >>> import matplotlib.pyplot as plt
92 | # >>> fig, ax = plt.subplots()
93 | # >>> img = librosa.display.specshow(melfb, x_axis='linear', ax=ax)
94 | # >>> ax.set(ylabel='Mel filter', title='Mel filter bank')
95 | # >>> fig.colorbar(img, ax=ax)
96 | """
97 |
98 | if fmax is None:
99 | fmax = float(sr) / 2
100 |
101 | # Initialize the weights
102 | n_mels = int(n_mels)
103 | weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
104 |
105 | # Center freqs of each FFT bin
106 | fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
107 |
108 | # 'Center freqs' of mel bands - uniformly spaced between limits
109 | mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
110 |
111 | fdiff = np.diff(mel_f)
112 | ramps = np.subtract.outer(mel_f, fftfreqs)
113 |
114 | for i in range(n_mels):
115 | # lower and upper slopes for all bins
116 | lower = -ramps[i] / fdiff[i]
117 | upper = ramps[i + 2] / fdiff[i + 1]
118 |
119 | # .. then intersect them with each other and zero
120 | weights[i] = np.maximum(0, np.minimum(lower, upper))
121 |
122 | if norm == "slaney":
123 | # Slaney-style mel is scaled to be approx constant energy per channel
124 | enorm = 2.0 / (mel_f[2: n_mels + 2] - mel_f[:n_mels])
125 | weights *= enorm[:, np.newaxis]
126 | else:
127 | weights = normalize(weights, norm=norm, axis=-1)
128 |
129 | # Only check weights if f_mel[0] is positive
130 | if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
131 | # This means we have an empty channel somewhere
132 | print(
133 | " [WARN] UserWarning:"
134 | "Empty filters detected in mel frequency basis. "
135 | "Some channels will produce empty responses. "
136 | "Try increasing your sampling rate (and fmax) or "
137 | "reducing n_mels."
138 | )
139 |
140 | return weights
141 |
142 |
143 | def fft_frequencies(*, sr=22050, n_fft=2048):
144 | """Alternative implementation of `np.fft.fftfreq`
145 |
146 | Parameters
147 | ----------
148 | sr : number > 0 [scalar]
149 | Audio sampling rate
150 | n_fft : int > 0 [scalar]
151 | FFT window size
152 |
153 | Returns
154 | -------
155 | freqs : np.ndarray [shape=(1 + n_fft/2,)]
156 | Frequencies ``(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)``
157 |
158 | Examples
159 | --------
160 | # >>> librosa.fft_frequencies(sr=22050, n_fft=16)
161 | array([ 0. , 1378.125, 2756.25 , 4134.375,
162 | 5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ])
163 |
164 | """
165 |
166 | return np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
167 |
168 |
169 | def mel_frequencies(n_mels=128, *, fmin=0.0, fmax=11025.0, htk=False):
170 | """Compute an array of acoustic frequencies tuned to the mel scale.
171 |
172 | The mel scale is a quasi-logarithmic function of acoustic frequency
173 | designed such that perceptually similar pitch intervals (e.g. octaves)
174 | appear equal in width over the full hearing range.
175 |
176 | Because the definition of the mel scale is conditioned by a finite number
177 | of subjective psychoaoustical experiments, several implementations coexist
178 | in the audio signal processing literature [#]_. By default, librosa replicates
179 | the behavior of the well-established MATLAB Auditory Toolbox of Slaney [#]_.
180 | According to this default implementation, the conversion from Hertz to mel is
181 | linear below 1 kHz and logarithmic above 1 kHz. Another available implementation
182 | replicates the Hidden Markov Toolkit [#]_ (HTK) according to the following formula::
183 |
184 | mel = 2595.0 * np.log10(1.0 + f / 700.0).
185 |
186 | The choice of implementation is determined by the ``htk`` keyword argument: setting
187 | ``htk=False`` leads to the Auditory toolbox implementation, whereas setting it ``htk=True``
188 | leads to the HTK implementation.
189 |
190 | .. [#] Umesh, S., Cohen, L., & Nelson, D. Fitting the mel scale.
191 | In Proc. International Conference on Acoustics, Speech, and Signal Processing
192 | (ICASSP), vol. 1, pp. 217-220, 1998.
193 |
194 | .. [#] Slaney, M. Auditory Toolbox: A MATLAB Toolbox for Auditory
195 | Modeling Work. Technical Report, version 2, Interval Research Corporation, 1998.
196 |
197 | .. [#] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., Liu, X.,
198 | Moore, G., Odell, J., Ollason, D., Povey, D., Valtchev, V., & Woodland, P.
199 | The HTK book, version 3.4. Cambridge University, March 2009.
200 |
201 | See Also
202 | --------
203 | hz_to_mel
204 | mel_to_hz
205 | librosa.feature.melspectrogram
206 | librosa.feature.mfcc
207 |
208 | Parameters
209 | ----------
210 | n_mels : int > 0 [scalar]
211 | Number of mel bins.
212 | fmin : float >= 0 [scalar]
213 | Minimum frequency (Hz).
214 | fmax : float >= 0 [scalar]
215 | Maximum frequency (Hz).
216 | htk : bool
217 | If True, use HTK formula to convert Hz to mel.
218 | Otherwise (False), use Slaney's Auditory Toolbox.
219 |
220 | Returns
221 | -------
222 | bin_frequencies : ndarray [shape=(n_mels,)]
223 | Vector of ``n_mels`` frequencies in Hz which are uniformly spaced on the Mel
224 | axis.
225 |
226 | Examples
227 | --------
228 | # >>> librosa.mel_frequencies(n_mels=40)
229 | array([ 0. , 85.317, 170.635, 255.952,
230 | 341.269, 426.586, 511.904, 597.221,
231 | 682.538, 767.855, 853.173, 938.49 ,
232 | 1024.856, 1119.114, 1222.042, 1334.436,
233 | 1457.167, 1591.187, 1737.532, 1897.337,
234 | 2071.84 , 2262.393, 2470.47 , 2697.686,
235 | 2945.799, 3216.731, 3512.582, 3835.643,
236 | 4188.417, 4573.636, 4994.285, 5453.621,
237 | 5955.205, 6502.92 , 7101.009, 7754.107,
238 | 8467.272, 9246.028, 10096.408, 11025. ])
239 |
240 | """
241 |
242 | # 'Center freqs' of mel bands - uniformly spaced between limits
243 | min_mel = hz_to_mel(fmin, htk=htk)
244 | max_mel = hz_to_mel(fmax, htk=htk)
245 |
246 | mels = np.linspace(min_mel, max_mel, n_mels)
247 |
248 | return mel_to_hz(mels, htk=htk)
249 |
250 |
251 | def hz_to_mel(frequencies, *, htk=False):
252 | """Convert Hz to Mels
253 |
254 | Examples
255 | --------
256 | # >>> librosa.hz_to_mel(60)
257 | 0.9
258 | # >>> librosa.hz_to_mel([110, 220, 440])
259 | array([ 1.65, 3.3 , 6.6 ])
260 |
261 | Parameters
262 | ----------
263 | frequencies : number or np.ndarray [shape=(n,)] , float
264 | scalar or array of frequencies
265 | htk : bool
266 | use HTK formula instead of Slaney
267 |
268 | Returns
269 | -------
270 | mels : number or np.ndarray [shape=(n,)]
271 | input frequencies in Mels
272 |
273 | See Also
274 | --------
275 | mel_to_hz
276 | """
277 |
278 | frequencies = np.asanyarray(frequencies)
279 |
280 | if htk:
281 | return 2595.0 * np.log10(1.0 + frequencies / 700.0)
282 |
283 | # Fill in the linear part
284 | f_min = 0.0
285 | f_sp = 200.0 / 3
286 |
287 | mels = (frequencies - f_min) / f_sp
288 |
289 | # Fill in the log-scale part
290 |
291 | min_log_hz = 1000.0 # beginning of log region (Hz)
292 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
293 | logstep = np.log(6.4) / 27.0 # step size for log region
294 |
295 | if frequencies.ndim:
296 | # If we have array data, vectorize
297 | log_t = frequencies >= min_log_hz
298 | mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
299 | elif frequencies >= min_log_hz:
300 | # If we have scalar data, heck directly
301 | mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
302 |
303 | return mels
304 |
305 |
306 | def mel_to_hz(mels, *, htk=False):
307 | """Convert mel bin numbers to frequencies
308 |
309 | Examples
310 | --------
311 | # >>> librosa.mel_to_hz(3)
312 | 200.
313 |
314 | # >>> librosa.mel_to_hz([1,2,3,4,5])
315 | array([ 66.667, 133.333, 200. , 266.667, 333.333])
316 |
317 | Parameters
318 | ----------
319 | mels : np.ndarray [shape=(n,)], float
320 | mel bins to convert
321 | htk : bool
322 | use HTK formula instead of Slaney
323 |
324 | Returns
325 | -------
326 | frequencies : np.ndarray [shape=(n,)]
327 | input mels in Hz
328 |
329 | See Also
330 | --------
331 | hz_to_mel
332 | """
333 |
334 | mels = np.asanyarray(mels)
335 |
336 | if htk:
337 | return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
338 |
339 | # Fill in the linear scale
340 | f_min = 0.0
341 | f_sp = 200.0 / 3
342 | freqs = f_min + f_sp * mels
343 |
344 | # And now the nonlinear scale
345 | min_log_hz = 1000.0 # beginning of log region (Hz)
346 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
347 | logstep = np.log(6.4) / 27.0 # step size for log region
348 |
349 | if mels.ndim:
350 | # If we have vector data, vectorize
351 | log_t = mels >= min_log_mel
352 | freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
353 | elif mels >= min_log_mel:
354 | # If we have scalar data, check directly
355 | freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
356 |
357 | return freqs
358 |
359 |
360 | def normalize(S, *, norm=np.inf, axis=0, threshold=None, fill=None):
361 | """Normalize an array along a chosen axis.
362 |
363 | Given a norm (described below) and a target axis, the input
364 | array is scaled so that::
365 |
366 | norm(S, axis=axis) == 1
367 |
368 | For example, ``axis=0`` normalizes each column of a 2-d array
369 | by aggregating over the rows (0-axis).
370 | Similarly, ``axis=1`` normalizes each row of a 2-d array.
371 |
372 | This function also supports thresholding small-norm slices:
373 | any slice (i.e., row or column) with norm below a specified
374 | ``threshold`` can be left un-normalized, set to all-zeros, or
375 | filled with uniform non-zero values that normalize to 1.
376 |
377 | Note: the semantics of this function differ from
378 | `scipy.linalg.norm` in two ways: multi-dimensional arrays
379 | are supported, but matrix-norms are not.
380 |
381 | Parameters
382 | ----------
383 | S : np.ndarray
384 | The array to normalize
385 |
386 | norm : {np.inf, -np.inf, 0, float > 0, None}
387 | - `np.inf` : maximum absolute value
388 | - `-np.inf` : minimum absolute value
389 | - `0` : number of non-zeros (the support)
390 | - float : corresponding l_p norm
391 | See `scipy.linalg.norm` for details.
392 | - None : no normalization is performed
393 |
394 | axis : int [scalar]
395 | Axis along which to compute the norm.
396 |
397 | threshold : number > 0 [optional]
398 | Only the columns (or rows) with norm at least ``threshold`` are
399 | normalized.
400 |
401 | By default, the threshold is determined from
402 | the numerical precision of ``S.dtype``.
403 |
404 | fill : None or bool
405 | If None, then columns (or rows) with norm below ``threshold``
406 | are left as is.
407 |
408 | If False, then columns (rows) with norm below ``threshold``
409 | are set to 0.
410 |
411 | If True, then columns (rows) with norm below ``threshold``
412 | are filled uniformly such that the corresponding norm is 1.
413 |
414 | .. note:: ``fill=True`` is incompatible with ``norm=0`` because
415 | no uniform vector exists with l0 "norm" equal to 1.
416 |
417 | Returns
418 | -------
419 | S_norm : np.ndarray [shape=S.shape]
420 | Normalized array
421 |
422 | Raises
423 | ------
424 | ParameterError
425 | If ``norm`` is not among the valid types defined above
426 |
427 | If ``S`` is not finite
428 |
429 | If ``fill=True`` and ``norm=0``
430 |
431 | See Also
432 | --------
433 | scipy.linalg.norm
434 |
435 | Notes
436 | -----
437 | This function caches at level 40.
438 |
439 | Examples
440 | --------
441 | # >>> # Construct an example matrix
442 | # >>> S = np.vander(np.arange(-2.0, 2.0))
443 | # >>> S
444 | array([[-8., 4., -2., 1.],
445 | [-1., 1., -1., 1.],
446 | [ 0., 0., 0., 1.],
447 | [ 1., 1., 1., 1.]])
448 | # >>> # Max (l-infinity)-normalize the columns
449 | # >>> librosa.util.normalize(S)
450 | array([[-1. , 1. , -1. , 1. ],
451 | [-0.125, 0.25 , -0.5 , 1. ],
452 | [ 0. , 0. , 0. , 1. ],
453 | [ 0.125, 0.25 , 0.5 , 1. ]])
454 | # >>> # Max (l-infinity)-normalize the rows
455 | # >>> librosa.util.normalize(S, axis=1)
456 | array([[-1. , 0.5 , -0.25 , 0.125],
457 | [-1. , 1. , -1. , 1. ],
458 | [ 0. , 0. , 0. , 1. ],
459 | [ 1. , 1. , 1. , 1. ]])
460 | # >>> # l1-normalize the columns
461 | # >>> librosa.util.normalize(S, norm=1)
462 | array([[-0.8 , 0.667, -0.5 , 0.25 ],
463 | [-0.1 , 0.167, -0.25 , 0.25 ],
464 | [ 0. , 0. , 0. , 0.25 ],
465 | [ 0.1 , 0.167, 0.25 , 0.25 ]])
466 | # >>> # l2-normalize the columns
467 | # >>> librosa.util.normalize(S, norm=2)
468 | array([[-0.985, 0.943, -0.816, 0.5 ],
469 | [-0.123, 0.236, -0.408, 0.5 ],
470 | [ 0. , 0. , 0. , 0.5 ],
471 | [ 0.123, 0.236, 0.408, 0.5 ]])
472 |
473 | # >>> # Thresholding and filling
474 | # >>> S[:, -1] = 1e-308
475 | # >>> S
476 | array([[ -8.000e+000, 4.000e+000, -2.000e+000,
477 | 1.000e-308],
478 | [ -1.000e+000, 1.000e+000, -1.000e+000,
479 | 1.000e-308],
480 | [ 0.000e+000, 0.000e+000, 0.000e+000,
481 | 1.000e-308],
482 | [ 1.000e+000, 1.000e+000, 1.000e+000,
483 | 1.000e-308]])
484 |
485 | # >>> # By default, small-norm columns are left untouched
486 | # >>> librosa.util.normalize(S)
487 | array([[ -1.000e+000, 1.000e+000, -1.000e+000,
488 | 1.000e-308],
489 | [ -1.250e-001, 2.500e-001, -5.000e-001,
490 | 1.000e-308],
491 | [ 0.000e+000, 0.000e+000, 0.000e+000,
492 | 1.000e-308],
493 | [ 1.250e-001, 2.500e-001, 5.000e-001,
494 | 1.000e-308]])
495 | # >>> # Small-norm columns can be zeroed out
496 | # >>> librosa.util.normalize(S, fill=False)
497 | array([[-1. , 1. , -1. , 0. ],
498 | [-0.125, 0.25 , -0.5 , 0. ],
499 | [ 0. , 0. , 0. , 0. ],
500 | [ 0.125, 0.25 , 0.5 , 0. ]])
501 | # >>> # Or set to constant with unit-norm
502 | # >>> librosa.util.normalize(S, fill=True)
503 | array([[-1. , 1. , -1. , 1. ],
504 | [-0.125, 0.25 , -0.5 , 1. ],
505 | [ 0. , 0. , 0. , 1. ],
506 | [ 0.125, 0.25 , 0.5 , 1. ]])
507 | # >>> # With an l1 norm instead of max-norm
508 | # >>> librosa.util.normalize(S, norm=1, fill=True)
509 | array([[-0.8 , 0.667, -0.5 , 0.25 ],
510 | [-0.1 , 0.167, -0.25 , 0.25 ],
511 | [ 0. , 0. , 0. , 0.25 ],
512 | [ 0.1 , 0.167, 0.25 , 0.25 ]])
513 | """
514 |
515 | # Avoid div-by-zero
516 | if threshold is None:
517 | threshold = tiny(S)
518 |
519 | elif threshold <= 0:
520 | raise ValueError(
521 | "threshold={} must be strictly " "positive".format(threshold)
522 | )
523 |
524 | if fill not in [None, False, True]:
525 | raise ValueError("fill={} must be None or boolean".format(fill))
526 |
527 | if not np.all(np.isfinite(S)):
528 | raise ValueError("Input must be finite")
529 |
530 | # All norms only depend on magnitude, let's do that first
531 | mag = np.abs(S).astype(float)
532 |
533 | # For max/min norms, filling with 1 works
534 | fill_norm = 1
535 |
536 | if norm == np.inf:
537 | length = np.max(mag, axis=axis, keepdims=True)
538 |
539 | elif norm == -np.inf:
540 | length = np.min(mag, axis=axis, keepdims=True)
541 |
542 | elif norm == 0:
543 | if fill is True:
544 | raise ValueError("Cannot normalize with norm=0 and fill=True")
545 |
546 | length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype)
547 |
548 | elif np.issubdtype(type(norm), np.number) and norm > 0:
549 | length = np.sum(mag ** norm, axis=axis, keepdims=True) ** (1.0 / norm)
550 |
551 | if axis is None:
552 | fill_norm = mag.size ** (-1.0 / norm)
553 | else:
554 | fill_norm = mag.shape[axis] ** (-1.0 / norm)
555 |
556 | elif norm is None:
557 | return S
558 |
559 | else:
560 | raise NotImplementedError("Unsupported norm: {}".format(repr(norm)))
561 |
562 | # indices where norm is below the threshold
563 | small_idx = length < threshold
564 |
565 | S_norm = np.empty_like(S)
566 | if fill is None:
567 | # Leave small indices un-normalized
568 | length[small_idx] = 1.0
569 | S_norm[:] = S / length
570 |
571 | elif fill:
572 | # If we have a non-zero fill value, we locate those entries by
573 | # doing a nan-divide.
574 | # If S was finite, then length is finite (except for small positions)
575 | length[small_idx] = np.nan
576 | S_norm[:] = S / length
577 | S_norm[np.isnan(S_norm)] = fill_norm
578 | else:
579 | # Set small values to zero by doing an inf-divide.
580 | # This is safe (by IEEE-754) as long as S is finite.
581 | length[small_idx] = np.inf
582 | S_norm[:] = S / length
583 |
584 | return S_norm
585 |
586 |
587 | def tiny(x):
588 | """Compute the tiny-value corresponding to an input's data type.
589 |
590 | This is the smallest "usable" number representable in ``x.dtype``
591 | (e.g., float32).
592 |
593 | This is primarily useful for determining a threshold for
594 | numerical underflow in division or multiplication operations.
595 |
596 | Parameters
597 | ----------
598 | x : number or np.ndarray
599 | The array to compute the tiny-value for.
600 | All that matters here is ``x.dtype``
601 |
602 | Returns
603 | -------
604 | tiny_value : float
605 | The smallest positive usable number for the type of ``x``.
606 | If ``x`` is integer-typed, then the tiny value for ``np.float32``
607 | is returned instead.
608 |
609 | See Also
610 | --------
611 | numpy.finfo
612 |
613 | Examples
614 | --------
615 | For a standard double-precision floating point number:
616 |
617 | # >>> librosa.util.tiny(1.0)
618 | 2.2250738585072014e-308
619 |
620 | Or explicitly as double-precision
621 |
622 | # >>> librosa.util.tiny(np.asarray(1e-5, dtype=np.float64))
623 | 2.2250738585072014e-308
624 |
625 | Or complex numbers
626 |
627 | # >>> librosa.util.tiny(1j)
628 | 2.2250738585072014e-308
629 |
630 | Single-precision floating point:
631 |
632 | # >>> librosa.util.tiny(np.asarray(1e-5, dtype=np.float32))
633 | 1.1754944e-38
634 |
635 | Integer
636 |
637 | # >>> librosa.util.tiny(5)
638 | 1.1754944e-38
639 | """
640 |
641 | # Make sure we have an array view
642 | x = np.asarray(x)
643 |
644 | # Only floating types generate a tiny
645 | if np.issubdtype(x.dtype, np.floating) or np.issubdtype(
646 | x.dtype, np.complexfloating
647 | ):
648 | dtype = x.dtype
649 | else:
650 | dtype = np.float32
651 |
652 | return np.finfo(dtype).tiny
653 |
654 |
655 | if __name__ == '__main__':
656 | """
657 | Test mel_fn_librosa.py, check mel_fn_librosa.py is same as librosa.filters.mel
658 | "is ok:" should be True
659 | Need librosa
660 | require: pip install librosa
661 | """
662 | try:
663 | import librosa
664 | except ImportError:
665 | print(' [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: librosa not installed,'
666 | ' if you want check this file with librosa, please install it first.')
667 | exit(1)
668 | from librosa.filters import mel as librosa_mel_fn
669 |
670 | raw_fn = librosa_mel_fn(sr=16000, n_fft=1024, n_mels=128, fmin=0, fmax=8000)
671 | self_fn = mel(sr=16000, n_fft=1024, n_mels=128, fmin=0, fmax=8000)
672 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: raw_fn.shape", raw_fn.shape)
673 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: self_fn.shape", self_fn.shape)
674 | check = np.allclose(raw_fn, self_fn)
675 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: np.allclose(raw_fn, self_fn) is same:", check)
676 |
--------------------------------------------------------------------------------
/torchfcpe/model_conformer_naive.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | import math
5 | from functools import partial
6 | from einops import rearrange, repeat
7 |
8 | from local_attention import LocalAttention
9 | import torch.nn.functional as F
10 |
11 | # From https://github.com/CNChTu/Diffusion-SVC/ by CNChTu
12 | # License: MIT
13 |
14 |
15 | class ConformerNaiveEncoder(nn.Module):
16 | """
17 | Conformer Naive Encoder
18 |
19 | Args:
20 | dim_model (int): Dimension of model
21 | num_layers (int): Number of layers
22 | num_heads (int): Number of heads
23 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False
24 | conv_only (bool): Whether to use only conv module without attention, default False
25 | conv_dropout (float): Dropout rate of conv module, default 0.
26 | atten_dropout (float): Dropout rate of attention module, default 0.
27 | use_pre_norm (bool): Whether to use pre-norm, default False
28 | """
29 |
30 | def __init__(self,
31 | num_layers: int,
32 | num_heads: int,
33 | dim_model: int,
34 | use_norm: bool = False,
35 | conv_only: bool = False,
36 | conv_dropout: float = 0.,
37 | atten_dropout: float = 0.,
38 | ):
39 | super().__init__()
40 | self.num_layers = num_layers
41 | self.num_heads = num_heads
42 | self.dim_model = dim_model
43 | self.use_norm = use_norm
44 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留
45 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留
46 |
47 | self.encoder_layers = nn.ModuleList(
48 | [
49 | CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout)
50 | for _ in range(num_layers)
51 | ]
52 | )
53 |
54 | def forward(self, x, mask=None) -> torch.Tensor:
55 | """
56 | Args:
57 | x (torch.Tensor): Input tensor (#batch, length, dim_model)
58 | mask (torch.Tensor): Mask tensor, default None
59 | return:
60 | torch.Tensor: Output tensor (#batch, length, dim_model)
61 | """
62 |
63 | for (i, layer) in enumerate(self.encoder_layers):
64 | x = layer(x, mask)
65 | return x # (#batch, length, dim_model)
66 |
67 |
68 | class CFNEncoderLayer(nn.Module):
69 | """
70 | Conformer Naive Encoder Layer
71 |
72 | Args:
73 | dim_model (int): Dimension of model
74 | num_heads (int): Number of heads
75 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False
76 | conv_only (bool): Whether to use only conv module without attention, default False
77 | conv_dropout (float): Dropout rate of conv module, default 0.1
78 | atten_dropout (float): Dropout rate of attention module, default 0.1
79 | use_pre_norm (bool): Whether to use pre-norm, default False
80 | """
81 |
82 | def __init__(self,
83 | dim_model: int,
84 | num_heads: int = 8,
85 | use_norm: bool = False,
86 | conv_only: bool = False,
87 | conv_dropout: float = 0.,
88 | atten_dropout: float = 0.,
89 | ):
90 | super().__init__()
91 |
92 | if conv_dropout > 0.:
93 | self.conformer = nn.Sequential(
94 | ConformerConvModule(dim_model),
95 | nn.Dropout(conv_dropout)
96 | )
97 | else:
98 | self.conformer = ConformerConvModule(dim_model)
99 | self.norm = nn.LayerNorm(dim_model)
100 |
101 | self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留
102 |
103 | # selfatt -> fastatt: performer!
104 | if not conv_only:
105 | self.attn = SelfAttention(dim=dim_model,
106 | heads=num_heads,
107 | causal=False,
108 | use_norm=use_norm,
109 | dropout=atten_dropout, )
110 | else:
111 | self.attn = None
112 |
113 | def forward(self, x, mask=None) -> torch.Tensor:
114 | """
115 | Args:
116 | x (torch.Tensor): Input tensor (#batch, length, dim_model)
117 | mask (torch.Tensor): Mask tensor, default None
118 | return:
119 | torch.Tensor: Output tensor (#batch, length, dim_model)
120 | """
121 | if self.attn is not None:
122 | x = x + (self.attn(self.norm(x), mask=mask))
123 |
124 | x = x + (self.conformer(x))
125 |
126 | return x # (#batch, length, dim_model)
127 |
128 |
129 | class ConformerConvModule(nn.Module):
130 | def __init__(
131 | self,
132 | dim,
133 | expansion_factor=2,
134 | kernel_size=31,
135 | dropout=0.,
136 | ):
137 | super().__init__()
138 |
139 | inner_dim = dim * expansion_factor
140 | padding = calc_same_padding(kernel_size)
141 |
142 | _norm = nn.LayerNorm(dim)
143 |
144 | self.net = nn.Sequential(
145 | _norm,
146 | Transpose((1, 2)),
147 | nn.Conv1d(dim, inner_dim * 2, 1),
148 | nn.GLU(dim=1),
149 | DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding[0], groups=inner_dim),
150 | nn.SiLU(),
151 | nn.Conv1d(inner_dim, dim, 1),
152 | Transpose((1, 2)),
153 | nn.Dropout(dropout)
154 | )
155 |
156 | def forward(self, x):
157 | return self.net(x)
158 |
159 |
160 | class DepthWiseConv1d(nn.Module):
161 | def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
162 | super().__init__()
163 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
164 |
165 | def forward(self, x):
166 | return self.conv(x)
167 |
168 |
169 | def calc_same_padding(kernel_size):
170 | pad = kernel_size // 2
171 | return (pad, pad - (kernel_size + 1) % 2)
172 |
173 |
174 | class Transpose(nn.Module):
175 | def __init__(self, dims):
176 | super().__init__()
177 | assert len(dims) == 2, 'dims must be a tuple of two dimensions'
178 | self.dims = dims
179 |
180 | def forward(self, x):
181 | return x.transpose(*self.dims)
182 |
183 |
184 | class SelfAttention(nn.Module):
185 | def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None,
186 | feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False,
187 | dropout=0., no_projection=False, use_norm=False):
188 | super().__init__()
189 | assert dim % heads == 0, 'dimension must be divisible by number of heads'
190 | dim_head = default(dim_head, dim // heads)
191 | inner_dim = dim_head * heads
192 | self.fast_attention = FastAttention(dim_head, nb_features, causal=causal,
193 | generalized_attention=generalized_attention, kernel_fn=kernel_fn,
194 | qr_uniform_q=qr_uniform_q, no_projection=no_projection,
195 | use_norm=use_norm)
196 |
197 | self.heads = heads
198 | self.global_heads = heads - local_heads
199 | self.local_attn = LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout,
200 | look_forward=int(not causal),
201 | rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None
202 |
203 | # print (heads, nb_features, dim_head)
204 | # name_embedding = torch.zeros(110, heads, dim_head, dim_head)
205 | # self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
206 |
207 | self.to_q = nn.Linear(dim, inner_dim)
208 | self.to_k = nn.Linear(dim, inner_dim)
209 | self.to_v = nn.Linear(dim, inner_dim)
210 | self.to_out = nn.Linear(inner_dim, dim)
211 | self.dropout = nn.Dropout(dropout)
212 |
213 | @torch.no_grad()
214 | def redraw_projection_matrix(self):
215 | self.fast_attention.redraw_projection_matrix()
216 | # torch.nn.init.zeros_(self.name_embedding)
217 | # print (torch.sum(self.name_embedding))
218 |
219 | def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
220 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads
221 |
222 | cross_attend = exists(context)
223 |
224 | context = default(context, x)
225 | context_mask = default(context_mask, mask) if not cross_attend else context_mask
226 | # print (torch.sum(self.name_embedding))
227 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
228 |
229 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
230 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
231 |
232 | attn_outs = []
233 | # print (name)
234 | # print (self.name_embedding[name].size())
235 | if not empty(q):
236 | if exists(context_mask):
237 | global_mask = context_mask[:, None, :, None]
238 | v.masked_fill_(~global_mask, 0.)
239 | if cross_attend:
240 | pass
241 | # print (torch.sum(self.name_embedding))
242 | # out = self.fast_attention(q,self.name_embedding[name],None)
243 | # print (torch.sum(self.name_embedding[...,-1:]))
244 | # attn_outs.append(out)
245 | else:
246 | out = self.fast_attention(q, k, v)
247 | attn_outs.append(out)
248 |
249 | if not empty(lq):
250 | assert not cross_attend, 'local attention is not compatible with cross attention'
251 | out = self.local_attn(lq, lk, lv, input_mask=mask)
252 | attn_outs.append(out)
253 |
254 | out = torch.cat(attn_outs, dim=1)
255 | out = rearrange(out, 'b h n d -> b n (h d)')
256 | out = self.to_out(out)
257 | return self.dropout(out)
258 |
259 |
260 | class FastAttention(nn.Module):
261 | def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False,
262 | kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False, use_norm=False):
263 | super().__init__()
264 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
265 |
266 | self.dim_heads = dim_heads
267 | self.nb_features = nb_features
268 | self.ortho_scaling = ortho_scaling
269 |
270 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features,
271 | nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
272 | projection_matrix = self.create_projection()
273 | self.register_buffer('projection_matrix', projection_matrix)
274 |
275 | self.generalized_attention = generalized_attention
276 | self.kernel_fn = kernel_fn
277 |
278 | # if this is turned on, no projection will be used
279 | # queries and keys will be softmax-ed as in the original efficient attention paper
280 | self.no_projection = no_projection
281 |
282 | self.causal = causal
283 | self.use_norm = use_norm
284 | '''
285 | if causal:
286 | try:
287 | import fast_transformers.causal_product.causal_product_cuda
288 | self.causal_linear_fn = partial(causal_linear_attention)
289 | except ImportError:
290 | print(
291 | 'unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
292 | self.causal_linear_fn = causal_linear_attention_noncuda
293 | '''
294 | if self.causal or self.generalized_attention:
295 | raise NotImplementedError('Causal and generalized attention not implemented yet')
296 |
297 | @torch.no_grad()
298 | def redraw_projection_matrix(self):
299 | projections = self.create_projection()
300 | self.projection_matrix.copy_(projections)
301 | del projections
302 |
303 | def forward(self, q, k, v):
304 | device = q.device
305 |
306 | if self.use_norm:
307 | q = q / (q.norm(dim=-1, keepdim=True) + 1e-8)
308 | k = k / (k.norm(dim=-1, keepdim=True) + 1e-8)
309 |
310 | if self.no_projection:
311 | q = q.softmax(dim=-1)
312 | k = torch.exp(k) if self.causal else k.softmax(dim=-2)
313 |
314 | elif self.generalized_attention:
315 | '''
316 | create_kernel = partial(generalized_kernel, kernel_fn=self.kernel_fn,
317 | projection_matrix=self.projection_matrix, device=device)
318 | q, k = map(create_kernel, (q, k))
319 | '''
320 | raise NotImplementedError('generalized attention not implemented yet')
321 |
322 | else:
323 | create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device)
324 |
325 | q = create_kernel(q, is_query=True)
326 | k = create_kernel(k, is_query=False)
327 |
328 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn
329 | if v is None:
330 | out = attn_fn(q, k, None)
331 | return out
332 | else:
333 | out = attn_fn(q, k, v)
334 | return out
335 |
336 |
337 | def linear_attention(q, k, v):
338 | if v is None:
339 | # print (k.size(), q.size())
340 | out = torch.einsum('...ed,...nd->...ne', k, q)
341 | return out
342 |
343 | else:
344 | k_cumsum = k.sum(dim=-2)
345 | # k_cumsum = k.sum(dim = -2)
346 | D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
347 |
348 | context = torch.einsum('...nd,...ne->...de', k, v)
349 | # print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
350 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
351 | return out
352 |
353 |
354 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
355 | b, h, *_ = data.shape
356 | # (batch size, head, length, model_dim)
357 |
358 | # normalize model dim
359 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
360 |
361 | # what is ration?, projection_matrix.shape[0] --> 266
362 |
363 | ratio = (projection_matrix.shape[0] ** -0.5)
364 |
365 | projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
366 | projection = projection.type_as(data)
367 |
368 | # data_dash = w^T x
369 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
370 |
371 | # diag_data = D**2
372 | diag_data = data ** 2
373 | diag_data = torch.sum(diag_data, dim=-1)
374 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
375 | diag_data = diag_data.unsqueeze(dim=-1)
376 |
377 | # print ()
378 | if is_query:
379 | data_dash = ratio * (
380 | torch.exp(data_dash - diag_data -
381 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
382 | else:
383 | data_dash = ratio * (
384 | torch.exp(data_dash - diag_data + eps)) # - torch.max(data_dash)) + eps)
385 |
386 | return data_dash.type_as(data)
387 |
388 |
389 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
390 | nb_full_blocks = int(nb_rows / nb_columns)
391 | # print (nb_full_blocks)
392 | block_list = []
393 |
394 | for _ in range(nb_full_blocks):
395 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
396 | block_list.append(q)
397 | # block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
398 | # print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
399 | # print (nb_rows, nb_full_blocks, nb_columns)
400 | remaining_rows = nb_rows - nb_full_blocks * nb_columns
401 | # print (remaining_rows)
402 | if remaining_rows > 0:
403 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
404 | # print (q[:remaining_rows].size())
405 | block_list.append(q[:remaining_rows])
406 |
407 | final_matrix = torch.cat(block_list)
408 |
409 | if scaling == 0:
410 | multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
411 | elif scaling == 1:
412 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
413 | else:
414 | raise ValueError(f'Invalid scaling {scaling}')
415 |
416 | return torch.diag(multiplier) @ final_matrix
417 |
418 |
419 | def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
420 | unstructured_block = torch.randn((cols, cols), device=device)
421 | q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
422 | q, r = map(lambda t: t.to(device), (q, r))
423 |
424 | # proposed by @Parskatt
425 | # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
426 | if qr_uniform_q:
427 | d = torch.diag(r, 0)
428 | q *= d.sign()
429 | return q.t()
430 |
431 |
432 | def default(val, d):
433 | return val if exists(val) else d
434 |
435 |
436 | def exists(val):
437 | return val is not None
438 |
439 |
440 | def empty(tensor):
441 | return tensor.numel() == 0
442 |
--------------------------------------------------------------------------------
/torchfcpe/model_convnext.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class ConvNeXtBlock(nn.Module):
8 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
9 |
10 | Args:
11 | dim (int): Number of input channels.
12 | intermediate_dim (int): Dimensionality of the intermediate layer.
13 | dilation (int, optional): Dilation factor for the depthwise convolution. Defaults to 1.
14 | kernel_size (int, optional): Kernel size for the depthwise convolution. Defaults to 7.
15 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
16 | Defaults to 1e-6.
17 | """
18 |
19 | def __init__(
20 | self,
21 | dim: int,
22 | intermediate_dim: int,
23 | dilation: int = 1,
24 | kernel_size: int = 7,
25 | layer_scale_init_value: Optional[float] = 1e-6,
26 | ):
27 | super().__init__()
28 | self.dwconv = nn.Conv1d(
29 | dim,
30 | dim,
31 | kernel_size=kernel_size,
32 | groups=dim,
33 | dilation=dilation,
34 | padding=int(dilation * (kernel_size - 1) / 2),
35 | ) # depthwise conv
36 | self.norm = nn.LayerNorm(dim, eps=1e-6)
37 | self.pwconv1 = nn.Linear(
38 | dim, intermediate_dim
39 | ) # pointwise/1x1 convs, implemented with linear layers
40 | self.act = nn.GELU()
41 | self.pwconv2 = nn.Linear(intermediate_dim, dim)
42 | self.gamma = (
43 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
44 | if layer_scale_init_value is not None and layer_scale_init_value > 0
45 | else None
46 | )
47 |
48 | def forward(self, x: torch.Tensor) -> torch.Tensor:
49 | residual = x
50 |
51 | x = self.dwconv(x)
52 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
53 | x = self.norm(x)
54 | x = self.pwconv1(x)
55 | x = self.act(x)
56 | x = self.pwconv2(x)
57 | if self.gamma is not None:
58 | x = self.gamma * x
59 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
60 |
61 | x = residual + x
62 | return x
63 |
64 |
65 | class ConvNeXt(nn.Module):
66 | """ConvNeXt layers
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | num_layers (int): Number of ConvNeXt layers.
71 | mlp_factor (int, optional): Factor for the intermediate layer dimensionality. Defaults to 4.
72 | dilation_cycle (int, optional): Cycle for the dilation factor. Defaults to 4.
73 | kernel_size (int, optional): Kernel size for the depthwise convolution. Defaults to 7.
74 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
75 | Defaults to 1e-6.
76 | """
77 |
78 | def __init__(
79 | self,
80 | dim: int,
81 | num_layers: int = 20,
82 | mlp_factor: int = 4,
83 | dilation_cycle: int = 4,
84 | kernel_size: int = 7,
85 | layer_scale_init_value: Optional[float] = 1e-6,
86 | ):
87 | super().__init__()
88 | self.dim = dim
89 | self.num_layers = num_layers
90 | self.mlp_factor = mlp_factor
91 | self.dilation_cycle = dilation_cycle
92 | self.kernel_size = kernel_size
93 | self.layer_scale_init_value = layer_scale_init_value
94 |
95 | self.layers = nn.ModuleList(
96 | [
97 | ConvNeXtBlock(
98 | dim,
99 | dim * mlp_factor,
100 | dilation=(2 ** (i % dilation_cycle)),
101 | kernel_size=kernel_size,
102 | layer_scale_init_value=1e-6,
103 | )
104 | for i in range(num_layers)
105 | ]
106 | )
107 |
108 | def forward(self, x: torch.Tensor) -> torch.Tensor:
109 | for layer in self.layers:
110 | x = layer(x)
111 | return x
112 |
--------------------------------------------------------------------------------
/torchfcpe/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # import weight_norm from different version of pytorch
6 | try:
7 | from torch.nn.utils.parametrizations import weight_norm
8 | except ImportError:
9 | from torch.nn.utils import weight_norm
10 |
11 | from .model_conformer_naive import ConformerNaiveEncoder
12 |
13 |
14 | class CFNaiveMelPE(nn.Module):
15 | """
16 | Conformer-based Mel-spectrogram Prediction Encoderc in Fast Context-based Pitch Estimation
17 |
18 | Args:
19 | input_channels (int): Number of input channels, should be same as the number of bins of mel-spectrogram.
20 | out_dims (int): Number of output dimensions, also class numbers.
21 | hidden_dims (int): Number of hidden dimensions.
22 | n_layers (int): Number of conformer layers.
23 | f0_max (float): Maximum frequency of f0.
24 | f0_min (float): Minimum frequency of f0.
25 | use_fa_norm (bool): Whether to use fast attention norm, default False
26 | conv_only (bool): Whether to use only conv module without attention, default False
27 | conv_dropout (float): Dropout rate of conv module, default 0.
28 | atten_dropout (float): Dropout rate of attention module, default 0.
29 | use_harmonic_emb (bool): Whether to use harmonic embedding, default False
30 | use_pre_norm (bool): Whether to use pre norm, default False
31 | """
32 |
33 | def __init__(self,
34 | input_channels: int,
35 | out_dims: int,
36 | hidden_dims: int = 512,
37 | n_layers: int = 6,
38 | n_heads: int = 8,
39 | f0_max: float = 1975.5,
40 | f0_min: float = 32.70,
41 | use_fa_norm: bool = False,
42 | conv_only: bool = False,
43 | conv_dropout: float = 0.,
44 | atten_dropout: float = 0.,
45 | use_harmonic_emb: bool = False,
46 | ):
47 | super().__init__()
48 | self.input_channels = input_channels
49 | self.out_dims = out_dims
50 | self.hidden_dims = hidden_dims
51 | self.n_layers = n_layers
52 | self.n_heads = n_heads
53 | self.f0_max = f0_max
54 | self.f0_min = f0_min
55 | self.use_fa_norm = use_fa_norm
56 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留
57 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留
58 |
59 | # Harmonic embedding
60 | if use_harmonic_emb:
61 | self.harmonic_emb = nn.Embedding(9, hidden_dims)
62 | else:
63 | self.harmonic_emb = None
64 |
65 | # Input stack, convert mel-spectrogram to hidden_dims
66 | self.input_stack = nn.Sequential(
67 | nn.Conv1d(input_channels, hidden_dims, 3, 1, 1),
68 | nn.GroupNorm(4, hidden_dims),
69 | nn.LeakyReLU(),
70 | nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1)
71 | )
72 | # Conformer Encoder
73 | self.net = ConformerNaiveEncoder(
74 | num_layers=n_layers,
75 | num_heads=n_heads,
76 | dim_model=hidden_dims,
77 | use_norm=use_fa_norm,
78 | conv_only=conv_only,
79 | conv_dropout=conv_dropout,
80 | atten_dropout=atten_dropout,
81 | )
82 | # LayerNorm
83 | self.norm = nn.LayerNorm(hidden_dims)
84 | # Output stack, convert hidden_dims to out_dims
85 | self.output_proj = weight_norm(
86 | nn.Linear(hidden_dims, out_dims)
87 | )
88 | # Cent table buffer
89 | """
90 | self.cent_table_b = torch.Tensor(
91 | np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0],
92 | out_dims))
93 | """
94 | # use torch have very small difference like 1e-4, up to 1e-3, but it may be better to use numpy?
95 | self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0],
96 | self.f0_to_cent(torch.Tensor([f0_max]))[0],
97 | out_dims).detach()
98 | self.register_buffer("cent_table", self.cent_table_b)
99 | # gaussian_blurred_cent_mask_b buffer
100 | self.gaussian_blurred_cent_mask_b = (1200. * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
101 | self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
102 |
103 | def forward(self, x: torch.Tensor, _h_emb=None) -> torch.Tensor:
104 | """
105 | Args:
106 | x (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins).
107 | _h_emb (int): Harmonic embedding index, like 0, 1, 2, only use in train. Default: None.
108 | return:
109 | torch.Tensor: Predicted f0 latent, shape (B, T, out_dims).
110 | """
111 | x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
112 | if self.harmonic_emb is not None:
113 | if _h_emb is None:
114 | x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device))
115 | else:
116 | x = x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
117 | x = self.net(x)
118 | x = self.norm(x)
119 | x = self.output_proj(x)
120 | x = torch.sigmoid(x)
121 | return x # latent (B, T, out_dims)
122 |
123 | @torch.no_grad()
124 | def latent2cents_decoder(self,
125 | y: torch.Tensor,
126 | threshold: float = 0.05,
127 | mask: bool = True
128 | ) -> torch.Tensor:
129 | """
130 | Convert latent to cents.
131 | Args:
132 | y (torch.Tensor): Latent, shape (B, T, out_dims).
133 | threshold (float): Threshold to mask. Default: 0.05.
134 | mask (bool): Whether to mask. Default: True.
135 | return:
136 | torch.Tensor: Cents, shape (B, T, 1).
137 | """
138 | B, N, _ = y.size()
139 | ci = self.cent_table[None, None, :].expand(B, N, -1)
140 | rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) # cents: [B,N,1]
141 | if mask:
142 | confident = torch.max(y, dim=-1, keepdim=True)[0]
143 | confident_mask = torch.ones_like(confident)
144 | confident_mask[confident <= threshold] = float("-INF")
145 | rtn = rtn * confident_mask
146 | return rtn # (B, T, 1)
147 |
148 | @torch.no_grad()
149 | def latent2cents_local_decoder(self,
150 | y: torch.Tensor,
151 | threshold: float = 0.05,
152 | mask: bool = True
153 | ) -> torch.Tensor:
154 | """
155 | Convert latent to cents. Use local argmax.
156 | Args:
157 | y (torch.Tensor): Latent, shape (B, T, out_dims).
158 | threshold (float): Threshold to mask. Default: 0.05.
159 | mask (bool): Whether to mask. Default: True.
160 | return:
161 | torch.Tensor: Cents, shape (B, T, 1).
162 | """
163 | B, N, _ = y.size()
164 | ci = self.cent_table[None, None, :].expand(B, N, -1)
165 | confident, max_index = torch.max(y, dim=-1, keepdim=True)
166 | local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
167 | local_argmax_index[local_argmax_index < 0] = 0
168 | local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
169 | ci_l = torch.gather(ci, -1, local_argmax_index)
170 | y_l = torch.gather(y, -1, local_argmax_index)
171 | rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1]
172 | if mask:
173 | confident_mask = torch.ones_like(confident)
174 | confident_mask[confident <= threshold] = float("-INF")
175 | rtn = rtn * confident_mask
176 | return rtn # (B, T, 1)
177 |
178 | @torch.no_grad()
179 | def gaussian_blurred_cent2latent(self, cents): # cents: [B,N,1]
180 | """
181 | Convert cents to latent.
182 | Args:
183 | cents (torch.Tensor): Cents, shape (B, T, 1).
184 | return:
185 | torch.Tensor: Latent, shape (B, T, out_dims).
186 | """
187 | mask = (cents > 0.1) & (cents < self.gaussian_blurred_cent_mask)
188 | # mask = (cents>0.1) & (cents<(1200.*np.log2(self.f0_max/10.)))
189 | B, N, _ = cents.size()
190 | ci = self.cent_table[None, None, :].expand(B, N, -1)
191 | return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
192 |
193 | @torch.no_grad()
194 | def infer(self,
195 | mel: torch.Tensor,
196 | decoder: str = "local_argmax", # "argmax" or "local_argmax"
197 | threshold: float = 0.05,
198 | ) -> torch.Tensor:
199 | """
200 | Args:
201 | mel (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins).
202 | decoder (str): Decoder type. Default: "local_argmax".
203 | threshold (float): Threshold to mask. Default: 0.05.
204 | """
205 | latent = self.forward(mel)
206 | if decoder == "argmax":
207 | cents = self.latent2cents_decoder(latent, threshold=threshold)
208 | elif decoder == "local_argmax":
209 | cents = self.latent2cents_local_decoder(latent, threshold=threshold)
210 | else:
211 | raise ValueError(f" [x] Unknown decoder type {decoder}.")
212 | f0 = self.cent_to_f0(cents)
213 | return f0 # (B, T, 1)
214 |
215 | def train_and_loss(self, mel, gt_f0, loss_scale=10):
216 | """
217 | Args:
218 | mel (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins).
219 | gt_f0 (torch.Tensor): Ground truth f0, shape (B, T, 1).
220 | loss_scale (float): Loss scale. Default: 10.
221 | return: loss
222 | """
223 | if mel.shape[-2] != gt_f0.shape[-2]:
224 | _len = min(mel.shape[-2], gt_f0.shape[-2])
225 | mel = mel[:, :_len, :]
226 | gt_f0 = gt_f0[:, :_len, :]
227 | gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0, [B,N,1]
228 | x_gt = self.gaussian_blurred_cent2latent(gt_cent_f0) # [B,N,out_dim]
229 | if self.harmonic_emb is not None:
230 | x = self.forward(mel, _h_emb=0)
231 | x_half = self.forward(mel, _h_emb=1)
232 | x_gt_half = self.gaussian_blurred_cent2latent(gt_cent_f0 / 2)
233 | x_gt_double = self.gaussian_blurred_cent2latent(gt_cent_f0 * 2)
234 | x_double = self.forward(mel, _h_emb=2)
235 | loss = F.binary_cross_entropy(x, x_gt)
236 | loss_half = F.binary_cross_entropy(x_half, x_gt_half)
237 | loss_double = F.binary_cross_entropy(x_double, x_gt_double)
238 | loss = loss + (loss_half + loss_double) / 2
239 | loss = loss * loss_scale
240 | else:
241 | x = self.forward(mel) # [B,N,out_dim]
242 | loss = F.binary_cross_entropy(x, x_gt) * loss_scale
243 | return loss
244 |
245 | @torch.no_grad()
246 | def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
247 | """
248 | Convert cent to f0. Args: cent (torch.Tensor): Cent, shape = (B, T, 1). return: torch.Tensor: f0, shape = (B, T, 1).
249 | """
250 | f0 = 10. * 2 ** (cent / 1200.)
251 | return f0 # (B, T, 1)
252 |
253 | @torch.no_grad()
254 | def f0_to_cent(self, f0: torch.Tensor) -> torch.Tensor:
255 | """
256 | Convert f0 to cent. Args: f0 (torch.Tensor): f0, shape = (B, T, 1). return: torch.Tensor: Cent, shape = (B, T, 1).
257 | """
258 | cent = 1200. * torch.log2(f0 / 10.)
259 | return cent # (B, T, 1)
260 |
--------------------------------------------------------------------------------
/torchfcpe/models_infer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pathlib
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from torchfcpe.f02midi.transpose import f02midi
8 |
9 | from .models import CFNaiveMelPE
10 | from .tools import (
11 | DotDict,
12 | catch_none_args_must,
13 | catch_none_args_opti,
14 | get_config_json_in_same_path,
15 | get_device,
16 | spawn_wav2mel,
17 | )
18 | from .torch_interp import batch_interp_with_replacement_detach
19 |
20 |
21 | def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
22 | """_summary_
23 |
24 | Args:
25 | f0s (torch.Tensor): (B, T, len(key_shift_list))
26 | key_shift_list (list): list of key shifts
27 | tta_uv_penalty (float,int): uv penalty
28 |
29 | Returns:
30 | f0: (B, T, 1)
31 | """
32 | device = f0s.device
33 | # convert f0 to note
34 | f0s = f0s / (
35 | torch.pow(
36 | 2,
37 | torch.tensor(key_shift_list, device=device)
38 | .to(device)
39 | .unsqueeze(0)
40 | .unsqueeze(0)
41 | / 12,
42 | )
43 | )
44 | notes = torch.log2(f0s / 440) * 12 + 69
45 | notes[notes < 0] = 0
46 |
47 | # select best note
48 | # 使用动态规划选择最优的音高
49 | # 惩罚1:uv的惩罚固定为超参数uv_penalty ** 2,v转为uv时额外惩罚两次
50 | # 惩罚2:相邻帧音高的L2距离(uv和v互转的过程除外),距离小于0.5时忽略不计
51 | uv_penalty = tta_uv_penalty**2
52 | dp = torch.zeros_like(notes, device=device)
53 | # dp[b,t,c]表示,对于样本b,0到第t帧的所有选择中,选择第c个f0作为第t帧的结尾的最小惩罚
54 | backtrack = torch.zeros_like(notes, device=device).long()
55 | # backtrack[b,t,c]表示,对于样本b,0到第t帧的所有选择中,选择第c个f0作为第t帧的结尾时,t-1帧结尾的选择,值域为0到len(f0_list)-1
56 | # init
57 | dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
58 | # forward
59 | for t in range(1, notes.size(1)):
60 | penalty = torch.zeros(
61 | [notes.size(0), notes.size(2), notes.size(2)], device=device
62 | )
63 | # [b,c1,c2]表示第b个样本中,t-1帧选择c1,t帧选择c2的惩罚
64 |
65 | # t帧是uv的情况
66 | t_uv = notes[:, t, :] <= 0
67 | penalty += uv_penalty * t_uv.unsqueeze(1)
68 |
69 | # t帧是v的情况
70 | # t-1帧也是v的情况
71 | t1_uv = notes[:, t - 1, :] <= 0
72 | l2 = torch.pow(
73 | (notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1))
74 | * (~t1_uv).unsqueeze(-1)
75 | * (~t_uv).unsqueeze(1),
76 | 2,
77 | )
78 | l2 = l2 - 0.5
79 | l2 = l2 * (l2 > 0)
80 | penalty += l2
81 |
82 | # t-1帧是uv的情况,uv转v的惩罚
83 | penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
84 |
85 | # 选择最小惩罚
86 | min_value, min_indices = torch.min(
87 | dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1
88 | )
89 | dp[:, t, :] = min_value
90 | backtrack[:, t, :] = min_indices
91 |
92 | # backtrack
93 | t = f0s.size(1) - 1
94 | f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
95 | min_indices = torch.argmin(dp[:, t, :], dim=-1)
96 | for i in range(0, t + 1):
97 | f0_result[:, t - i] = f0s[:, t - i, min_indices]
98 | min_indices = backtrack[:, t - i, min_indices]
99 |
100 | return f0_result.unsqueeze(-1)
101 |
102 |
103 | class InferCFNaiveMelPE(torch.nn.Module):
104 | """Infer CFNaiveMelPE
105 | Args:
106 | args (DotDict): Config.
107 | state_dict (dict): Model state dict.
108 | """
109 |
110 | def __init__(self, args, state_dict):
111 | super().__init__()
112 | self.wav2mel = spawn_wav2mel(args, device="cpu")
113 | self.model = spawn_model(args)
114 | self.model.load_state_dict(state_dict)
115 | self.model.eval()
116 | self.args_dict = dict(args)
117 | self.register_buffer(
118 | "tensor_device_marker", torch.tensor(1.0).float(), persistent=False
119 | )
120 |
121 | def forward(
122 | self,
123 | wav: torch.Tensor,
124 | sr: [int, float],
125 | decoder_mode: str = "local_argmax",
126 | threshold: float = 0.006,
127 | key_shifts: list = [0],
128 | ) -> torch.Tensor:
129 | """Infer
130 | Args:
131 | wav (torch.Tensor): Input wav, (B, n_sample, 1).
132 | sr (int, float): Input wav sample rate.
133 | decoder_mode (str): Decoder type. Default: "local_argmax", support "argmax" or "local_argmax".
134 | threshold (float): Threshold to mask. Default: 0.006.
135 | key_shifts (list): Key shifts. Default: [0].
136 | return: f0 (torch.Tensor): f0 Hz, shape (B, (n_sample//hop_size + 1), 1).
137 | """
138 | with torch.no_grad():
139 | wav = wav.to(self.tensor_device_marker.device)
140 | mels = torch.stack(
141 | [self.wav2mel(wav, sr, keyshift=keyshift) for keyshift in key_shifts],
142 | -1,
143 | )
144 | mels = rearrange(mels, "B T C K -> (B K) T C")
145 | f0s = self.model.infer(mels, decoder=decoder_mode, threshold=threshold)
146 | f0s = rearrange(f0s, "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
147 | return f0s # (B, T, len(key_shifts))
148 |
149 | def infer(
150 | self,
151 | wav: torch.Tensor,
152 | sr: [int, float],
153 | decoder_mode: str = "local_argmax",
154 | threshold: float = 0.006,
155 | f0_min: float = None,
156 | f0_max: float = None,
157 | interp_uv: bool = False,
158 | output_interp_target_length: int = None,
159 | return_uv: bool = False,
160 | test_time_augmentation: bool = False,
161 | tta_uv_penalty: float = 12.0,
162 | tta_key_shifts: list = [0, -12, 12],
163 | tta_use_origin_uv=False,
164 | ) -> torch.Tensor or (torch.Tensor, torch.Tensor):
165 | """Infer
166 | Args:
167 | wav (torch.Tensor): Input wav, (B, n_sample, 1).
168 | sr (int, float): Input wav sample rate.
169 | decoder_mode (str): Decoder type. Default: "local_argmax", support "argmax" or "local_argmax".
170 | threshold (float): Threshold to mask. Default: 0.006.
171 | f0_min (float): Minimum f0. Default: None. Use in post-processing.
172 | f0_max (float): Maximum f0. Default: None. Use in post-processing.
173 | interp_uv (bool): Interpolate unvoiced frames. Default: False.
174 | output_interp_target_length (int): Output interpolation target length. Default: None.
175 | return_uv (bool): Return unvoiced frames. Default: False.
176 | test_time_augmentation (bool): Test time augmentation. If enabled, the output may be better but slower. Default: False.
177 | tta_uv_penalty (float): Test time augmentation unvoiced penalty. Default: 12.0.
178 | tta_key_shifts (list): Test time augmentation key shifts. Default: [0, -12, 12].
179 | tta_use_origin_uv (bool): Use origin uv. Default: False
180 | return: f0 (torch.Tensor): f0 Hz, shape (B, (n_sample//hop_size + 1) or output_interp_target_length, 1).
181 | if return_uv is True, return f0, uv. the shape of uv(torch.Tensor) is like f0.
182 | """
183 | # infer
184 | if test_time_augmentation:
185 | assert len(tta_key_shifts) > 0
186 | flag = 0
187 | if tta_use_origin_uv:
188 | if 0 not in tta_key_shifts:
189 | flag = 1
190 | tta_key_shifts.append(0)
191 | tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
192 | f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
193 | f0 = ensemble_f0(
194 | f0s[:, :, flag:],
195 | tta_key_shifts[flag:],
196 | tta_uv_penalty,
197 | )
198 | if tta_use_origin_uv:
199 | f0_for_uv = f0s[:, :, [0]]
200 | else:
201 | f0_for_uv = f0
202 | else:
203 | f0 = self.__call__(wav, sr, decoder_mode, threshold)
204 | f0_for_uv = f0
205 | if f0_min is None:
206 | f0_min = self.args_dict["model"]["f0_min"]
207 | uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
208 | f0 = f0 * (1 - uv)
209 | # interp
210 | if interp_uv:
211 | f0 = batch_interp_with_replacement_detach(
212 | uv.squeeze(-1).bool(), f0.squeeze(-1)
213 | ).unsqueeze(-1)
214 | if f0_max is not None:
215 | f0[f0 > f0_max] = f0_max
216 | if output_interp_target_length is not None:
217 | f0 = torch.where(f0 == 0, float("nan"), f0)
218 | f0 = torch.nn.functional.interpolate(
219 | f0.transpose(1, 2),
220 | size=int(output_interp_target_length),
221 | mode="linear",
222 | ).transpose(1, 2)
223 | f0 = torch.where(f0.isnan(), float(0.0), f0)
224 | # if return_uv is True, interp and return uv
225 | if return_uv:
226 | uv = torch.nn.functional.interpolate(
227 | uv.transpose(1, 2),
228 | size=int(output_interp_target_length),
229 | mode="nearest",
230 | ).transpose(1, 2)
231 | return f0, uv
232 | else:
233 | return f0
234 |
235 | def extact_midi(
236 | self,
237 | wav: torch.Tensor,
238 | sr: [int, float],
239 | output_path: str,
240 | decoder_mode: str = "local_argmax",
241 | threshold: float = 0.006,
242 | f0_min: float = None,
243 | f0_max: float = None,
244 | tempo: float = None,
245 | ):
246 | f0 = self.infer(
247 | wav,
248 | sr,
249 | decoder_mode,
250 | threshold,
251 | f0_min,
252 | f0_max,
253 | )
254 | f0 = f0.squeeze(-1).squeeze(0).cpu().numpy()
255 | wav = wav.squeeze(0).squeeze(-1).cpu().numpy()
256 | return f02midi(f0, tempo=tempo, output_path=output_path, sr=sr, y=wav)
257 |
258 | def get_hop_size(self) -> int:
259 | """Get hop size"""
260 | return DotDict(self.args_dict).mel.hop_size
261 |
262 | def get_hop_size_ms(self) -> float:
263 | """Get hop size in ms"""
264 | return (
265 | DotDict(self.args_dict).mel.hop_size / DotDict(self.args_dict).mel.sr * 1000
266 | )
267 |
268 | def get_model_sr(self) -> int:
269 | """Get model sample rate"""
270 | return DotDict(self.args_dict).mel.sr
271 |
272 | def get_mel_config(self) -> dict:
273 | """Get mel config"""
274 | return dict(DotDict(self.args_dict).mel)
275 |
276 | def get_device(self) -> str:
277 | """Get device"""
278 | return self.tensor_device_marker.device
279 |
280 | def get_model_f0_range(self) -> dict:
281 | """Get model f0 range like {'f0_min': 32.70, 'f0_max': 1975.5}"""
282 | return {
283 | "f0_min": DotDict(self.args_dict).model.f0_min,
284 | "f0_max": DotDict(self.args_dict).model.f0_max,
285 | }
286 |
287 |
288 | class InferCFNaiveMelPEONNX:
289 | """Infer CFNaiveMelPE ONNX
290 | Args:
291 | args (DotDict): Config.
292 | onnx_path (str): Path to onnx file.
293 | device (str): Device. must be not None.
294 | """
295 |
296 | def __init__(self, args, onnx_path, device):
297 | raise NotImplementedError
298 |
299 |
300 | def spawn_bundled_infer_model(device: str = None) -> InferCFNaiveMelPE:
301 | """
302 | Spawn bundled infer model
303 | This model has been trained on our dataset and comes with the package.
304 | You can use it directly without anything else.
305 | Args:
306 | device (str): Device. Default: None.
307 | """
308 | file_path = pathlib.Path(__file__)
309 | model_path = file_path.parent / "assets" / "fcpe_c_v001.pt"
310 | model = spawn_infer_model_from_pt(str(model_path), device, bundled_model=True)
311 | return model
312 |
313 |
314 | def spawn_infer_model_from_onnx(
315 | onnx_path: str, device: str = None
316 | ) -> InferCFNaiveMelPEONNX:
317 | """
318 | Spawn infer model from onnx file
319 | Args:
320 | onnx_path (str): Path to onnx file.
321 | device (str): Device. Default: None.
322 | """
323 | device = get_device(device, "torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_onnx")
324 | config_path = get_config_json_in_same_path(onnx_path)
325 | with open(config_path, "r", encoding="utf-8") as f:
326 | config_dict = json.load(f)
327 | args = DotDict(config_dict)
328 | if (args.is_onnx is None) or (args.is_onnx is False):
329 | raise ValueError(
330 | " [ERROR] spawn_infer_model_from_onnx: this model is not onnx model."
331 | )
332 |
333 | if args.model.type == "CFNaiveMelPEONNX":
334 | infer_model = InferCFNaiveMelPEONNX(args, onnx_path, device)
335 | else:
336 | raise ValueError(
337 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPEONNX"
338 | )
339 |
340 | return infer_model
341 |
342 |
343 | def spawn_infer_model_from_pt(
344 | pt_path: str, device: str = None, bundled_model: bool = False
345 | ) -> InferCFNaiveMelPE:
346 | """
347 | Spawn infer model from pt file
348 | Args:
349 | pt_path (str): Path to pt file.
350 | device (str): Device. Default: None.
351 | bundled_model (bool): Whether this model is bundled model, only used in spawn_bundled_infer_model.
352 | """
353 | device = get_device(device, "torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt")
354 | ckpt = torch.load(pt_path, map_location=torch.device(device))
355 | if bundled_model:
356 | ckpt["config_dict"]["model"]["conv_dropout"] = 0.0
357 | ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
358 | args = DotDict(ckpt["config_dict"])
359 | if (args.is_onnx is not None) and (args.is_onnx is True):
360 | raise ValueError(
361 | " [ERROR] spawn_infer_model_from_pt: this model is an onnx model."
362 | )
363 |
364 | if args.model.type == "CFNaiveMelPE":
365 | infer_model = InferCFNaiveMelPE(args, ckpt["model"])
366 | infer_model = infer_model.to(device)
367 | infer_model.eval()
368 | else:
369 | raise ValueError(
370 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPE"
371 | )
372 |
373 | return infer_model
374 |
375 |
376 | def spawn_model(args: DotDict) -> CFNaiveMelPE:
377 | """Spawn conformer naive model"""
378 | if args.model.type == "CFNaiveMelPE":
379 | pe_model = CFNaiveMelPE(
380 | input_channels=catch_none_args_must(
381 | args.mel.num_mels,
382 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
383 | warning_str="args.mel.num_mels is None",
384 | ),
385 | out_dims=catch_none_args_must(
386 | args.model.out_dims,
387 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
388 | warning_str="args.model.out_dims is None",
389 | ),
390 | hidden_dims=catch_none_args_must(
391 | args.model.hidden_dims,
392 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
393 | warning_str="args.model.hidden_dims is None",
394 | ),
395 | n_layers=catch_none_args_must(
396 | args.model.n_layers,
397 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
398 | warning_str="args.model.n_layers is None",
399 | ),
400 | n_heads=catch_none_args_must(
401 | args.model.n_heads,
402 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
403 | warning_str="args.model.n_heads is None",
404 | ),
405 | f0_max=catch_none_args_must(
406 | args.model.f0_max,
407 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
408 | warning_str="args.model.f0_max is None",
409 | ),
410 | f0_min=catch_none_args_must(
411 | args.model.f0_min,
412 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
413 | warning_str="args.model.f0_min is None",
414 | ),
415 | use_fa_norm=catch_none_args_must(
416 | args.model.use_fa_norm,
417 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
418 | warning_str="args.model.use_fa_norm is None",
419 | ),
420 | conv_only=catch_none_args_opti(
421 | args.model.conv_only,
422 | default=False,
423 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
424 | warning_str="args.model.conv_only is None",
425 | ),
426 | conv_dropout=catch_none_args_opti(
427 | args.model.conv_dropout,
428 | default=0.0,
429 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
430 | warning_str="args.model.conv_dropout is None",
431 | ),
432 | atten_dropout=catch_none_args_opti(
433 | args.model.atten_dropout,
434 | default=0.0,
435 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
436 | warning_str="args.model.atten_dropout is None",
437 | ),
438 | use_harmonic_emb=catch_none_args_opti(
439 | args.model.use_harmonic_emb,
440 | default=False,
441 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe",
442 | warning_str="args.model.use_harmonic_emb is None",
443 | ),
444 | )
445 | else:
446 | raise ValueError(
447 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPE"
448 | )
449 | return pe_model
450 |
451 |
452 | def bundled_infer_model_unit_test(wav_path):
453 | """Unit test for bundled infer model"""
454 | # wav_path is your wav file path
455 | try:
456 | import librosa
457 | import matplotlib.pyplot as plt
458 | except ImportError:
459 | print(
460 | " [UNIT_TEST] torchfcpe.tools.spawn_infer_model_from_pt: matplotlib or librosa not found, skip test"
461 | )
462 | exit(1)
463 |
464 | infer_model = spawn_bundled_infer_model(device="cpu")
465 | wav, sr = librosa.load(wav_path, sr=16000)
466 | f0 = infer_model.infer(torch.tensor(wav).unsqueeze(0), sr, interp_uv=False)
467 | f0_interp = infer_model.infer(torch.tensor(wav).unsqueeze(0), sr, interp_uv=True)
468 | plt.plot(f0.squeeze(-1).squeeze(0).numpy(), color="r", linestyle="-")
469 | plt.plot(f0_interp.squeeze(-1).squeeze(0).numpy(), color="g", linestyle="-")
470 | # 添加图例
471 | plt.legend(["f0", "f0_interp"])
472 | plt.xlabel("frame")
473 | plt.ylabel("f0")
474 | plt.title("f0")
475 | plt.show()
476 |
--------------------------------------------------------------------------------
/torchfcpe/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .mel_extractor import Wav2Mel, Wav2MelModule
3 | import pathlib
4 |
5 |
6 | class DotDict(dict):
7 | """
8 | DotDict, used for config
9 |
10 | Example:
11 | # >>> config = DotDict({'a': 1, 'b': {'c': 2}}})
12 | # >>> config.a
13 | # 1
14 | # >>> config.b.c
15 | # 2
16 | """
17 |
18 | def __getattr__(*args):
19 | val = dict.get(*args)
20 | return DotDict(val) if type(val) is dict else val
21 |
22 | __setattr__ = dict.__setitem__
23 | __delattr__ = dict.__delitem__
24 |
25 |
26 | def spawn_wav2mel(args: DotDict, device: str = None) -> Wav2MelModule:
27 | """Spawn wav2mel"""
28 | _type = args.mel.type
29 | if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'):
30 | _type = 'default'
31 | elif str(_type).lower() == 'stft':
32 | _type = 'stft'
33 | else:
34 | raise ValueError(f' [ERROR] torchfcpe.tools.args.spawn_wav2mel: {_type} is not a supported args.mel.type')
35 | wav2mel = Wav2MelModule(
36 | sr=catch_none_args_opti(
37 | args.mel.sr,
38 | default=16000,
39 | func_name='torchfcpe.tools.spawn_wav2mel',
40 | warning_str='args.mel.sr is None',
41 | ),
42 | n_mels=catch_none_args_opti(
43 | args.mel.num_mels,
44 | default=128,
45 | func_name='torchfcpe.tools.spawn_wav2mel',
46 | warning_str='args.mel.num_mels is None',
47 | ),
48 | n_fft=catch_none_args_opti(
49 | args.mel.n_fft,
50 | default=1024,
51 | func_name='torchfcpe.tools.spawn_wav2mel',
52 | warning_str='args.mel.n_fft is None',
53 | ),
54 | win_size=catch_none_args_opti(
55 | args.mel.win_size,
56 | default=1024,
57 | func_name='torchfcpe.tools.spawn_wav2mel',
58 | warning_str='args.mel.win_size is None',
59 | ),
60 | hop_length=catch_none_args_opti(
61 | args.mel.hop_size,
62 | default=160,
63 | func_name='torchfcpe.tools.spawn_wav2mel',
64 | warning_str='args.mel.hop_size is None',
65 | ),
66 | fmin=catch_none_args_opti(
67 | args.mel.fmin,
68 | default=0,
69 | func_name='torchfcpe.tools.spawn_wav2mel',
70 | warning_str='args.mel.fmin is None',
71 | ),
72 | fmax=catch_none_args_opti(
73 | args.mel.fmax,
74 | default=8000,
75 | func_name='torchfcpe.tools.spawn_wav2mel',
76 | warning_str='args.mel.fmax is None',
77 | ),
78 | clip_val=1e-05,
79 | mel_type=_type,
80 | )
81 | device = catch_none_args_opti(
82 | device,
83 | default='cpu',
84 | func_name='torchfcpe.tools.spawn_wav2mel',
85 | warning_str='.device is None',
86 | )
87 | return wav2mel.to(torch.device(device))
88 |
89 |
90 | def catch_none_args_opti(x, default, func_name, warning_str=None, level='WARN'):
91 | """Catch None, optional"""
92 | if x is None:
93 | if warning_str is not None:
94 | print(f' [{level}] {warning_str}; use default {default}')
95 | print(f' [{level}] > call by:{func_name}')
96 | return default
97 | else:
98 | return x
99 |
100 |
101 | def catch_none_args_must(x, func_name, warning_str):
102 | """Catch None, must"""
103 | level = "ERROR"
104 | if x is None:
105 | print(f' [{level}] {warning_str}')
106 | print(f' [{level}] > call by:{func_name}')
107 | raise ValueError(f' [{level}] {warning_str}')
108 | else:
109 | return x
110 |
111 |
112 | def get_device(device: str, func_name: str) -> str:
113 | """Get device"""
114 |
115 | if device is None:
116 | if torch.cuda.is_available():
117 | device = 'cuda'
118 | elif torch.backends.mps.is_available():
119 | device = 'mps'
120 | else:
121 | device = 'cpu'
122 |
123 | print(f' [INFO]: Using {device} automatically.')
124 | print(f' [INFO] > call by: {func_name}')
125 | else:
126 | print(f' [INFO]: device is not None, use {device}')
127 | print(f' [INFO] > call by:{func_name}')
128 | device = device
129 |
130 | # Check if the specified device is available, if not, switch to cpu
131 | if ((device == 'cuda' and not torch.cuda.is_available()) or
132 | (device == 'mps' and not torch.backends.mps.is_available())):
133 | print(f' [WARN]: Specified device ({device}) is not available, switching to cpu.')
134 | device = 'cpu'
135 |
136 | return device
137 |
138 |
139 | def get_config_json_in_same_path(path: str) -> str:
140 | """Get config json in same path"""
141 | path = pathlib.Path(path)
142 | config_json = path.parent / 'config.json'
143 | if config_json.exists():
144 | return str(config_json)
145 | else:
146 | raise FileNotFoundError(f' [ERROR] {config_json} not found.')
147 |
--------------------------------------------------------------------------------
/torchfcpe/torch_interp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | '''
3 | use from https://github.com/autumn-DL/TorchInterp
4 | it use MIT license
5 | '''
6 |
7 |
8 | def torch_interp(x, xp, fp):
9 | # if not isinstance(x, torch.Tensor):
10 | # x = torch.tensor(x)
11 | # if not isinstance(xp, torch.Tensor):
12 | # xp = torch.tensor(xp)
13 | # if not isinstance(fp, torch.Tensor):
14 | # fp = torch.tensor(fp)
15 |
16 | sort_idx = torch.argsort(xp)
17 | xp = xp[sort_idx]
18 | fp = fp[sort_idx]
19 |
20 | right_idxs = torch.searchsorted(xp, x)
21 |
22 | right_idxs = right_idxs.clamp(max=len(xp) - 1)
23 |
24 | left_idxs = (right_idxs - 1).clamp(min=0)
25 |
26 | x_left = xp[left_idxs]
27 | x_right = xp[right_idxs]
28 | y_left = fp[left_idxs]
29 | y_right = fp[right_idxs]
30 |
31 | interp_vals = y_left + ((x - x_left) * (y_right - y_left) / (x_right - x_left))
32 |
33 | interp_vals[x < xp[0]] = fp[0]
34 | interp_vals[x > xp[-1]] = fp[-1]
35 |
36 | return interp_vals
37 |
38 |
39 | def batch_interp_with_replacement_detach(uv, f0):
40 | '''
41 | :param uv: B T
42 | :param f0: B T
43 | :return: f0 B T
44 | '''
45 |
46 | result = f0.clone()
47 |
48 | for i in range(uv.shape[0]):
49 | x = torch.where(uv[i])[-1]
50 | xp = torch.where(~uv[i])[-1]
51 | fp = f0[i][~uv[i]]
52 |
53 | interp_vals = torch_interp(x, xp, fp).detach()
54 |
55 | result[i][uv[i]] = interp_vals
56 | return result
57 |
58 |
59 | def unit_text():
60 | try:
61 | import matplotlib.pyplot as plt
62 | except ImportError:
63 | print(' [UNIT_TEST] torchfcpe.torch_interp: matplotlib not found, skip plotting.')
64 | exit(1)
65 |
66 | # f0
67 | f0 = torch.tensor([1, 0, 3, 0, 0, 3, 4, 5, 0, 0]).float()
68 | uv = torch.tensor([0, 1, 0, 1, 1, 0, 0, 0, 1, 1]).bool()
69 |
70 | interp_f0 = batch_interp_with_replacement_detach(uv.unsqueeze(0), f0.unsqueeze(0)).squeeze(0)
71 |
72 | print(interp_f0)
73 |
74 |
75 | if __name__ == '__main__':
76 | unit_text()
77 |
--------------------------------------------------------------------------------
/train/configs/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | duration: 0.3483 # Audio duration during training, must be less than the duration of the shortest audio clip
3 | train_path: 'yx2/train' # Create a folder named "audio" under this path and put the audio clip in it
4 | valid_path: 'yx2/val' # Create a folder named "audio" under this path and put the audio clip in it
5 | extensions: # List of extension included in the data collection
6 | - wav
7 | mel:
8 | type: 'stft' # use "none" or "default" to use default
9 | sr: 16000
10 | num_mels: 512 # if stft need to match others
11 | n_fft: 1024
12 | win_size: 1024
13 | hop_size: 160
14 | fmin: 0
15 | fmax: 8000
16 | model:
17 | type: 'CFNaiveMelPE'
18 | out_dims: 360
19 | hidden_dims: 512
20 | n_layers: 6
21 | n_heads: 8
22 | f0_min: 32.70
23 | f0_max: 1975.5
24 | use_fa_norm: true
25 | conv_only: true
26 | conv_dropout: 0.0
27 | atten_dropout: 0.0
28 | use_harmonic_emb: false
29 | loss:
30 | loss_scale: 10
31 | device: cuda
32 | env:
33 | expdir: exp/yx2_001ac_stft
34 | gpu_id: 0
35 | train:
36 | aug_add_music: true
37 | aug_keyshift: true
38 | f0_shift_mode: 'keyshift'
39 | keyshift_min: -6
40 | keyshift_max: 6
41 | aug_noise: true
42 | noise_ratio: 0.7
43 | brown_noise_ratio: 1
44 | aug_mask: true
45 | aug_mask_v_o: true
46 | aug_mask_vertical_factor: 0.05
47 | aug_mask_vertical_factor_v_o: 0.3
48 | aug_mask_iszeropad_mode: 'noise' # randon zero or noise
49 | aug_mask_block_num: 1
50 | aug_mask_block_num_v_o: 1
51 | num_workers: 14 # If your cpu and gpu are both very strong, set to 0 may be faster!
52 | amp_dtype: fp32 # only can ues fp32, else nan
53 | batch_size: 128
54 | use_redis: true
55 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow
56 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu
57 | epochs: 100000
58 | interval_log: 100
59 | interval_val: 5000
60 | interval_force_save: 5000
61 | lr: 0.0005
62 | decay_step: 100000
63 | gamma: 0.7071
64 | weight_decay: 0.0001
65 | save_opt: false
66 |
--------------------------------------------------------------------------------
/train/data_loaders_wav.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import numpy as np
4 | import librosa
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset
9 | from concurrent.futures import ProcessPoolExecutor
10 | import torch.multiprocessing as mp
11 | import utils_all as ut
12 | import pandas as pd
13 | import torchfcpe
14 | from redis_coder import RedisService, encode_wb, decode_wb
15 | DATABUFFER = RedisService(
16 | host='localhost',
17 | password='',
18 | port=6379,
19 | max_connections=16
20 | )
21 | def traverse_dir(
22 | root_dir,
23 | extensions,
24 | amount=None,
25 | str_include=None,
26 | str_exclude=None,
27 | is_pure=False,
28 | is_sort=False,
29 | is_ext=True):
30 | file_list = []
31 | cnt = 0
32 | for root, _, files in os.walk(root_dir):
33 | for file in files:
34 | if any([file.endswith(f".{ext}") for ext in extensions]):
35 | # path
36 | mix_path = os.path.join(root, file)
37 | pure_path = mix_path[len(root_dir) + 1:] if is_pure else mix_path
38 |
39 | # amount
40 | if (amount is not None) and (cnt == amount):
41 | if is_sort:
42 | file_list.sort()
43 | return file_list
44 |
45 | # check string
46 | if (str_include is not None) and (str_include not in pure_path):
47 | continue
48 | if (str_exclude is not None) and (str_exclude in pure_path):
49 | continue
50 |
51 | if not is_ext:
52 | ext = pure_path.split('.')[-1]
53 | pure_path = pure_path[:-(len(ext) + 1)]
54 | file_list.append(pure_path)
55 | cnt += 1
56 | if is_sort:
57 | file_list.sort()
58 | return file_list
59 |
60 |
61 | def get_data_loaders(args, jump=False):
62 | wav2mel = torchfcpe.spawn_wav2mel(args, device='cpu')
63 | data_train = F0Dataset(
64 | path_root=args.data.train_path,
65 | waveform_sec=args.data.duration,
66 | hop_size=args.mel.hop_size,
67 | sample_rate=args.mel.sr,
68 | duration=args.data.duration,
69 | load_all_data=args.train.cache_all_data,
70 | whole_audio=False,
71 | extensions=args.data.extensions,
72 | device=args.train.cache_device,
73 | wav2mel=wav2mel,
74 | aug_noise=args.train.aug_noise,
75 | noise_ratio=args.train.noise_ratio,
76 | brown_noise_ratio=args.train.brown_noise_ratio,
77 | aug_mask=args.train.aug_mask,
78 | aug_mask_v_o=args.train.aug_mask_v_o,
79 | aug_mask_vertical_factor=args.train.aug_mask_vertical_factor,
80 | aug_mask_vertical_factor_v_o=args.train.aug_mask_vertical_factor_v_o,
81 | aug_mask_iszeropad_mode=args.train.aug_mask_iszeropad_mode,
82 | aug_mask_block_num=args.train.aug_mask_block_num,
83 | aug_mask_block_num_v_o=args.train.aug_mask_block_num_v_o,
84 | aug_keyshift=args.train.aug_keyshift,
85 | keyshift_min=args.train.keyshift_min,
86 | keyshift_max=args.train.keyshift_max,
87 | f0_min=args.model.f0_min,
88 | f0_max=args.model.f0_max,
89 | f0_shift_mode='keyshift',
90 | load_data_num_processes=8,
91 | use_redis=args.train.use_redis,
92 | jump=jump
93 | )
94 | loader_train = torch.utils.data.DataLoader(
95 | data_train,
96 | batch_size=args.train.batch_size,
97 | shuffle=True,
98 | num_workers=args.train.num_workers if args.train.cache_device == 'cpu' else 0,
99 | persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == 'cpu' else False,
100 | pin_memory=True if args.train.cache_device == 'cpu' else False
101 | )
102 | data_valid = F0Dataset(
103 | path_root=args.data.valid_path,
104 | waveform_sec=args.data.duration,
105 | hop_size=args.mel.hop_size,
106 | sample_rate=args.mel.sr,
107 | duration=args.data.duration,
108 | load_all_data=args.train.cache_all_data,
109 | whole_audio=True,
110 | extensions=args.data.extensions,
111 | wav2mel=wav2mel,
112 | aug_noise=args.train.aug_noise,
113 | noise_ratio=args.train.noise_ratio,
114 | brown_noise_ratio=args.train.brown_noise_ratio,
115 | aug_mask=args.train.aug_mask,
116 | aug_mask_v_o=args.train.aug_mask_v_o,
117 | aug_mask_vertical_factor=args.train.aug_mask_vertical_factor,
118 | aug_mask_vertical_factor_v_o=args.train.aug_mask_vertical_factor_v_o,
119 | aug_mask_iszeropad_mode=args.train.aug_mask_iszeropad_mode,
120 | aug_mask_block_num=args.train.aug_mask_block_num,
121 | aug_mask_block_num_v_o=args.train.aug_mask_block_num_v_o,
122 | aug_keyshift=False,
123 | )
124 | loader_valid = torch.utils.data.DataLoader(
125 | data_valid,
126 | batch_size=1,
127 | shuffle=False,
128 | num_workers=0,
129 | pin_memory=True
130 | )
131 | return loader_train, loader_valid
132 |
133 |
134 | class F0Dataset(Dataset):
135 | def __init__(
136 | self,
137 | path_root,
138 | waveform_sec,
139 | hop_size,
140 | sample_rate,
141 | duration,
142 | load_all_data=True,
143 | whole_audio=False,
144 | extensions=['wav'],
145 | device='cpu',
146 | wav2mel=None,
147 | aug_noise=False,
148 | noise_ratio=0.7,
149 | brown_noise_ratio=1.,
150 | aug_mask=False,
151 | aug_mask_v_o=False,
152 | aug_mask_vertical_factor=0.05,
153 | aug_mask_vertical_factor_v_o=0.3,
154 | aug_mask_iszeropad_mode='randon', # randon zero or noise
155 | aug_mask_block_num=1,
156 | aug_mask_block_num_v_o=4,
157 | aug_keyshift=True,
158 | keyshift_min=-5,
159 | keyshift_max=12,
160 | f0_min=32.70,
161 | f0_max=1975.5,
162 | f0_shift_mode='keyshift',
163 | load_data_num_processes=1,
164 | snb_noise=None,
165 | noise_beta=0,
166 | use_redis=False,
167 | jump=False
168 | ):
169 | super().__init__()
170 | self.music_spk_id = 1
171 | self.wav2mel = wav2mel
172 | self.waveform_sec = waveform_sec
173 | self.sample_rate = sample_rate
174 | self.hop_size = hop_size
175 | self.path_root = path_root
176 | self.duration = duration
177 | self.aug_noise = aug_noise
178 | self.noise_ratio = noise_ratio
179 | self.brown_noise_ratio = brown_noise_ratio
180 | self.aug_mask = aug_mask
181 | self.aug_mask_v_o = aug_mask_v_o
182 | self.aug_mask_vertical_factor = aug_mask_vertical_factor
183 | self.aug_mask_vertical_factor_v_o = aug_mask_vertical_factor_v_o
184 | self.aug_mask_iszeropad_mode = aug_mask_iszeropad_mode
185 | self.aug_mask_block_num = aug_mask_block_num
186 | self.aug_mask_block_num_v_o = aug_mask_block_num_v_o
187 | self.aug_keyshift = aug_keyshift
188 | self.keyshift_min = keyshift_min
189 | self.keyshift_max = keyshift_max
190 | self.f0_min = f0_min
191 | self.f0_max = f0_max
192 | self.f0_shift_mode = f0_shift_mode
193 | self.n_spk = 4
194 | self.device = device
195 | self.load_all_data = load_all_data
196 | self.snb_noise = snb_noise
197 | self.noise_beta = noise_beta
198 | self.use_redis = use_redis
199 | self.jump = jump
200 |
201 | self.paths = traverse_dir(
202 | os.path.join(path_root, 'audio'),
203 | extensions=extensions,
204 | is_pure=True,
205 | is_sort=True,
206 | is_ext=True
207 | )
208 |
209 | self.whole_audio = whole_audio
210 | if self.use_redis:
211 | self.data_buffer = None
212 | else:
213 | self.data_buffer = {}
214 | self.device = device
215 | if load_all_data:
216 | print('Load all the data from :', path_root)
217 | else:
218 | print('Load the f0, volume data from :', path_root)
219 |
220 | if self.use_redis:
221 | _ = self.load_data(self.paths)
222 | else:
223 | with torch.no_grad():
224 | with ProcessPoolExecutor(max_workers=load_data_num_processes) as executor:
225 | tasks = []
226 | for i in range(load_data_num_processes):
227 | start = int(i * len(self.paths) / load_data_num_processes)
228 | end = int((i + 1) * len(self.paths) / load_data_num_processes)
229 | file_chunk = self.paths[start:end]
230 | tasks.append(file_chunk)
231 | for data_buffer in executor.map(self.load_data, tasks):
232 | self.data_buffer.update(data_buffer)
233 |
234 | self.paths = np.array(self.paths, dtype=object)
235 | self.data_buffer = pd.DataFrame(self.data_buffer)
236 |
237 | def load_data(self, paths):
238 | with torch.no_grad():
239 | data_buffer = {}
240 | rank = mp.current_process()._identity
241 | rank = rank[0] if len(rank) > 0 else 0
242 | for name_ext in tqdm(paths):
243 | path_audio = os.path.join(self.path_root, 'audio', name_ext)
244 | duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate)
245 |
246 | path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy'
247 | f0 = np.load(path_f0)[:, None]
248 | # f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(self.device)
249 |
250 | if self.n_spk is not None and self.n_spk > 1:
251 | dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0]
252 | t_spk_id = spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0
253 | if spk_id < 1 or spk_id > self.n_spk:
254 | raise ValueError(
255 | ' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ')
256 | else:
257 | pass
258 | # spk_id = 1
259 | # t_spk_id = spk_id
260 | # spk_id = torch.LongTensor(np.array([spk_id])).to(self.device)
261 | # spk_id = np.array([spk_id])
262 |
263 | if self.load_all_data:
264 | audio, sr = librosa.load(path_audio, sr=self.sample_rate)
265 | if len(audio.shape) > 1:
266 | audio = librosa.to_mono(audio)
267 | # audio = torch.from_numpy(audio).to(device)
268 |
269 | # path_audio = os.path.join(self.path_root, 'npaudiodir', name_ext) + '.npy'
270 | # audio = np.load(path_audio)
271 |
272 | if spk_id == self.music_spk_id:
273 | path_music = os.path.join(self.path_root, 'music', name_ext)# + '.npy'
274 | # audio_music = np.load(path_music)
275 | audio_music, _ = librosa.load(path_music, sr=self.sample_rate)
276 | if len(audio_music.shape) > 1:
277 | audio_music = librosa.to_mono(audio_music)
278 | else:
279 | audio_music = None
280 |
281 | """
282 | data_buffer[name_ext] = {
283 | 'duration': duration,
284 | 'audio': audio,
285 | 'f0': f0,
286 | 'spk_id': spk_id,
287 | 't_spk_id': t_spk_id,
288 | }
289 | """
290 | if self.use_redis:
291 | f0 = encode_wb(f0, f0.dtype, f0.shape)
292 | audio = encode_wb(audio, audio.dtype, audio.shape)
293 | if audio_music is not None:
294 | audio_music = encode_wb(audio_music, audio_music.dtype, audio_music.shape)
295 | else:
296 | audio_music = int(0)
297 | if self.use_redis:
298 | if not self.jump:
299 | DATABUFFER[name_ext] = list((duration, f0, audio, audio_music))
300 | data_buffer = None
301 | else:
302 | data_buffer[name_ext] = (duration, f0, audio, audio_music)
303 | else:
304 | if spk_id == self.music_spk_id:
305 | use_music = True
306 | else:
307 | use_music = None
308 | """
309 | data_buffer[name_ext] = {
310 | 'duration': duration,
311 | 'f0': f0,
312 | 'spk_id': spk_id,
313 | 't_spk_id': t_spk_id
314 | }
315 | """
316 | data_buffer[name_ext] = (duration, f0, use_music)
317 | return data_buffer
318 |
319 | def __getitem__(self, file_idx):
320 | with torch.no_grad():
321 | name_ext = self.paths[file_idx]
322 | if self.use_redis:
323 | data_buffer = DATABUFFER[name_ext]
324 | else:
325 | data_buffer = self.data_buffer[name_ext]
326 | # check duration. if too short, then skip
327 | if float(data_buffer[0]) < (self.waveform_sec + 0.1):
328 | return self.__getitem__((file_idx + 1) % len(self.paths))
329 |
330 | # get item
331 | return self.get_data(name_ext, tuple(data_buffer))
332 |
333 | def get_data(self, name_ext, data_buffer):
334 | with torch.no_grad():
335 | name = os.path.splitext(name_ext)[0]
336 | frame_resolution = self.hop_size / self.sample_rate
337 | duration = float(data_buffer[0])
338 | waveform_sec = duration if self.whole_audio else self.waveform_sec
339 |
340 | # load audio
341 | idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
342 | start_frame = int(idx_from / frame_resolution)
343 | units_frame_len = int(waveform_sec / frame_resolution)
344 |
345 | # load f0
346 | if self.use_redis and self.load_all_data:
347 | f0 = decode_wb(data_buffer[1])
348 | f0 = f0.copy()
349 | else:
350 | f0 = data_buffer[1].copy()
351 | f0 = torch.from_numpy(f0).float().cpu()
352 |
353 | # load mel
354 | # audio = data_buffer.get('audio')
355 | if len(data_buffer) == 3:
356 | #path_audio = os.path.join(self.path_root, 'npaudiodir', name_ext) + '.npy'
357 | #audio = np.load(path_audio)
358 | path_audio = os.path.join(self.path_root, 'audio', name_ext)
359 | audio, _ = librosa.load(path_audio, sr=self.sample_rate)
360 | if len(audio.shape) > 1:
361 | audio = librosa.to_mono(audio)
362 | if random.choice((False, True)) and (data_buffer[2] is not None):
363 | path_music = os.path.join(self.path_root, 'music', name_ext)
364 | audio_music, _ = librosa.load(path_music, sr=self.sample_rate)
365 | if len(audio_music.shape) > 1:
366 | audio_music = librosa.to_mono(audio_music)
367 | audio = audio + audio_music
368 | del audio_music
369 | audio = 0.98 * audio / (np.abs(audio).max())
370 |
371 | else:
372 | if self.use_redis:
373 | audio = decode_wb(data_buffer[2])
374 | audio = audio.copy()
375 | else:
376 | audio = data_buffer[2].copy()
377 | if len(data_buffer) == 4:
378 | if self.use_redis:
379 | if len(data_buffer[3]) == 1:
380 | pass
381 | else:
382 | audio_music = decode_wb(data_buffer[3])
383 | audio_music = audio_music.copy()
384 | audio = audio + audio_music
385 | del audio_music
386 | audio = 0.98 * audio / (np.abs(audio).max())
387 | else:
388 | if data_buffer[3] is not None:
389 | if random.choice((False, True)):
390 | audio_music = data_buffer[3].copy()
391 | audio = audio + audio_music
392 | del audio_music
393 | audio = 0.98 * audio / (np.abs(audio).max())
394 |
395 | if random.choice((False, True)) and self.aug_keyshift:
396 | if self.f0_shift_mode == 'keyshift':
397 | _f0_shift_mode = 'keyshift'
398 | elif self.f0_shift_mode == 'automax':
399 | _f0_shift_mode = 'automax'
400 | elif self.f0_shift_mode == 'random':
401 | _f0_shift_mode = random.choice(('keyshift', 'automax'))
402 | else:
403 | raise ValueError('f0_shift_mode must be keyshift, automax or random')
404 |
405 | if _f0_shift_mode == 'keyshift':
406 | keyshift = random.uniform(self.keyshift_min, self.keyshift_max)
407 | elif _f0_shift_mode == 'automax':
408 | keyshift_max = 12 * np.log2(self.f0_max / f0.max)
409 | keyshift_min = 12 * np.log2(self.f0_min / f0.min)
410 | keyshift = random.uniform(keyshift_min, keyshift_max)
411 | with torch.no_grad():
412 | f0 = 2 ** (keyshift / 12) * f0
413 | else:
414 | keyshift = 0
415 |
416 | is_aug_noise = bool(random.randint(0, 1))
417 |
418 | if self.snb_noise is not None:
419 | audio = ut.add_noise_snb(audio, self.snb_noise, self.noise_beta)
420 |
421 | if self.aug_noise and is_aug_noise:
422 | if bool(random.randint(0, 1)):
423 | audio = ut.add_noise(audio, noise_ratio=self.noise_ratio)
424 | else:
425 | audio = ut.add_noise_slice(audio, self.sample_rate, self.duration, noise_ratio=self.noise_ratio,
426 | brown_noise_ratio=self.brown_noise_ratio)
427 |
428 | peak = np.abs(audio).max()
429 | audio = 0.98 * audio / peak
430 | audio = torch.from_numpy(audio).float().unsqueeze(0).cpu()
431 | with torch.no_grad():
432 | mel = self.wav2mel(audio, sample_rate=self.sample_rate, keyshift=keyshift, no_cache_window=True).squeeze(0).cpu()
433 |
434 | if self.aug_mask and bool(random.randint(0, 1)) and not is_aug_noise:
435 | v_o = bool(random.randint(0, 1)) and self.aug_mask_v_o
436 | mel = mel.transpose(-1, -2)
437 | if self.aug_mask_iszeropad_mode == 'zero':
438 | iszeropad = True
439 | elif self.aug_mask_iszeropad_mode == 'noise':
440 | iszeropad = False
441 | else:
442 | iszeropad = bool(random.randint(0, 1))
443 | mel = ut.add_mel_mask_slice(mel, self.sample_rate, self.duration, hop_size=self.hop_size,
444 | vertical_factor=self.aug_mask_vertical_factor_v_o if v_o else self.aug_mask_vertical_factor,
445 | vertical_offset=v_o, iszeropad=iszeropad,
446 | block_num=self.aug_mask_block_num_v_o if v_o else self.aug_mask_block_num)
447 | mel = mel.transpose(-1, -2)
448 |
449 | mel = mel[start_frame: start_frame + units_frame_len].detach()
450 |
451 | f0_frames = f0[start_frame: start_frame + units_frame_len].detach()
452 |
453 | # load spk_id
454 | # spk_id = data_buffer.get('spk_id')
455 | # spk_id = torch.LongTensor(spk_id).to(self.device)
456 |
457 |
458 | del audio
459 | # return dict(mel=mel, f0=f0_frames, spk_id=spk_id, name=name, name_ext=name_ext)
460 | output = (mel, f0_frames, name, name_ext)
461 | return output
462 |
463 | def __len__(self):
464 | return len(self.paths)
465 |
--------------------------------------------------------------------------------
/train/draw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | from tqdm import tqdm
5 | import shutil
6 |
7 | train_dir = r'yx1\train'
8 | val_dir = r'yx1\val'
9 | """
10 | for spker in os.listdir(train_dir):
11 | train_list = os.listdir(os.path.join(train_dir, spker, "audio"))
12 | val_list = []
13 | for i in range(10):
14 | # 生成随机数,作为索引,随机选择一个文件
15 | _random = random.randint(0, len(train_list) - 1)
16 | val_list.append(train_list[_random])
17 | train_list.pop(_random)
18 | print(_random, len(train_list))
19 | for i in val_list:
20 | print(i)
21 | os.makedirs(os.path.join(val_dir, spker, "audio"), exist_ok=True)
22 | shutil.move(os.path.join(train_dir, spker, "audio", i), os.path.join(val_dir, spker, "audio", i))
23 | os.makedirs(os.path.join(val_dir, spker, "f0"), exist_ok=True)
24 | shutil.move(os.path.join(train_dir, spker, "f0", i + '.npy'), os.path.join(val_dir, spker, "f0", i + '.npy'))
25 | if int(spker) == 1:
26 | os.makedirs(os.path.join(val_dir, spker, "music"), exist_ok=True)
27 | shutil.move(os.path.join(train_dir, spker, "music", i), os.path.join(val_dir, spker, "music", i))
28 | """
29 |
--------------------------------------------------------------------------------
/train/pre_data.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import os
3 | import librosa
4 | import soundfile as sf
5 | root_dir = r'XXX'
6 | raw_dir = os.path.join(root_dir, 'raw')
7 | music_dir = os.path.join(root_dir, 'music')
8 | f0_mir1k_dir = os.path.join(root_dir, 'f0_mir1k')
9 | f0_ptdb = os.path.join(root_dir, 'f0')
10 | audio_mir1k = os.path.join(root_dir, 'audio_mir1k')
11 | audio_ptdb = os.path.join(root_dir, 'audio')
12 |
13 |
14 | file_list = os.listdir(raw_dir)
15 |
16 | for file_name in file_list:
17 | print(file_name)
18 | if file_name[-3:] == '.pv':
19 | pipe_mode = 'PV'
20 | else:
21 | pipe_mode = 'WAV'
22 | if file_name[:4] == 'mic_':
23 | dataset_mode = 'ptdb'
24 | else:
25 | dataset_mode = 'mir1k'
26 |
27 | dataset_mode = 'ptdb'
28 |
29 | if pipe_mode == 'PV':
30 | pv_np = numpy.loadtxt(os.path.join(raw_dir, file_name))
31 | pv_np = numpy.insert(pv_np, 0, 0)
32 | mask = pv_np == 0
33 | f0_np = 440 * (2.0 ** ((pv_np - 69.0) / 12.0))
34 | f0_np[mask] = 0
35 | if dataset_mode == 'ptdb':
36 | numpy.save(os.path.join(f0_ptdb, file_name[:-3] + '.wav.npy'), f0_np)
37 | else:
38 | numpy.save(os.path.join(f0_mir1k_dir, file_name[:-3] + '.wav.npy'), f0_np)
39 | else:
40 | if dataset_mode == 'ptdb':
41 | audio, sr = librosa.load(os.path.join(raw_dir, file_name), mono=False, sr=None)
42 | assert sr == 16000
43 | assert len(audio.shape) == 1
44 | sf.write(os.path.join(audio_ptdb, file_name), audio, sr)
45 | else:
46 | audio, sr = librosa.load(os.path.join(raw_dir, file_name), mono=False, sr=None)
47 | assert sr == 16000
48 | assert len(audio.shape) == 2
49 | sf.write(os.path.join(audio_mir1k, file_name), audio[1], sr)
50 | sf.write(os.path.join(music_dir, file_name), audio[0], sr)
51 |
52 |
53 | """
54 | in_wav, in_sr = librosa.load(os.path.join(raw_dir, "abjones_1_02.wav"), mono=False, sr=None)
55 | sf.write("test_0.wav", in_wav[0], in_sr) # music
56 | sf.write("test_1.wav", in_wav[1], in_sr) # voice
57 | print(in_wav.shape)
58 | print(in_sr)
59 | print(os.path.join(raw_dir, "abjones_1_02.wav"))
60 | """
61 | """
62 | """
63 | """
64 | a = numpy.array([1, 2, 0, 0, 0, 90, 58.03715, 66, 66, 66, 66])
65 | # 给a的前面加一个0
66 | b = numpy.insert(a, 0, 0)
67 | # 为0的地方为False,其他为True
68 | mask = b == 0
69 | # f0 (Hz)= 440 * (2.0 ** ((b - 69.0) / 12.0))
70 | c = 440 * (2.0 ** ((b - 69.0) / 12.0))
71 | # 将c中为0的地方置为0
72 | c[mask] = 0
73 | print(c)
74 | print(mask)
75 | print(b)
76 | """
77 |
--------------------------------------------------------------------------------
/train/redis_coder.py:
--------------------------------------------------------------------------------
1 | import redis
2 | import numpy as np
3 | import struct
4 |
5 |
6 | class RedisPool:
7 | def __init__(self, host, password, port, max_connections=10):
8 | # 创建 Redis 连接池
9 | self.pool = redis.ConnectionPool(host=host, port=port, password=password, max_connections=max_connections)
10 |
11 | def get_redis_conn(self):
12 | # 获取 Redis 连接
13 | redis_conn = redis.StrictRedis(connection_pool=self.pool)
14 | try:
15 | # 检查连接是否可用
16 | redis_conn.ping()
17 | except Exception:
18 | # 如果连接不可用,则重建连接
19 | redis_conn.connection_pool.disconnect()
20 | redis_conn = redis.StrictRedis(connection_pool=self.pool)
21 | return redis_conn
22 |
23 |
24 | class RedisService:
25 | def __init__(self, host, password, port, max_connections=10):
26 | # 创建 Redis 连接池对象
27 | self.pool = RedisPool(host=host, port=port, password=password, max_connections=max_connections)
28 | # 获取 Redis 连接
29 | self.redis_conn = self.pool.get_redis_conn()
30 |
31 | def set(self, **kwargs):
32 | for key, value in kwargs.items():
33 | self.redis_conn.__setitem__(key, value)
34 |
35 | def push(self, key, value):
36 | self.redis_conn.lpush(key, value)
37 |
38 | def pop(self, key):
39 | value = self.redis_conn.rpop(key)
40 | self.redis_conn.lrem(key, 0, value)
41 | return value
42 |
43 | def list_get_index(self, key, index):
44 | return self.redis_conn.lindex(key, index)
45 |
46 | def llen(self, key):
47 | return self.redis_conn.llen(key)
48 |
49 | def set_add(self, key, *value):
50 | self.redis_conn.sadd(key, *value)
51 |
52 | def set_member_exists(self, key, value):
53 | return self.redis_conn.sismember(key, value)
54 |
55 | def exitst(self, key):
56 | return self.redis_conn.exists(key)
57 |
58 | def __setitem__(self, key, value):
59 | if self.redis_conn.exists(key):
60 | self.redis_conn.delete(key)
61 | if isinstance(value, str):
62 | self.redis_conn.set(key, value)
63 | elif isinstance(value, dict):
64 | self.redis_conn.hmset(key, value)
65 | elif isinstance(value, list):
66 | self.redis_conn.rpush(key, *value)
67 | else:
68 | self.redis_conn.set(key, value)
69 |
70 | def __getitem__(self, key):
71 | key_type = self.redis_conn.type(key)
72 | if key_type == b'none':
73 | return None
74 | elif key_type == b'list':
75 | return self.redis_conn.lrange(key, 0, -1)
76 | elif key_type == b'string':
77 | return self.redis_conn.get(key)
78 | elif key_type == b'hash':
79 | return self.redis_conn.hgetall(key)
80 | else:
81 | print(f"Key {key} type {key_type} not support")
82 | return None
83 |
84 |
85 | def encode_wb(wb_np: np.ndarray,
86 | dtype: np.dtype,
87 | shape: tuple,
88 | ) -> bytes:
89 | # 将numpy数组转换为bytes,头部为shape和dtype,变长编码
90 | wb_bytes = struct.pack('i', len(shape))
91 | for i in shape:
92 | wb_bytes += struct.pack('i', i)
93 |
94 | wb_bytes += struct.pack('i', len(dtype.name))
95 | wb_bytes += dtype.name.encode('utf-8')
96 | wb_bytes += wb_np.tobytes()
97 | return wb_bytes
98 |
99 |
100 | def decode_wb(wb_bytes: bytes) -> np.ndarray:
101 | # 将上文编码的bytes转换回numpy数组
102 | shape_len = struct.unpack('i', wb_bytes[:4])[0]
103 | shape = []
104 | for i in range(shape_len):
105 | shape.append(struct.unpack('i', wb_bytes[4 + i * 4: 8 + i * 4])[0])
106 | dtype_len = struct.unpack('i', wb_bytes[4 + shape_len * 4: 8 + shape_len * 4])[0]
107 | dtype = np.dtype(struct.unpack(f'{dtype_len}s', wb_bytes[8 + shape_len * 4: 8 + shape_len * 4 + dtype_len])[0])
108 | wb_np = np.frombuffer(wb_bytes[8 + shape_len * 4 + dtype_len:], dtype=dtype)
109 | wb_np = wb_np.reshape(shape)
110 | return wb_np
111 |
--------------------------------------------------------------------------------
/train/savertools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CNChTu/FCPE/d55c1b636dc3564dc35be19675998688931e870c/train/savertools/__init__.py
--------------------------------------------------------------------------------
/train/savertools/saver.py:
--------------------------------------------------------------------------------
1 | '''
2 | author: wayn391@mastertones
3 | '''
4 |
5 | import os
6 | import time
7 | import yaml
8 | import datetime
9 | import torch
10 | import matplotlib.pyplot as plt
11 | plt.switch_backend('agg')
12 | from . import utils
13 | import numpy as np
14 | from torch.utils.tensorboard import SummaryWriter
15 |
16 |
17 | class Saver(object):
18 | def __init__(
19 | self,
20 | args,
21 | initial_global_step=-1):
22 |
23 | self.expdir = args.env.expdir
24 | self.sample_rate = args.mel.sampling_rate
25 |
26 | # cold start
27 | self.global_step = initial_global_step
28 | self.init_time = time.time()
29 | self.last_time = time.time()
30 |
31 | # makedirs
32 | os.makedirs(self.expdir, exist_ok=True)
33 |
34 | # path
35 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
36 |
37 | # ckpt
38 | os.makedirs(self.expdir, exist_ok=True)
39 |
40 | # writer
41 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
42 |
43 | # save config
44 | path_config = os.path.join(self.expdir, 'config.yaml')
45 | with open(path_config, "w") as out_config:
46 | yaml.dump(dict(args), out_config)
47 |
48 | # save spk_emb_dict
49 | if args.model.use_speaker_encoder:
50 | import numpy as np
51 | path_from_spk_emb_dict = os.path.join(args.data.train_path, 'spk_emb_dict.npy')
52 | path_save_spk_emb_dict = os.path.join(self.expdir, 'spk_emb_dict.npy')
53 | temp_spk_emb_dict = np.load(path_from_spk_emb_dict, allow_pickle=True).item()
54 | np.save(path_save_spk_emb_dict, temp_spk_emb_dict)
55 |
56 | def log_info(self, msg):
57 | '''log method'''
58 | if isinstance(msg, dict):
59 | msg_list = []
60 | for k, v in msg.items():
61 | tmp_str = ''
62 | if isinstance(v, int):
63 | tmp_str = '{}: {:,}'.format(k, v)
64 | else:
65 | tmp_str = '{}: {}'.format(k, v)
66 |
67 | msg_list.append(tmp_str)
68 | msg_str = '\n'.join(msg_list)
69 | else:
70 | msg_str = msg
71 |
72 | # dsplay
73 | print(msg_str)
74 |
75 | # save
76 | with open(self.path_log_info, 'a') as fp:
77 | fp.write(msg_str + '\n')
78 |
79 | def log_value(self, dict):
80 | for k, v in dict.items():
81 | self.writer.add_scalar(k, v, self.global_step)
82 |
83 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
84 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
85 | spec = spec_cat[0]
86 | if isinstance(spec, torch.Tensor):
87 | spec = spec.cpu().numpy()
88 | fig = plt.figure(figsize=(12, 9))
89 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
90 | plt.tight_layout()
91 | self.writer.add_figure(name, fig, self.global_step)
92 |
93 | def log_audio(self, dict):
94 | for k, v in dict.items():
95 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
96 |
97 | def log_f0(self, name, f0_pr, f0_gt, inuv=False):
98 | #f0_gt = (1 + f0_gt / 700).log()
99 | name = (name + '_f0_inuv') if inuv else (name + '_f0')
100 | f0_pr = f0_pr.squeeze().cpu().numpy()
101 | f0_gt = f0_gt.squeeze().cpu().numpy()
102 | if inuv:
103 | uv = f0_pr == 0
104 | if len(f0_pr[~uv]) > 0:
105 | f0_pr[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0_pr[~uv])
106 | uv = f0_gt == 0
107 | if len(f0_gt[~uv]) > 0:
108 | f0_gt[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0_gt[~uv])
109 | fig = plt.figure()
110 | plt.plot(f0_gt, color='b', linestyle='-')
111 | plt.plot(f0_pr, color='r', linestyle='-')
112 | self.writer.add_figure(name, fig, self.global_step)
113 |
114 | def get_interval_time(self, update=True):
115 | cur_time = time.time()
116 | time_interval = cur_time - self.last_time
117 | if update:
118 | self.last_time = cur_time
119 | return time_interval
120 |
121 | def get_total_time(self, to_str=True):
122 | total_time = time.time() - self.init_time
123 | if to_str:
124 | total_time = str(datetime.timedelta(
125 | seconds=total_time))[:-5]
126 | return total_time
127 |
128 | def save_model(
129 | self,
130 | model,
131 | optimizer,
132 | name='model',
133 | postfix='',
134 | to_json=False,
135 | config_dict=None):
136 | # path
137 | if postfix:
138 | postfix = '_' + postfix
139 | path_pt = os.path.join(
140 | self.expdir, name + postfix + '.pt')
141 |
142 | # check
143 | print(' [*] model checkpoint saved: {}'.format(path_pt))
144 |
145 | # save
146 | if optimizer is not None:
147 | torch.save({
148 | 'global_step': self.global_step,
149 | 'model': model.state_dict(),
150 | 'optimizer': optimizer.state_dict(),
151 | 'config_dict': config_dict
152 | }, path_pt)
153 | else:
154 | torch.save({
155 | 'global_step': self.global_step,
156 | 'model': model.state_dict(),
157 | 'config_dict': config_dict
158 | }, path_pt)
159 |
160 | # to json
161 | if to_json:
162 | path_json = os.path.join(
163 | self.expdir, name + '.json')
164 | utils.to_json(path_pt, path_json)
165 |
166 | def delete_model(self, name='model', postfix=''):
167 | # path
168 | if postfix:
169 | postfix = '_' + postfix
170 | path_pt = os.path.join(
171 | self.expdir, name + postfix + '.pt')
172 |
173 | # delete
174 | if os.path.exists(path_pt):
175 | os.remove(path_pt)
176 | print(' [*] model checkpoint deleted: {}'.format(path_pt))
177 |
178 | def global_step_increment(self):
179 | self.global_step += 1
180 |
--------------------------------------------------------------------------------
/train/savertools/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import json
4 | import torch
5 |
6 |
7 | def traverse_dir(
8 | root_dir,
9 | extensions,
10 | amount=None,
11 | str_include=None,
12 | str_exclude=None,
13 | is_pure=False,
14 | is_sort=False,
15 | is_ext=True):
16 | file_list = []
17 | cnt = 0
18 | for root, _, files in os.walk(root_dir):
19 | for file in files:
20 | if any([file.endswith(f".{ext}") for ext in extensions]):
21 | # path
22 | mix_path = os.path.join(root, file)
23 | pure_path = mix_path[len(root_dir) + 1:] if is_pure else mix_path
24 |
25 | # amount
26 | if (amount is not None) and (cnt == amount):
27 | if is_sort:
28 | file_list.sort()
29 | return file_list
30 |
31 | # check string
32 | if (str_include is not None) and (str_include not in pure_path):
33 | continue
34 | if (str_exclude is not None) and (str_exclude in pure_path):
35 | continue
36 |
37 | if not is_ext:
38 | ext = pure_path.split('.')[-1]
39 | pure_path = pure_path[:-(len(ext) + 1)]
40 | file_list.append(pure_path)
41 | cnt += 1
42 | if is_sort:
43 | file_list.sort()
44 | return file_list
45 |
46 |
47 | class DotDict(dict):
48 | def __getattr__(*args):
49 | val = dict.get(*args)
50 | return DotDict(val) if type(val) is dict else val
51 |
52 | __setattr__ = dict.__setitem__
53 | __delattr__ = dict.__delitem__
54 |
55 |
56 | def get_network_paras_amount(model_dict):
57 | info = dict()
58 | for model_name, model in model_dict.items():
59 | # all_params = sum(p.numel() for p in model.parameters())
60 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
61 |
62 | info[model_name] = trainable_params
63 | return info
64 |
65 |
66 | def load_config(path_config):
67 | with open(path_config, "r") as config:
68 | args = yaml.safe_load(config)
69 | args = DotDict(args)
70 | # print(args)
71 | return args
72 |
73 |
74 | def to_json(path_params, path_json):
75 | params = torch.load(path_params, map_location=torch.device('cpu'))
76 | raw_state_dict = {}
77 | for k, v in params.items():
78 | val = v.flatten().numpy().tolist()
79 | raw_state_dict[k] = val
80 |
81 | with open(path_json, 'w') as outfile:
82 | json.dump(raw_state_dict, outfile, indent="\t")
83 |
84 |
85 | def convert_tensor_to_numpy(tensor, is_squeeze=True):
86 | if is_squeeze:
87 | tensor = tensor.squeeze()
88 | if tensor.requires_grad:
89 | tensor = tensor.detach()
90 | if tensor.is_cuda:
91 | tensor = tensor.cpu()
92 | return tensor.numpy()
93 |
94 |
95 | def load_model(
96 | expdir,
97 | model,
98 | optimizer,
99 | name='model',
100 | postfix='',
101 | device='cpu'):
102 | if postfix == '':
103 | postfix = '_' + postfix
104 | path = os.path.join(expdir, name + postfix)
105 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
106 | global_step = 0
107 | if len(path_pt) > 0:
108 | steps = [s[len(path):] for s in path_pt]
109 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
110 | if maxstep >= 0:
111 | path_pt = path + str(maxstep) + '.pt'
112 | else:
113 | path_pt = path + 'best.pt'
114 | print(' [*] restoring model from', path_pt)
115 | ckpt = torch.load(path_pt, map_location=torch.device(device))
116 | global_step = ckpt['global_step']
117 | model.load_state_dict(ckpt['model'], strict=False)
118 | if ckpt.get('optimizer') != None and optimizer != None:
119 | optimizer.load_state_dict(ckpt['optimizer'])
120 | return global_step, model, optimizer
121 |
--------------------------------------------------------------------------------
/train/solver_wav.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import time
4 | from savertools.saver import Saver
5 | from savertools.saver import utils
6 | from torch import autocast
7 | from torch.cuda.amp import GradScaler
8 |
9 | from mir_eval.melody import raw_pitch_accuracy, to_cent_voicing, raw_chroma_accuracy, overall_accuracy
10 | from mir_eval.melody import voicing_recall, voicing_false_alarm
11 | import gc
12 |
13 | USE_MIR = True
14 |
15 |
16 | def test(args, model, loader_test, saver):
17 | print(' [*] testing...')
18 | model.eval()
19 |
20 | # losses
21 | _rpa = _rca = _oa = _vfa = _vr = test_loss = 0.
22 | _num_a = 0
23 |
24 | # intialization
25 | num_batches = len(loader_test)
26 | rtf_all = []
27 |
28 | # run
29 | with torch.no_grad():
30 | for bidx, data in enumerate(loader_test):
31 | fn = data[2][0]
32 | print('--------')
33 | print('{}/{} - {}'.format(bidx, num_batches, fn))
34 |
35 | # unpack data
36 | # for k in data.keys():
37 | # if not k.startswith('name'):
38 | # data[k] = data[k].to(args.device)
39 | for k in range(len(data)):
40 | if k < 2:
41 | data[k] = data[k].to(args.device)
42 | # print('>>', data[2][0])
43 |
44 | # forward
45 | st_time = time.time()
46 | f0 = model.infer(mel=data[0])
47 | ed_time = time.time()
48 |
49 | if USE_MIR:
50 | _f0 = f0.squeeze().cpu().numpy()
51 | _df0 = data[1].squeeze().cpu().numpy()
52 |
53 | time_slice = np.array([i * args.mel.hop_size * 1000 / args.mel.sr for i in range(len(_df0))])
54 | ref_v, ref_c, est_v, est_c = to_cent_voicing(time_slice, _df0, time_slice, _f0)
55 |
56 | rpa = raw_pitch_accuracy(ref_v, ref_c, est_v, est_c)
57 | rca = raw_chroma_accuracy(ref_v, ref_c, est_v, est_c)
58 | oa = overall_accuracy(ref_v, ref_c, est_v, est_c)
59 | vfa = voicing_false_alarm(ref_v, est_v)
60 | vr = voicing_recall(ref_v, est_v)
61 |
62 | # RTF
63 | run_time = ed_time - st_time
64 | song_time = f0.shape[1] * args.mel.hop_size / args.mel.sr
65 | rtf = run_time / song_time
66 | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
67 | if USE_MIR:
68 | print('RPA: {} | RCA: {} | OA: {} | VFA: {} | VR: {} |'.format(rpa, rca, oa, vfa, vr))
69 | rtf_all.append(rtf)
70 |
71 | # loss
72 | for i in range(args.train.batch_size):
73 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale)
74 | test_loss += loss.item()
75 |
76 | if USE_MIR:
77 | _rpa = _rpa + rpa
78 | _rca = _rca + rca
79 | _oa = _oa + oa
80 | _vfa = _vfa + vfa
81 | _vr = _vr + vr
82 | _num_a = _num_a + 1
83 |
84 | # log mel
85 | saver.log_spec(data[3][0], data[0], data[0])
86 |
87 | saver.log_f0(data[3][0], f0, data[1])
88 | saver.log_f0(data[3][0], f0, data[1], inuv=True)
89 |
90 | # report
91 | test_loss /= args.train.batch_size
92 | test_loss /= num_batches
93 |
94 | if USE_MIR:
95 | _rpa /= _num_a
96 |
97 | _rca /= _num_a
98 |
99 | _oa /= _num_a
100 |
101 | _vfa /= _num_a
102 |
103 | _vr /= _num_a
104 |
105 | # check
106 | print(' [test_loss] test_loss:', test_loss)
107 | print(' Real Time Factor', np.mean(rtf_all))
108 | return test_loss, _rpa, _rca, _oa, _vfa, _vr
109 |
110 |
111 | def train(args, initial_global_step, model, optimizer, scheduler, loader_train, loader_test):
112 | # saver
113 | saver = Saver(args, initial_global_step=initial_global_step)
114 |
115 | # model size
116 | params_count = utils.get_network_paras_amount({'model': model})
117 | saver.log_info('--- model size ---')
118 | saver.log_info(params_count)
119 |
120 | # run
121 | num_batches = len(loader_train)
122 | model.train()
123 | saver.log_info('======= start training =======')
124 | scaler = GradScaler()
125 | if args.train.amp_dtype == 'fp32':
126 | dtype = torch.float32
127 | elif args.train.amp_dtype == 'fp16':
128 | dtype = torch.float16
129 | elif args.train.amp_dtype == 'bf16':
130 | dtype = torch.bfloat16
131 | else:
132 | raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
133 | for epoch in range(args.train.epochs):
134 | train_one_epoch(loader_train, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches,
135 | loader_test)
136 | # 手动gc,防止内存泄漏
137 | gc.collect()
138 |
139 |
140 | def train_one_epoch(loader_train, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches,
141 | loader_test):
142 | for batch_idx, data in enumerate(loader_train):
143 | train_one_step(batch_idx, data, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches,
144 | loader_test)
145 |
146 |
147 | def train_one_step(batch_idx, data, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches,
148 | loader_test):
149 | saver.global_step_increment()
150 | optimizer.zero_grad()
151 |
152 | # unpack data
153 | for k in range(len(data)):
154 | if k < 2:
155 | data[k] = data[k].to(args.device)
156 | # print('>>', data[2][0])
157 |
158 | # forward
159 | if dtype == torch.float32:
160 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale)
161 | else:
162 | with autocast(device_type=args.device, dtype=dtype):
163 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale)
164 |
165 | # handle nan loss
166 | if torch.isnan(loss):
167 | # raise ValueError(' [x] nan loss ')
168 | print(' [x] nan loss ')
169 | loss = None
170 | return
171 | else:
172 | # backpropagate
173 | if dtype == torch.float32:
174 | loss.backward()
175 | optimizer.step()
176 | else:
177 | scaler.scale(loss).backward()
178 | scaler.step(optimizer)
179 | scaler.update()
180 | scheduler.step()
181 |
182 | # log loss
183 | if saver.global_step % args.train.interval_log == 0:
184 | current_lr = optimizer.param_groups[0]['lr']
185 | saver.log_info(
186 | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
187 | epoch,
188 | batch_idx,
189 | num_batches,
190 | args.env.expdir,
191 | args.train.interval_log / saver.get_interval_time(),
192 | current_lr,
193 | loss.item(),
194 | saver.get_total_time(),
195 | saver.global_step
196 | )
197 | )
198 |
199 | saver.log_value({
200 | 'train/loss': loss.item()
201 | })
202 |
203 | saver.log_value({
204 | 'train/lr': current_lr
205 | })
206 |
207 | # validation
208 | if saver.global_step % args.train.interval_val == 0:
209 | optimizer_save = optimizer if args.train.save_opt else None
210 |
211 | # save latest
212 | saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}', config_dict=dict(args))
213 | last_val_step = saver.global_step - args.train.interval_val
214 | if last_val_step % args.train.interval_force_save != 0:
215 | saver.delete_model(postfix=f'{last_val_step}')
216 |
217 | # run testing set
218 | test_loss, rpa, rca, oa, vfa, vr = test(args, model, loader_test, saver)
219 |
220 | # log loss
221 | saver.log_info(
222 | ' --- --- \nloss: {:.3f}. '.format(
223 | test_loss,
224 | )
225 | )
226 |
227 | saver.log_value({
228 | 'validation/loss': test_loss
229 | })
230 | if USE_MIR:
231 | saver.log_value({
232 | 'validation/rpa': rpa
233 | })
234 | saver.log_value({
235 | 'validation/rca': rca
236 | })
237 | saver.log_value({
238 | 'validation/oa': oa
239 | })
240 | saver.log_value({
241 | 'validation/vfa': vfa
242 | })
243 | saver.log_value({
244 | 'validation/vr': vr
245 | })
246 |
247 | model.train()
248 |
--------------------------------------------------------------------------------
/train/train_wav.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch.optim import lr_scheduler
4 | from savertools import utils
5 | from data_loaders_wav import get_data_loaders
6 | from solver_wav import train
7 | import torchfcpe
8 |
9 |
10 | def parse_args(args=None, namespace=None):
11 | """Parse command-line arguments."""
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | "-c",
15 | "--config",
16 | type=str,
17 | required=True,
18 | help="path to the config file")
19 | parser.add_argument(
20 | "-j",
21 | "--jump",
22 | type=str,
23 | required=False,
24 | default=None,
25 | help="jump cache for redis")
26 | return parser.parse_args(args=args, namespace=namespace)
27 |
28 |
29 | if __name__ == '__main__':
30 | # parse commands
31 | cmd = parse_args()
32 |
33 | # load config
34 | args = utils.load_config(cmd.config)
35 | print(' > config:', cmd.config)
36 | print(' > exp:', args.env.expdir)
37 |
38 | # load model
39 | model = torchfcpe.spawn_model(args)
40 |
41 | # load parameters
42 | optimizer = torch.optim.AdamW(model.parameters())
43 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
44 | for param_group in optimizer.param_groups:
45 | param_group['initial_lr'] = args.train.lr
46 | param_group['lr'] = args.train.lr * args.train.gamma ** max((initial_global_step - 2) // args.train.decay_step,
47 | 0)
48 | param_group['weight_decay'] = args.train.weight_decay
49 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,
50 | last_epoch=initial_global_step - 2)
51 |
52 | # device
53 | if args.device == 'cuda':
54 | torch.cuda.set_device(args.env.gpu_id)
55 | model.to(args.device)
56 |
57 | for state in optimizer.state.values():
58 | for k, v in state.items():
59 | if torch.is_tensor(v):
60 | state[k] = v.to(args.device)
61 |
62 | # datas
63 | if (str(cmd.jump) == 'True') or (str(cmd.jump) == 'true'):
64 | jump = True
65 | else:
66 | jump = False
67 | loader_train, loader_valid = get_data_loaders(args)
68 |
69 | # run
70 | train(args, initial_global_step, model, optimizer, scheduler, loader_train, loader_valid)
71 |
--------------------------------------------------------------------------------
/train/utils_1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from scipy import signal as sg
4 | import pdb
5 |
6 |
7 | def make_lowshelf(g, fc, Q, fs=44100):
8 | """Generate filter coefficients for 2nd order Lowshelf filter.
9 | This function follows the code from the JUCE DSP library
10 | which can be found in `juce_IIRFilter.cpp`.
11 |
12 | The design equations are based upon those found in the Cookbook
13 | formulae for audio equalizer biquad filter coefficients
14 | by Robert Bristow-Johnson.
15 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
16 | Args:
17 | g (float): Gain factor in dB.
18 | fc (float): Cutoff frequency in Hz.
19 | Q (float): Q factor.
20 | fs (float): Sampling frequency in Hz.
21 | Returns:
22 | tuple: (b, a) filter coefficients
23 | """
24 | # convert gain from dB to linear
25 | g = np.power(10,(g/20))
26 |
27 | # initial values
28 | A = np.max([0.0, np.sqrt(g)])
29 | aminus1 = A - 1
30 | aplus1 = A + 1
31 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
32 | coso = np.cos(omega)
33 | beta = np.sin(omega) * np.sqrt(A) / Q
34 | aminus1TimesCoso = aminus1 * coso
35 |
36 | # coefs calculation
37 | b0 = A * (aplus1 - aminus1TimesCoso + beta)
38 | b1 = A * 2 * (aminus1 - aplus1 * coso)
39 | b2 = A * (aplus1 - aminus1TimesCoso - beta)
40 | a0 = aplus1 + aminus1TimesCoso + beta
41 | a1 = -2 * (aminus1 + aplus1 * coso)
42 | a2 = aplus1 + aminus1TimesCoso - beta
43 |
44 | # output coefs
45 | #b = np.array([b0/a0, b1/a0, b2/a0])
46 | #a = np.array([a0/a0, a1/a0, a2/a0])
47 |
48 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]])
49 |
50 |
51 |
52 | def make_highself(g, fc, Q, fs=44100):
53 | """Generate filter coefficients for 2nd order Highshelf filter.
54 | This function follows the code from the JUCE DSP library
55 | which can be found in `juce_IIRFilter.cpp`.
56 |
57 | The design equations are based upon those found in the Cookbook
58 | formulae for audio equalizer biquad filter coefficients
59 | by Robert Bristow-Johnson.
60 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
61 | Args:
62 | g (float): Gain factor in dB.
63 | fc (float): Cutoff frequency in Hz.
64 | Q (float): Q factor.
65 | fs (float): Sampling frequency in Hz.
66 | Returns:
67 | tuple: (b, a) filter coefficients
68 | """
69 | # convert gain from dB to linear
70 | g = np.power(10,(g/20))
71 |
72 | # initial values
73 | A = np.max([0.0, np.sqrt(g)])
74 | aminus1 = A - 1
75 | aplus1 = A + 1
76 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
77 | coso = np.cos(omega)
78 | beta = np.sin(omega) * np.sqrt(A) / Q
79 | aminus1TimesCoso = aminus1 * coso
80 |
81 | # coefs calculation
82 | b0 = A * (aplus1 + aminus1TimesCoso + beta)
83 | b1 = A * -2 * (aminus1 + aplus1 * coso)
84 | b2 = A * (aplus1 + aminus1TimesCoso - beta)
85 | a0 = aplus1 - aminus1TimesCoso + beta
86 | a1 = 2 * (aminus1 - aplus1 * coso)
87 | a2 = aplus1 - aminus1TimesCoso - beta
88 |
89 | # output coefs
90 | #b = np.array([b0/a0, b1/a0, b2/a0])
91 | #a = np.array([a0/a0, a1/a0, a2/a0])
92 |
93 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]])
94 |
95 |
96 |
97 | def make_peaking(g, fc, Q, fs=44100):
98 | """Generate filter coefficients for 2nd order Peaking EQ.
99 | This function follows the code from the JUCE DSP library
100 | which can be found in `juce_IIRFilter.cpp`.
101 |
102 | The design equations are based upon those found in the Cookbook
103 | formulae for audio equalizer biquad filter coefficients
104 | by Robert Bristow-Johnson.
105 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
106 | Args:
107 | g (float): Gain factor in dB.
108 | fc (float): Cutoff frequency in Hz.
109 | Q (float): Q factor.
110 | fs (float): Sampling frequency in Hz.
111 | Returns:
112 | tuple: (b, a) filter coefficients
113 | """
114 | # convert gain from dB to linear
115 | g = np.power(10,(g/20))
116 |
117 | # initial values
118 | A = np.max([0.0, np.sqrt(g)])
119 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
120 | alpha = np.sin(omega) / (Q * 2)
121 | c2 = -2 * np.cos(omega)
122 | alphaTimesA = alpha * A
123 | alphaOverA = alpha / A
124 |
125 | # coefs calculation
126 | b0 = 1 + alphaTimesA
127 | b1 = c2
128 | b2 = 1 - alphaTimesA
129 | a0 = 1 + alphaOverA
130 | a1 = c2
131 | a2 = 1 - alphaOverA
132 |
133 | # output coefs
134 | #b = np.array([b0/a0, b1/a0, b2/a0])
135 | #a = np.array([a0/a0, a1/a0, a2/a0])
136 |
137 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]])
138 |
139 |
140 |
141 | def params2sos(G, Fc, Q, fs):
142 | """Convert 5 band EQ paramaters to 2nd order sections.
143 | Takes a vector with shape (13,) of denormalized EQ parameters
144 | and calculates filter coefficients for each of the 5 filters.
145 | These coefficients (2nd order sections) are then stored into a
146 | single (5,6) matrix. This matrix can be fed to `scipy.signal.sosfreqz()`
147 | in order to determine the frequency response of the cascasd of
148 | all five biquad filters.
149 | Args:
150 | x (float): Gain factor in dB.
151 | fs (float): Sampling frequency in Hz.
152 | Returns:
153 | ndarray: filter coefficients for 5 band EQ stored in (5,6) matrix.
154 | [[b1_0, b1_1, b1_2, a1_0, a1_1, a1_2], # lowshelf coefficients
155 | [b2_0, b2_1, b2_2, a2_0, a2_1, a2_2], # first band coefficients
156 | [b3_0, b3_1, b3_2, a3_0, a3_1, a3_2], # second band coefficients
157 | [b4_0, b4_1, b4_2, a4_0, a4_1, a4_2], # third band coefficients
158 | [b5_0, b5_1, b5_2, a5_0, a5_1, a5_2]] # highshelf coefficients
159 | """
160 | # generate filter coefficients from eq params
161 | c0 = make_lowshelf(G[0], Fc[0], Q[0], fs=fs)
162 | c1 = make_peaking (G[1], Fc[1], Q[1], fs=fs)
163 | c2 = make_peaking (G[2], Fc[2], Q[2], fs=fs)
164 | c3 = make_peaking (G[3], Fc[3], Q[3], fs=fs)
165 | c4 = make_peaking (G[4], Fc[4], Q[4], fs=fs)
166 | c5 = make_peaking (G[5], Fc[5], Q[5], fs=fs)
167 | c6 = make_peaking (G[6], Fc[6], Q[6], fs=fs)
168 | c7 = make_peaking (G[7], Fc[7], Q[7], fs=fs)
169 | c8 = make_peaking (G[8], Fc[8], Q[8], fs=fs)
170 | c9 = make_highself(G[9], Fc[9], Q[9], fs=fs)
171 |
172 | # stuff coefficients into second order sections structure
173 | sos = np.concatenate([c0,c1,c2,c3,c4,c5,c6,c7,c8,c9], axis=0)
174 |
175 | return sos
176 |
177 |
178 | import parselmouth
179 | def change_gender(x, fs, lo, hi, ratio_fs, ratio_ps, ratio_pr):
180 | s = parselmouth.Sound(x, sampling_frequency=fs)
181 | f0 = s.to_pitch_ac(pitch_floor=lo, pitch_ceiling=hi, time_step=0.8/lo)
182 | f0_np = f0.selected_array['frequency']
183 | f0_med = np.median(f0_np[f0_np!=0]).item()
184 | ss = parselmouth.praat.call([s, f0], "Change gender", ratio_fs, f0_med*ratio_ps, ratio_pr, 1.0)
185 | return ss.values.squeeze(0)
186 |
187 | def change_gender_f0(x, fs, lo, hi, ratio_fs, new_f0_med, ratio_pr):
188 | s = parselmouth.Sound(x, sampling_frequency=fs)
189 | ss = parselmouth.praat.call(s, "Change gender", lo, hi, ratio_fs, new_f0_med, ratio_pr, 1.0)
190 | return ss.values.squeeze(0)
191 |
--------------------------------------------------------------------------------
/train/utils_all.py:
--------------------------------------------------------------------------------
1 | import colorednoise as cn
2 | import random
3 | import numpy as np
4 | from sklearn.metrics import mean_squared_error
5 | import torch
6 | import librosa
7 | import pyworld
8 | import soundfile
9 |
10 | from utils_1 import params2sos
11 | from scipy.signal import sosfilt
12 |
13 | Qmin, Qmax = 2, 5
14 |
15 |
16 | def add_noise(wav, noise_ratio=0.7, brown_noise_ratio=1.):
17 | beta = random.random() * 2 # the exponent
18 | y = cn.powerlaw_psd_gaussian(beta, wav.shape[0])
19 | m = np.sqrt(mean_squared_error(wav, np.zeros_like(y)))
20 |
21 | if beta >= 0 and beta <= 1.5:
22 | wav += (noise_ratio * random.random()) * m * y
23 | else:
24 | wav += (brown_noise_ratio * random.random()) * m * y
25 | return wav
26 |
27 |
28 | def add_noise_slice(wav, sr, duration, add_factor=0.50, noise_ratio=0.7, brown_noise_ratio=1.):
29 | slice_length = int(duration * sr)
30 | n_frames = int(wav.shape[-1] // slice_length)
31 | slice_length_noise = int(slice_length * add_factor)
32 | for n in range(n_frames):
33 | left, right = int(n * slice_length), int((n + 1) * slice_length)
34 | offset = random.randint(left, right - slice_length_noise)
35 | if wav[offset:offset + slice_length_noise].shape[0] != 0:
36 | wav[offset:offset + slice_length_noise] = add_noise(wav[offset:offset + slice_length_noise],
37 | noise_ratio=noise_ratio,
38 | brown_noise_ratio=brown_noise_ratio)
39 | return wav
40 |
41 |
42 | def add_mel_mask(mel, iszeropad=False, esp=1e-5):
43 | if iszeropad:
44 | return torch.ones_like(mel) * esp
45 | else:
46 | return (random.random() * 0.9 + 0.1) * torch.randn_like(mel)
47 |
48 |
49 | def add_mel_mask_slice(mel, sr, duration, hop_size=512, add_factor=0.3, vertical_offset=True, vertical_factor=0.05,
50 | iszeropad=True, islog=True, block_num=5, esp=1e-5):
51 | if islog:
52 | mel = torch.exp(mel)
53 | slice_length = int(duration * sr) // hop_size
54 | n_frames = int(mel.shape[-1] // slice_length)
55 | n_mels = mel.shape[-2]
56 | for n in range(n_frames):
57 | line_num = n_mels // block_num
58 | for i in range(block_num):
59 | now_vertical_factor = vertical_factor + random.random() * 0.1
60 | now_add_factor = add_factor + random.random() * 0.1
61 | slice_length_noise = int(slice_length * now_add_factor)
62 | if vertical_offset:
63 | v_offset = int(random.uniform(line_num * i, line_num * (i + 1) - now_vertical_factor))
64 | n_v_down = v_offset
65 | n_v_up = int(v_offset + now_vertical_factor * n_mels)
66 | else:
67 | n_v_down = 0
68 | n_v_up = n_mels
69 | left, right = int(n * slice_length), int((n + 1) * slice_length)
70 | offset = int(random.uniform(left, right - slice_length_noise))
71 | if mel[n_v_down:n_v_up, offset:offset + slice_length_noise].shape[-1] != 0:
72 | mel[n_v_down:n_v_up, offset:offset + slice_length_noise] = add_mel_mask(
73 | mel[n_v_down:n_v_up, offset:offset + slice_length_noise], iszeropad, esp)
74 | if islog:
75 | mel = torch.log(torch.clamp(mel, min=esp))
76 | return mel
77 |
78 |
79 | def random_eq(wav, sr):
80 | rng = np.random.default_rng()
81 | z = rng.uniform(0, 1, size=(10,))
82 | Q = Qmin * (Qmax / Qmin) ** z
83 | G = rng.uniform(-12, 12, size=(10,))
84 | Fc = np.exp(np.linspace(np.log(60), np.log(7600), 10))
85 | sos = params2sos(G, Fc, Q, sr)
86 | wav = sosfilt(sos, wav)
87 | peak = np.abs(wav).max()
88 | if peak > 0.98:
89 | wav = 0.98 * wav / peak
90 | return wav
91 |
92 |
93 | def worldSynthesize(wav, target_sr=44100, hop_length=512, fft_size=2048, f0_in=None):
94 | f0, t = pyworld.dio(wav.astype(np.double), fs=target_sr, frame_period=1000 * hop_length / target_sr)
95 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, target_sr)
96 | if f0 is not None:
97 | f0 = f0_in.astype(np.double)
98 | ap = pyworld.d4c(wav.astype(np.double), f0, t, target_sr, fft_size=fft_size)
99 | sp = pyworld.cheaptrick(wav.astype(np.double), f0, t, target_sr, fft_size=fft_size)
100 | synthesized = pyworld.synthesize(f0, sp, ap, target_sr, frame_period=1000 * hop_length / target_sr)
101 |
102 | peak = np.abs(synthesized).max()
103 | synthesized = 0.98 * synthesized / peak
104 |
105 | return synthesized, f0
106 |
107 |
108 | # soundfile.write(f'world_{wav_name}.wav', synthesized, target_sr)
109 | # np.save(f"f0_{wav_name}.npy",f0)
110 |
111 |
112 | def add_noise_snb(wav, snb, beta):
113 | # 将信噪比转换为信号与噪声的能量比例
114 | snb = 10 ** (snb / 10)
115 | noise = cn.powerlaw_psd_gaussian(beta, wav.shape[0])
116 | rms_signal = np.sqrt(np.mean(wav ** 2))
117 | rms_noise = np.sqrt(np.mean(noise ** 2))
118 | wav = wav + noise * (rms_signal / rms_noise) / snb
119 | return wav
120 |
121 |
122 | def add_noise_slice_snb(wav, sr, duration, add_factor=0.50, snb=0.7, beta=1.0):
123 | slice_length = int(duration * sr)
124 | n_frames = int(wav.shape[-1] // slice_length)
125 | slice_length_noise = int(slice_length * add_factor)
126 | for n in range(n_frames):
127 | left, right = int(n * slice_length), int((n + 1) * slice_length)
128 | offset = random.randint(left, right - slice_length_noise)
129 | if wav[offset:offset + slice_length_noise].shape[0] != 0:
130 | wav[offset:offset + slice_length_noise] = add_noise_snb(wav[offset:offset + slice_length_noise], snb, beta)
131 | return wav
132 |
133 |
134 | def add_pub_noise_snb(wav, snb):
135 | # 将信噪比转换为信号与噪声的能量比例
136 | import os
137 | noise_path = r'path/to/noise/data/dir'
138 | noise_list = os.listdir(noise_path)
139 | noise_path = os.path.join(noise_path, random.choice(noise_list))
140 | snb = 10 ** (snb / 10)
141 | noise, sr = librosa.load(noise_path, sr=16000)
142 | if len(wav) > len(noise):
143 | noise = np.tile(noise, len(wav) // len(noise) + 1)
144 | if len(wav) < len(noise):
145 | noise = noise[:len(wav)]
146 | rms_signal = np.sqrt(np.mean(wav ** 2))
147 | rms_noise = np.sqrt(np.mean(noise ** 2))
148 | wav = wav + noise * (rms_signal / rms_noise) / snb
149 | return wav
150 |
151 |
152 | def add_pub_noise_snb2(wav, snb):
153 | # 将信噪比转换为信号与噪声的能量比例
154 | import os
155 | noise_path = r'path/to/noise/data/dir'
156 | noise_list = os.listdir(noise_path)
157 | noise_path = os.path.join(noise_path, random.choice(noise_list))
158 | snb = 10 ** (snb / 10)
159 | noise, sr = librosa.load(noise_path, sr=16000)
160 | if len(wav) > len(noise):
161 | noise = np.tile(noise, len(wav) // len(noise) + 1)
162 | if len(wav) < len(noise):
163 | offset = int(random.choice(range(len(noise) - len(wav))))
164 | noise = noise[offset:offset + len(wav)]
165 | rms_signal = np.sqrt(np.mean(wav ** 2))
166 | rms_noise = np.sqrt(np.mean(noise ** 2))
167 | wav = wav + noise * (rms_signal / rms_noise) / snb
168 | return wav
169 |
--------------------------------------------------------------------------------