├── .gitignore ├── LICENSE.md ├── README.md ├── configs ├── config.yaml ├── data │ └── default.yaml ├── experiment │ ├── ploss.yaml │ ├── pretrain.yaml │ ├── resume_real.yaml │ └── resume_synth.yaml ├── model │ └── default.yaml ├── schedule │ ├── even.yaml │ ├── param.yaml │ ├── spec.yaml │ ├── switch.yaml │ └── switch_perc.yaml ├── synth │ ├── dataset │ │ ├── h1of.yaml │ │ ├── h1of_envno.yaml │ │ ├── h1ofcf.yaml │ │ ├── h1ofcf_envno.yaml │ │ ├── h1ofcfrev_envno.yaml │ │ └── h2of.yaml │ ├── h1of.yaml │ ├── h1of_f0.yaml │ ├── h1of_f0_envno.yaml │ ├── h1ofcf_f0.yaml │ ├── h1ofcf_f0_envno.yaml │ ├── h1ofcfrev_envno.yaml │ ├── h1ofcfrev_f0_envno.yaml │ ├── h2of.yaml │ ├── h2of_cf.yaml │ └── h2of_f0.yaml └── trainer │ └── default.yaml ├── diffsynth ├── __init__.py ├── data.py ├── estimator.py ├── f0.py ├── layers.py ├── loss.py ├── model.py ├── modelutils.py ├── modules │ ├── .DS_Store │ ├── __init__.py │ ├── delay.py │ ├── envelope.py │ ├── filter.py │ ├── fm.py │ ├── frequency.py │ ├── generators.py │ ├── harmor.py │ ├── lfo.py │ └── reverb.py ├── perceptual │ ├── __init__.py │ ├── crepe.py │ ├── openl3.py │ ├── perceptual.py │ └── wav2vec.py ├── processor.py ├── schedules.py ├── spectral.py ├── synthesizer.py ├── transforms.py └── util.py ├── final_dataset.ipynb ├── gen_dataset.py ├── plot.py ├── test.py ├── train.py └── trainclassifier.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode/ 3 | note/ 4 | __pycache__ 5 | results/ 6 | result/ 7 | use_results/ 8 | old/ 9 | *.pyc 10 | runs/ 11 | scripts/ 12 | output/ 13 | outputs/ 14 | myenv.yaml 15 | configs/exp -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Naotake Masuda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffsynth - a Differentiable Musical Synthesizer in PyTorch 2 | 3 | Synthesizer Sound Matching with Differentiable DSP @ ISMIR2021 4 | https://hyakuchiki.github.io/DiffSynthISMIR/ 5 | 6 | ## Features 7 | 8 | - Additive-subtractive synthesizer 9 | - FM synthesizer 10 | - ADSR envelopes, LFOs 11 | - Chorus/flanger, reverb effects 12 | - Parameter estimator network 13 | 14 | ## To-do 15 | 16 | - Training with perceptual loss doesn't work 17 | 18 | ## Training 19 | 20 | - p-loss model 21 | - `python train.py experiment=only_param_h2of trainer.gpus=1` 22 | - pretrain 23 | - `python train.py experiment=pretrain_h2of trainer.gpus=1` 24 | - resume real model 25 | - `python train.py experiment=resume_real_h2of trainer.gpus=1 trainer.resume_from_checkpoint=[pretrain ckpt absolute path]` 26 | - resume synth model 27 | - `python train.py experiment=resume_synth_h2of trainer.gpus=1 trainer.resume_from_checkpoint=[pretrain ckpt absolute path]` 28 | 29 | ## Notes 30 | 31 | Some code was ported to pytorch from [DDSP](https://github.com/magenta/ddsp) (Copyright 2019 Google LLC.). 32 | 33 | Several features have been added since ISMIR2021 (code is more readable, chorus effect, etc.) 34 | To reproduce results from ISMIR2021, revert to [Ver. May2021](https://github.com/hyakuchiki/diffsynth/commit/aca9585a8c0f8466166830dfed97bf222d7e1f40) 35 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - trainer: default.yaml 3 | - model: default.yaml 4 | - data: trainid.yaml 5 | - synth: h2of.yaml 6 | - schedule: switch.yaml 7 | - experiment: null 8 | - exp: null -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffsynth.data.IdOodDataModule 2 | 3 | id_dir: null 4 | ood_dir: null 5 | train_type: id 6 | batch_size: 64 7 | sample_rate: 16000 8 | length: 4.0 9 | num_workers: 8 10 | splits: [0.8, 0.1, 0.1] -------------------------------------------------------------------------------- /configs/experiment/ploss.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: default.yaml 5 | - override /model: default.yaml 6 | - override /data: default.yaml 7 | - override /synth: h2of.yaml 8 | - override /schedule: param.yaml 9 | 10 | data: 11 | id_dir: data/diffsynth_5-6/harmor_2oscfree 12 | ood_dir: data/nsynth-train 13 | batch_size: 64 14 | train_type: id 15 | 16 | trainer: 17 | max_epochs: 400 18 | gradient_clip_val: 1.0 19 | 20 | model: 21 | lr: 0.001 22 | decay_rate: 0.99 23 | estimator: 24 | _target_: diffsynth.estimator.MelEstimator 25 | hidden_size: 512 26 | sw_loss: 27 | fft_sizes: [64, 128, 256, 512, 1024, 2048] 28 | log_grad: true -------------------------------------------------------------------------------- /configs/experiment/pretrain.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: default.yaml 5 | - override /model: default.yaml 6 | - override /data: default.yaml 7 | - override /synth: h2of.yaml 8 | - override /schedule: switch.yaml 9 | 10 | data: 11 | id_dir: data/diffsynth_5-6/harmor_2oscfree 12 | ood_dir: data/nsynth-train 13 | batch_size: 64 14 | train_type: id 15 | 16 | trainer: 17 | max_epochs: 200 18 | gradient_clip_val: 1.0 19 | 20 | model: 21 | lr: 0.001 22 | decay_rate: 0.99 23 | estimator: 24 | _target_: diffsynth.estimator.MelEstimator 25 | hidden_size: 512 26 | sw_loss: 27 | fft_sizes: [64, 128, 256, 512, 1024, 2048] 28 | log_grad: true -------------------------------------------------------------------------------- /configs/experiment/resume_real.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: default.yaml 5 | - override /model: default.yaml 6 | - override /data: default.yaml 7 | - override /synth: h2of.yaml 8 | - override /schedule: switch.yaml 9 | 10 | data: 11 | id_dir: data/diffsynth_5-6/harmor_2oscfree 12 | ood_dir: data/nsynth-train 13 | batch_size: 64 14 | train_type: ood 15 | 16 | trainer: 17 | max_epochs: 200 18 | gradient_clip_val: 1.0 19 | 20 | model: 21 | lr: 0.001 22 | decay_rate: 0.99 23 | estimator: 24 | _target_: diffsynth.estimator.MelEstimator 25 | hidden_size: 512 26 | sw_loss: 27 | fft_sizes: [64, 128, 256, 512, 1024, 2048] 28 | log_grad: true -------------------------------------------------------------------------------- /configs/experiment/resume_synth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: default.yaml 5 | - override /model: default.yaml 6 | - override /data: default.yaml 7 | - override /synth: h2of.yaml 8 | - override /schedule: switch.yaml 9 | 10 | data: 11 | id_dir: data/diffsynth_5-6/harmor_2oscfree 12 | ood_dir: data/nsynth-train 13 | batch_size: 64 14 | train_type: id 15 | 16 | trainer: 17 | max_epochs: 200 18 | gradient_clip_val: 1.0 19 | 20 | model: 21 | lr: 0.001 22 | decay_rate: 0.99 23 | estimator: 24 | _target_: diffsynth.estimator.MelEstimator 25 | hidden_size: 512 26 | sw_loss: 27 | fft_sizes: [64, 128, 256, 512, 1024, 2048] 28 | log_grad: true -------------------------------------------------------------------------------- /configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.001 2 | decay_rate: 0.99 3 | 4 | estimator: 5 | _target_: diffsynth.estimator.MelEstimator 6 | hidden_size: 512 7 | channels: 64 8 | sample_rate: ${data.sample_rate} 9 | n_fft: 1024 10 | hop: 256 11 | 12 | sw_loss: 13 | _target_: diffsynth.loss.SpecWaveLoss 14 | mag_w: 1.0 15 | log_mag_w: 1.0 16 | fft_sizes: [64, 128, 256, 512, 1024, 2048] 17 | 18 | perc_model: null 19 | 20 | log_grad: true 21 | 22 | f0_encoder: null -------------------------------------------------------------------------------- /configs/schedule/even.yaml: -------------------------------------------------------------------------------- 1 | name: even 2 | 3 | param_w: 4 | type: linear 5 | start: 12500 6 | warm: 37500 7 | start_v: 10.0 8 | end_v: 5.0 9 | 10 | sw_w: 11 | type: linear 12 | start: 12500 13 | warm: 37500 14 | start_v: 0.0 15 | end_v: 0.5 -------------------------------------------------------------------------------- /configs/schedule/param.yaml: -------------------------------------------------------------------------------- 1 | name: param 2 | param_w: 10.0 -------------------------------------------------------------------------------- /configs/schedule/spec.yaml: -------------------------------------------------------------------------------- 1 | name: spec 2 | 3 | sw_w: 1.0 -------------------------------------------------------------------------------- /configs/schedule/switch.yaml: -------------------------------------------------------------------------------- 1 | name: switch 2 | param_w: 3 | type: linear 4 | start: 12500 5 | warm: 37500 6 | start_v: 10.0 7 | end_v: 0.0 8 | 9 | sw_w: 10 | type: linear 11 | start: 12500 12 | warm: 37500 13 | start_v: 0.0 14 | end_v: 1.0 -------------------------------------------------------------------------------- /configs/schedule/switch_perc.yaml: -------------------------------------------------------------------------------- 1 | name: switch_perc 2 | 3 | param_w: 4 | type: linear 5 | start: 12500 6 | warm: 37500 7 | start_v: 10.0 8 | end_v: 0.0 9 | 10 | sw_w: 11 | type: linear 12 | start: 12500 13 | warm: 37500 14 | start_v: 0.0 15 | end_v: 1.0 16 | 17 | perc_w: 18 | type: linear 19 | start: 12500 20 | warm: 37500 21 | start_v: 0.0 22 | end_v: 0.1 -------------------------------------------------------------------------------- /configs/synth/dataset/h1of.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_1oscfree 2 | # used for generating dataset for h2of.yaml 3 | # generates dynamic parameters (cutoff, amplitudes) using envelope 4 | 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: MULT 44 | cutoff: envc 45 | q: Q_FILT 46 | 47 | fixed_params: 48 | NOTE_OFF: 0.75 49 | AMP_FLOOR: 0 50 | static_params: [BFRQ, M_OSC, MULT, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C] 51 | save_params: [harmor_amplitudes, harmor_osc_mix, harmor_f0_hz, harmor_cutoff, harmor_q] -------------------------------------------------------------------------------- /configs/synth/dataset/h1of_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1of_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | 47 | fixed_params: 48 | AMP_FLOOR: 0 49 | static_params: [BFRQ, M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF] -------------------------------------------------------------------------------- /configs/synth/dataset/h1ofcf.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcf 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope output not envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ChorusFlanger 49 | sample_rate: 16000 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | fixed_params: 57 | AMP_FLOOR: 0 58 | NOISE_A: 0 59 | NOISE_C: 0 60 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX] 61 | save_params: [harmor_amplitudes, harmor_osc_mix, harmor_f0_hz, harmor_f0_mult, harmor_cutoff, harmor_q, chorus_delay_ms, chorus_rate, chorus_depth, chorus_mix] -------------------------------------------------------------------------------- /configs/synth/dataset/h1ofcf_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcf_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ChorusFlanger 49 | sample_rate: 16000 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | fixed_params: 57 | AMP_FLOOR: 0 58 | NOISE_A: 0 59 | NOISE_C: 0 60 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX] -------------------------------------------------------------------------------- /configs/synth/dataset/h1ofcfrev_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcfrev_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ChorusFlanger 49 | sample_rate: 16000 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | reverb: 57 | config: 58 | _target_: diffsynth.modules.reverb.DecayReverb 59 | connections: 60 | audio: chorus 61 | gain: REV_G 62 | decay: REV_D 63 | fixed_params: 64 | AMP_FLOOR: 0 65 | NOISE_A: 0 66 | NOISE_C: 0 67 | CF_DEPTH: 0.1 68 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX, REV_G, REV_D, BFRQ] -------------------------------------------------------------------------------- /configs/synth/dataset/h2of.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_2oscfree 2 | # used for generating dataset for h2of.yaml 3 | # generates dynamic parameters (cutoff, amplitudes) using envelope 4 | 5 | n_oscs: 2 6 | max_value: 0.6 7 | 8 | dag: 9 | enva: 10 | config: 11 | _target_: diffsynth.modules.envelope.ADSREnvelope 12 | max_value: 0.6 13 | channels: 2 14 | connections: 15 | floor: AMP_FLOOR 16 | peak: PEAK_A 17 | attack: AT_A 18 | decay: DE_A 19 | sus_level: SU_A 20 | release: RE_A 21 | note_off: NOTE_OFF 22 | noise_mag: NOISE_A 23 | envc: 24 | config: 25 | _target_: diffsynth.modules.envelope.ADSREnvelope 26 | channels: 1 27 | connections: 28 | floor: CUT_FLOOR 29 | peak: PEAK_C 30 | attack: AT_C 31 | decay: DE_C 32 | sus_level: SU_C 33 | release: RE_C 34 | note_off: NOTE_OFF 35 | noise_mag: NOISE_C 36 | harmor: 37 | config: 38 | _target_: diffsynth.modules.harmor.Harmor 39 | sample_rate: 16000 40 | sep_amp: true 41 | n_oscs: 2 42 | connections: 43 | amplitudes: enva 44 | osc_mix: M_OSC 45 | f0_hz: BFRQ 46 | f0_mult: MULT 47 | cutoff: envc 48 | q: Q_FILT 49 | 50 | fixed_params: 51 | NOTE_OFF: 0.75 52 | AMP_FLOOR: 0 53 | static_params: [BFRQ, M_OSC, MULT, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C] -------------------------------------------------------------------------------- /configs/synth/h1of.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_1oscfree 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 1 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: BFRQ 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | fixed_params: 17 | static_params: [BFRQ, M_OSC, MULT, Q_FILT] -------------------------------------------------------------------------------- /configs/synth/h1of_f0.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_2oscfree 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 1 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: f0_hz 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | fixed_params: 17 | f0_hz: null 18 | static_params: [M_OSC, MULT, Q_FILT] -------------------------------------------------------------------------------- /configs/synth/h1of_f0_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1of_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | # f0 is external 6 | dag: 7 | enva: 8 | config: 9 | _target_: diffsynth.modules.envelope.ADSREnvelope 10 | max_value: 1.0 11 | channels: 1 12 | connections: 13 | floor: AMP_FLOOR 14 | peak: PEAK_A 15 | attack: AT_A 16 | decay: DE_A 17 | sus_level: SU_A 18 | release: RE_A 19 | note_off: NOTE_OFF 20 | noise_mag: NOISE_A 21 | envc: 22 | config: 23 | _target_: diffsynth.modules.envelope.ADSREnvelope 24 | channels: 1 25 | connections: 26 | floor: CUT_FLOOR 27 | peak: PEAK_C 28 | attack: AT_C 29 | decay: DE_C 30 | sus_level: SU_C 31 | release: RE_C 32 | note_off: NOTE_OFF 33 | noise_mag: NOISE_C 34 | harmor: 35 | config: 36 | _target_: diffsynth.modules.harmor.Harmor 37 | sample_rate: 16000 38 | sep_amp: true 39 | n_oscs: 1 40 | connections: 41 | amplitudes: enva 42 | osc_mix: M_OSC 43 | f0_hz: f0_hz 44 | f0_mult: DUMMY 45 | cutoff: envc 46 | q: Q_FILT 47 | 48 | fixed_params: 49 | AMP_FLOOR: 0 50 | f0_hz: null 51 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF] 52 | -------------------------------------------------------------------------------- /configs/synth/h1ofcf_f0.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_1ofccf_f0 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 1 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: BFRQ 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | chorus: 17 | config: 18 | _target_: diffsynth.modules.delay.ChorusFlanger 19 | sample_rate: 16000 20 | connections: 21 | audio: harmor 22 | delay_ms: CF_DELAY 23 | rate: CF_RATE 24 | depth: CF_DEPTH 25 | mix: CF_MIX 26 | fixed_params: 27 | BFRQ: null 28 | static_params: [M_OSC, MULT, Q_FILT, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX] -------------------------------------------------------------------------------- /configs/synth/h1ofcf_f0_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcf_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: 16000 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ChorusFlanger 49 | sample_rate: 16000 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | fixed_params: 57 | AMP_FLOOR: 0 58 | NOISE_A: 0 59 | NOISE_C: 0 60 | BFRQ: null 61 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX] -------------------------------------------------------------------------------- /configs/synth/h1ofcfrev_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcfrev_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: ${data.sample_rate} 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ModulatedDelay 49 | sample_rate: ${data.sample_rate} 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | reverb: 57 | config: 58 | _target_: diffsynth.modules.reverb.DecayReverb 59 | connections: 60 | audio: chorus 61 | gain: REV_G 62 | decay: REV_D 63 | fixed_params: 64 | AMP_FLOOR: 0 65 | NOISE_A: 0 66 | NOISE_C: 0 67 | CF_DEPTH: 0.1 68 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX, REV_G, REV_D, BFRQ] -------------------------------------------------------------------------------- /configs/synth/h1ofcfrev_f0_envno.yaml: -------------------------------------------------------------------------------- 1 | name: h1ofcfrev_envno 2 | # generates dynamic parameters (cutoff, amplitudes) using envelope 3 | # note off position is random 4 | # save envelope params 5 | dag: 6 | enva: 7 | config: 8 | _target_: diffsynth.modules.envelope.ADSREnvelope 9 | max_value: 1.0 10 | channels: 1 11 | connections: 12 | floor: AMP_FLOOR 13 | peak: PEAK_A 14 | attack: AT_A 15 | decay: DE_A 16 | sus_level: SU_A 17 | release: RE_A 18 | note_off: NOTE_OFF 19 | noise_mag: NOISE_A 20 | envc: 21 | config: 22 | _target_: diffsynth.modules.envelope.ADSREnvelope 23 | channels: 1 24 | connections: 25 | floor: CUT_FLOOR 26 | peak: PEAK_C 27 | attack: AT_C 28 | decay: DE_C 29 | sus_level: SU_C 30 | release: RE_C 31 | note_off: NOTE_OFF 32 | noise_mag: NOISE_C 33 | harmor: 34 | config: 35 | _target_: diffsynth.modules.harmor.Harmor 36 | sample_rate: ${data.sample_rate} 37 | sep_amp: true 38 | n_oscs: 1 39 | connections: 40 | amplitudes: enva 41 | osc_mix: M_OSC 42 | f0_hz: BFRQ 43 | f0_mult: DUMMY 44 | cutoff: envc 45 | q: Q_FILT 46 | chorus: 47 | config: 48 | _target_: diffsynth.modules.delay.ChorusFlanger 49 | sample_rate: ${data.sample_rate} 50 | connections: 51 | audio: harmor 52 | delay_ms: CF_DELAY 53 | rate: CF_RATE 54 | depth: CF_DEPTH 55 | mix: CF_MIX 56 | reverb: 57 | config: 58 | _target_: diffsynth.modules.reverb.DecayReverb 59 | connections: 60 | audio: chorus 61 | gain: REV_G 62 | decay: REV_D 63 | fixed_params: 64 | AMP_FLOOR: 0 65 | NOISE_A: 0 66 | NOISE_C: 0 67 | BFRQ: null 68 | CF_DEPTH: 0.1 69 | static_params: [M_OSC, DUMMY, Q_FILT, NOISE_C, NOISE_A, AT_A, DE_A, SU_A, RE_A, AMP_FLOOR, PEAK_A, AT_C, DE_C, SU_C, RE_C, CUT_FLOOR, PEAK_C, NOTE_OFF, CF_DELAY, CF_RATE, CF_DEPTH, CF_MIX, REV_G, REV_D] -------------------------------------------------------------------------------- /configs/synth/h2of.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_2oscfree 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 2 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: BFRQ 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | fixed_params: 17 | static_params: [BFRQ, M_OSC, MULT, Q_FILT] -------------------------------------------------------------------------------- /configs/synth/h2of_cf.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_cf 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 2 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: BFRQ 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | chorus: 17 | config: 18 | _target_: diffsynth.modules.delay.ModulatedDelay 19 | sample_rate: ${data.sample_rate} 20 | connections: 21 | audio: harmor 22 | delay_ms: MD_DELAY 23 | phase: MD_PHASE 24 | depth: MD_DEPTH 25 | mix: MD_MIX 26 | fixed_params: 27 | static_params: [BFRQ, M_OSC, MULT, Q_FILT, MD_DELAY, MD_DEPTH, MD_MIX] -------------------------------------------------------------------------------- /configs/synth/h2of_f0.yaml: -------------------------------------------------------------------------------- 1 | name: harmor_2oscfree 2 | dag: 3 | harmor: 4 | config: 5 | _target_: diffsynth.modules.harmor.Harmor 6 | sample_rate: ${data.sample_rate} 7 | sep_amp: true 8 | n_oscs: 2 9 | connections: 10 | amplitudes: AMP 11 | osc_mix: M_OSC 12 | f0_hz: f0_hz 13 | f0_mult: MULT 14 | cutoff: CUTOFF 15 | q: Q_FILT 16 | fixed_params: 17 | f0_hz: null 18 | static_params: [f0_hz, M_OSC, MULT, Q_FILT] -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | max_epochs: 200 4 | gradient_clip_val: 1.0 5 | checkpoint_callback: True 6 | gpus: null 7 | num_nodes: 1 8 | benchmark: False 9 | overfit_batches: 0.0 10 | limit_train_batches: 1.0 11 | auto_lr_find: False 12 | resume_from_checkpoint: null 13 | weights_summary: "top" 14 | 15 | default_root_dir: null 16 | process_position: 0 17 | num_processes: 1 18 | auto_select_gpus: False 19 | tpu_cores: null 20 | log_gpu_memory: null 21 | progress_bar_refresh_rate: 1 22 | track_grad_norm: -1 23 | check_val_every_n_epoch: 1 24 | fast_dev_run: False 25 | accumulate_grad_batches: 1 26 | min_epochs: 1 27 | max_steps: null 28 | min_steps: null 29 | limit_val_batches: 1.0 30 | limit_test_batches: 1.0 31 | val_check_interval: 1.0 32 | flush_logs_every_n_steps: 100 33 | log_every_n_steps: 50 34 | accelerator: null 35 | sync_batchnorm: False 36 | precision: 32 37 | num_sanity_val_steps: 2 38 | truncated_bptt_steps: null 39 | profiler: null 40 | deterministic: False 41 | reload_dataloaders_every_epoch: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False -------------------------------------------------------------------------------- /diffsynth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyakuchiki/diffsynth/4fc64ba04608cbaf8e016e4878fc50bedfca4da3/diffsynth/__init__.py -------------------------------------------------------------------------------- /diffsynth/data.py: -------------------------------------------------------------------------------- 1 | import os, glob, functools 2 | import librosa 3 | import torch 4 | from torch.utils.data import Subset, Dataset, DataLoader, random_split, ConcatDataset, SubsetRandomSampler, BatchSampler 5 | import pytorch_lightning as pl 6 | import numpy as np 7 | from diffsynth.f0 import process_f0 8 | 9 | def mix_iterable(dl_a, dl_b): 10 | for i, j in zip(dl_a, dl_b): 11 | yield i 12 | yield j 13 | 14 | class ReiteratableWrapper(): 15 | def __init__(self, f, length): 16 | self._f = f 17 | self.length = length 18 | 19 | def __iter__(self): 20 | # make generator 21 | return self._f() 22 | 23 | def __len__(self): 24 | return self.length 25 | 26 | class WaveParamDataset(Dataset): 27 | def __init__(self, base_dir, sample_rate=16000, length=4.0, params=True, f0=False): 28 | self.base_dir = base_dir 29 | self.audio_dir = os.path.join(base_dir, 'audio') 30 | self.raw_files = sorted(glob.glob(os.path.join(self.audio_dir, '*.wav'))) 31 | print('loaded {0} files'.format(len(self.raw_files))) 32 | self.length = length 33 | self.sample_rate = sample_rate 34 | self.params = params 35 | self.f0 = f0 36 | if f0: 37 | self.f0_dir = os.path.join(base_dir, 'f0') 38 | assert os.path.exists(self.f0_dir) 39 | # all the f0 files should already be written 40 | # with the same name as the audio 41 | self.f0_files = sorted(glob.glob(os.path.join(self.f0_dir, '*.pt'))) 42 | if params: 43 | self.param_dir = os.path.join(base_dir, 'param') 44 | assert os.path.exists(self.param_dir) 45 | # all the files should already be written 46 | self.param_files = sorted(glob.glob(os.path.join(self.param_dir, '*.pt'))) 47 | 48 | def __getitem__(self, idx): 49 | raw_path = self.raw_files[idx] 50 | audio, _sr = librosa.load(raw_path, sr=self.sample_rate, duration=self.length) 51 | assert audio.shape[0] == self.length * self.sample_rate 52 | data = {'audio': audio} 53 | if self.f0: 54 | f0, periodicity = torch.load(self.f0_files[idx]) 55 | f0_hz = process_f0(f0, periodicity) 56 | data['BFRQ'] = f0_hz.unsqueeze(-1) 57 | if self.params: 58 | params = torch.load(self.param_files[idx]) 59 | data['params'] = params 60 | return data 61 | 62 | def __len__(self): 63 | return len(self.raw_files) 64 | 65 | class IdOodDataModule(pl.LightningDataModule): 66 | def __init__(self, id_dir, ood_dir, train_type, batch_size, sample_rate=16000, length=4.0, num_workers=8, splits=[.8, .1, .1], f0=False): 67 | super().__init__() 68 | self.id_dir = id_dir 69 | self.ood_dir = ood_dir 70 | assert train_type in ['id', 'ood', 'mixed'] 71 | self.train_type = train_type 72 | self.splits = splits 73 | self.sr = sample_rate 74 | self.l = length 75 | self.batch_size = batch_size 76 | self.num_workers = num_workers 77 | self.f0 = f0 78 | 79 | def create_split(self, dataset): 80 | dset_l = len(dataset) 81 | split_sizes = [int(dset_l*self.splits[0]), int(dset_l*self.splits[1])] 82 | split_sizes.append(dset_l - split_sizes[0] - split_sizes[1]) 83 | # should be seeded fine but probably better to split test set in some other way 84 | dset_train, dset_valid, dset_test = random_split(dataset, lengths=split_sizes) 85 | return {'train': dset_train, 'valid': dset_valid, 'test': dset_test} 86 | 87 | def setup(self, stage): 88 | id_dat = WaveParamDataset(self.id_dir, self.sr, self.l, True, self.f0) 89 | id_datasets = self.create_split(id_dat) 90 | # ood should be the same size as in-domain 91 | ood_dat = WaveParamDataset(self.ood_dir, self.sr, self.l, False, self.f0) 92 | indices = np.random.choice(len(ood_dat), len(id_dat), replace=False) 93 | ood_dat = Subset(ood_dat, indices) 94 | ood_datasets = self.create_split(ood_dat) 95 | self.id_datasets = id_datasets 96 | self.ood_datasets = ood_datasets 97 | assert len(id_datasets['train']) == len(ood_datasets['train']) 98 | if self.train_type == 'mixed': 99 | dat_len = len(id_datasets['train']) 100 | indices = np.random.choice(dat_len, dat_len//2, replace=False) 101 | self.train_set = ConcatDataset([Subset(id_datasets['train'], indices), Subset(ood_datasets['train'], indices)]) 102 | 103 | def train_dataloader(self): 104 | if self.train_type=='id': 105 | return DataLoader(self.id_datasets['train'], batch_size=self.batch_size, 106 | num_workers=self.num_workers, shuffle=True) 107 | elif self.train_type=='ood': 108 | return DataLoader(self.ood_datasets['train'], batch_size=self.batch_size, 109 | num_workers=self.num_workers, shuffle=True) 110 | elif self.train_type=='mixed': 111 | id_indices = list(range(len(self.train_set)//2)) 112 | ood_indices = list(range(len(self.train_set)//2, len(self.train_set))) 113 | id_samp = SubsetRandomSampler(id_indices) 114 | ood_samp = SubsetRandomSampler(ood_indices) 115 | id_batch_samp = BatchSampler(id_samp, batch_size=self.batch_size, drop_last=False) 116 | ood_batch_samp = BatchSampler(ood_samp, batch_size=self.batch_size, drop_last=False) 117 | generator = functools.partial(mix_iterable, id_batch_samp, ood_batch_samp) 118 | b_sampler = ReiteratableWrapper(generator, len(id_batch_samp)+len(ood_batch_samp)) 119 | return DataLoader(self.train_set, batch_sampler=b_sampler, num_workers=self.num_workers) 120 | 121 | def val_dataloader(self): 122 | return [DataLoader(self.id_datasets["valid"], batch_size=self.batch_size, num_workers=self.num_workers), 123 | DataLoader(self.ood_datasets["valid"], batch_size=self.batch_size, num_workers=self.num_workers)] 124 | 125 | def test_dataloader(self): 126 | return [DataLoader(self.id_datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers), 127 | DataLoader(self.ood_datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers)] -------------------------------------------------------------------------------- /diffsynth/estimator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from diffsynth.util import resample_frames 6 | from diffsynth.layers import Resnet1D, Normalize2d 7 | from diffsynth.transforms import LogTransform 8 | from nnAudio.Spectrogram import MelSpectrogram, MFCC 9 | 10 | from diffsynth.f0 import FMIN, FMAX 11 | 12 | class MFCCEstimator(nn.Module): 13 | def __init__(self, output_dim, n_mels=128, n_mfccs=30, n_fft=1024, hop=256, sample_rate=16000, num_layers=2, hidden_size=512, dropout_p=0.0, norm='instance'): 14 | super().__init__() 15 | self.mfcc = MFCC(sr=sample_rate, n_mfcc=n_mfccs, norm='ortho', verbose=True, hop_length=hop, n_fft=n_fft, n_mels=n_mels, center=True, sample_rate=sample_rate) 16 | self.norm = Normalize2d(norm) if norm else None 17 | self.gru = nn.GRU(n_mfccs, hidden_size, num_layers=num_layers, dropout=dropout_p, batch_first=True) 18 | self.output_dim = output_dim 19 | self.out = nn.Linear(hidden_size, output_dim) 20 | 21 | def forward(self, audio): 22 | x = self.mfcc(audio) 23 | x = self.norm(x) if self.norm else x 24 | x = x.permute(0, 2, 1).contiguous() 25 | # batch_size, n_frames, n_mfcc = x.shape 26 | output, _hidden = self.gru(x) 27 | # output: [batch_size, n_frames, self.output_dim] 28 | output = self.out(output) 29 | return torch.sigmoid(output) 30 | 31 | class MelEstimator(nn.Module): 32 | def __init__(self, output_dim, n_mels=128, n_fft=1024, hop=256, sample_rate=16000, channels=64, kernel_size=7, strides=[2,2,2], num_layers=1, hidden_size=512, dropout_p=0.0, bidirectional=False, norm='batch'): 33 | super().__init__() 34 | self.n_mels = n_mels 35 | self.channels = channels 36 | self.logmel = nn.Sequential(MelSpectrogram(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, hop_length=hop, center=True, power=1.0, htk=True, trainable_mel=False, trainable_STFT=False), LogTransform()) 37 | self.norm = Normalize2d(norm) if norm else None 38 | # Regular Conv 39 | self.convs = nn.ModuleList( 40 | [nn.Sequential(nn.Conv1d(1, channels, kernel_size, 41 | padding=kernel_size // 2, 42 | stride=strides[0]), nn.BatchNorm1d(channels), nn.ReLU())] 43 | + [nn.Sequential(nn.Conv1d(channels, channels, kernel_size, 44 | padding=kernel_size // 2, 45 | stride=strides[i]), nn.BatchNorm1d(channels), nn.ReLU()) 46 | for i in range(1, len(strides))]) 47 | self.l_out = self.get_downsampled_length()[-1] # downsampled in frequency dimension 48 | print('output dims after convolution', self.l_out) 49 | self.num_layers = num_layers 50 | self.hidden_size = hidden_size 51 | self.bidirectional = bidirectional 52 | self.gru = nn.GRU(self.l_out * channels, hidden_size, num_layers=num_layers, dropout=dropout_p, batch_first=True, bidirectional=bidirectional) 53 | self.out = nn.Linear(hidden_size*2 if bidirectional else hidden_size, output_dim) 54 | self.output_dim = output_dim 55 | 56 | def forward(self, audio): 57 | x = self.logmel(audio) 58 | x = self.norm(x) 59 | batch_size, n_mels, n_frames = x.shape 60 | x = x.permute(0, 2, 1).contiguous() 61 | x = x.view(-1, self.n_mels).unsqueeze(1) 62 | # x: [batch_size*n_frames, 1, n_mels] 63 | for i, conv in enumerate(self.convs): 64 | x = conv(x) 65 | x = x.view(batch_size, n_frames, self.channels, self.l_out) 66 | x = x.view(batch_size, n_frames, -1) 67 | D = 2 if self.bidirectional else 1 68 | output, _hidden = self.gru(x, torch.zeros(D * self.num_layers, batch_size, self.hidden_size, device=x.device)) 69 | # output: [batch_size, n_frames, self.output_dim] 70 | output = self.out(output) 71 | return torch.sigmoid(output) 72 | 73 | def get_downsampled_length(self): 74 | l = self.n_mels 75 | lengths = [l] 76 | for conv in self.convs: 77 | conv_module = conv[0] 78 | l = (l + 2 * conv_module.padding[0] - conv_module.dilation[0] * (conv_module.kernel_size[0] - 1) - 1) // conv_module.stride[0] + 1 79 | lengths.append(l) 80 | return lengths 81 | 82 | class F0MelEstimator(MelEstimator): 83 | def __init__(self, output_dim, n_mels=128, n_fft=1024, hop=256, sample_rate=16000, channels=64, kernel_size=7, strides=[2,2,2], num_layers=1, hidden_size=512, dropout_p=0.0, bidirectional=False, norm='batch'): 84 | super().__init__(output_dim, n_mels, n_fft, hop, sample_rate, channels, kernel_size, strides, num_layers, hidden_size, dropout_p, bidirectional, norm) 85 | self.gru = nn.GRU(self.l_out * channels + 1, hidden_size, num_layers=num_layers, dropout=dropout_p, batch_first=True, bidirectional=bidirectional) 86 | 87 | def forward(self, audio, f0): 88 | x = self.logmel(audio) 89 | x = self.norm(x) 90 | batch_size, n_mels, n_frames = x.shape 91 | x = x.permute(0, 2, 1).contiguous() 92 | x = x.view(-1, self.n_mels).unsqueeze(1) 93 | # x: [batch_size*n_frames, 1, n_mels] 94 | for i, conv in enumerate(self.convs): 95 | x = conv(x) 96 | x = x.view(batch_size, n_frames, self.channels, self.l_out) 97 | x = x.view(batch_size, n_frames, -1) 98 | f0 = resample_frames(f0, n_frames) 99 | f0 = (f0-FMIN)/(FMAX-FMIN) 100 | x = torch.cat([x, f0], dim=-1) 101 | D = 2 if self.bidirectional else 1 102 | output, _hidden = self.gru(x, torch.zeros(D * self.num_layers, batch_size, self.hidden_size, device=x.device)) 103 | # output: [batch_size, n_frames, self.output_dim] 104 | output = self.out(output) 105 | return torch.sigmoid(output) 106 | 107 | frame_setting_stride = { 108 | # n_downsample, stride 109 | "coarse": (5, 4), # hop: 1024 110 | "fine": (9, 2), # 512, too deep? 111 | "finer": (8, 2), # 256 112 | "finest": (6, 2) # 64 113 | } 114 | 115 | class FrameDilatedConvEstimator(nn.Module): 116 | """ 117 | Process raw waveform 118 | Similar to Jukebox 119 | """ 120 | def __init__(self, output_dim, frame_setting='finer', res_depth=4, channels=32, dilation_growth_rate=3, m_conv=1.0): 121 | """ 122 | Args: 123 | output_dims (int): output channels 124 | res_depth (int, optional): depth of each resnet. Defaults to 4. 125 | channels (int, optional): conv channels. Defaults to 32. 126 | dilation_growth_rate (int, optional): exponential growth of dilation. Defaults to 3. 127 | m_conv (float, optional): multiplier for resnet channels. Defaults to 1.0. 128 | """ 129 | super().__init__() 130 | self.n_downsample, self.stride = frame_setting_stride[frame_setting] 131 | blocks = [] 132 | kernel_size, pad = self.stride * 2, self.stride // 2 133 | for i in range(self.n_downsample): 134 | block = nn.Sequential( 135 | # downsampling conv, output size is L_in/stride 136 | nn.Conv1d(1 if i == 0 else channels, channels, kernel_size, self.stride, pad), 137 | # ResNet with growing dilation 138 | Resnet1D(channels, res_depth, m_conv, dilation_growth_rate), 139 | ) 140 | blocks.append(block) 141 | # # doesn't change size 142 | # block = nn.Conv1d(channels, output_dims, 3, 1, 1) # output:(batch, output_dims, n_frames) 143 | # blocks.append(block) 144 | self.model = nn.Sequential(*blocks) 145 | self.out = nn.Linear(channels, output_dim) 146 | self.output_dim = output_dim 147 | 148 | def get_z_frames(self, n_samples): 149 | n_frames = n_samples // (self.stride ** self.n_downsample) 150 | return n_frames 151 | 152 | def forward(self, audio): 153 | batch_size, n_samples = audio.shape 154 | x = audio.unsqueeze(1) 155 | x = self.model(x) # (batch, channels, n_frames) 156 | x = x.permute(0, 2, 1) 157 | assert x.shape == (batch_size, self.get_z_frames(n_samples), self.output_dims) 158 | x = torch.sigmoid(self.out(x)) 159 | return x -------------------------------------------------------------------------------- /diffsynth/f0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchcrepe 5 | import functools 6 | 7 | FMIN = 32 8 | FMAX = 2000 9 | 10 | def process_f0(f0_hz, periodicity): 11 | # Shape [1, 1 + int(time // hop_length,] 12 | # Postprocessing on f0_hz 13 | # replace unvoiced regions with NaN 14 | # win_length = 3 15 | # periodicity = torchcrepe.filter.mean(periodicity, win_length) 16 | threshold = 1e-3 17 | # if all noisy, do not perform thresholding 18 | if (periodicity > threshold).any(): 19 | f0_hz = torchcrepe.threshold.At(1e-3)(f0_hz, periodicity) 20 | # f0_hz = torchcrepe.filter.mean(f0_hz, win_length) 21 | f0_hz = f0_hz[0] 22 | # interpolate Nans 23 | # https://stackoverflow.com/questions/9537543/replace-nans-in-numpy-array-with-closest-non-nan-value 24 | f0_hz = f0_hz.numpy() 25 | mask = np.isnan(f0_hz) 26 | f0_hz[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), f0_hz[~mask]) 27 | return torch.from_numpy(f0_hz)# Shape [1 + int(time // hop_length,] 28 | 29 | def compute_f0(audio, sample_rate): 30 | """ For preprocessing 31 | Args: 32 | audio: torch.Tensor of single audio example. Shape [audio_length,]. 33 | sample_rate: Sample rate in Hz. 34 | 35 | Returns: 36 | f0_hz: Fundamental frequency in Hz. Shape [1, 1 + int(time // hop_length,] 37 | periodicity: Basically, confidence of pitch value. Shape [1, 1 + int(time // hop_length,] 38 | """ 39 | audio = audio.unsqueeze(0) 40 | 41 | # Compute f0 with torchcrepe. 42 | # uses viterbi by default 43 | # pad=False is probably center=False 44 | # [output_shape=(1, 1 + int(time // hop_length))] 45 | f0_hz, periodicity = torchcrepe.predict(audio, sample_rate, hop_length=128, pad=False, device='cuda', batch_size=2048, model='full', fmin=FMIN, fmax=FMAX, return_periodicity=True) 46 | return f0_hz, periodicity 47 | 48 | def write_f0(audiofile, f0_dir, duration, overwrite): 49 | basename = os.path.basename(audiofile) 50 | f0_file = os.path.join(f0_dir, basename[:-4]+'.pt') 51 | if overwrite or not os.path.exists(f0_file): 52 | audio, _sr = librosa.load(audiofile, sr=16000, duration=duration) 53 | f0, periodicity = compute_f0(torch.from_numpy(audio), 16000) 54 | torch.save((f0, periodicity), f0_file) 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse, os, glob 59 | import librosa 60 | import tqdm 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('base_dir', type=str, help='') 63 | parser.add_argument('--duration', type=float, default=4.0, help='') 64 | parser.add_argument('--overwrite', action='store_true') 65 | args = parser.parse_args() 66 | 67 | audio_dir = os.path.join(args.base_dir, 'audio') 68 | f0_dir = os.path.join(args.base_dir, 'f0') 69 | os.makedirs(f0_dir, exist_ok=True) 70 | raw_files = sorted(glob.glob(os.path.join(audio_dir, '*.wav'))) 71 | 72 | pool = torch.multiprocessing.Pool(processes=2) 73 | func = functools.partial(write_f0, f0_dir=f0_dir, duration=args.duration, overwrite=args.overwrite) 74 | with tqdm.tqdm(total=len(raw_files)) as t: 75 | for _ in pool.imap_unordered(func, raw_files): 76 | t.update(1) -------------------------------------------------------------------------------- /diffsynth/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class MLP(nn.Module): 6 | """ 7 | Copied from pytorch-DDSP 8 | Implementation of the MLP, as described in the original paper 9 | 10 | Parameters : 11 | in_size (int) : input size of the MLP 12 | out_size (int) : output size of the MLP 13 | loop (int) : number of repetition of Linear-Norm-ReLU 14 | """ 15 | def __init__(self, in_size=512, out_size=512, loop=3): 16 | super().__init__() 17 | self.linear = nn.ModuleList( 18 | [nn.Sequential(nn.Linear(in_size, out_size), 19 | nn.modules.normalization.LayerNorm(out_size), 20 | nn.ReLU() 21 | )] + [nn.Sequential(nn.Linear(out_size, out_size), 22 | nn.modules.normalization.LayerNorm(out_size), 23 | nn.ReLU() 24 | ) for i in range(loop - 1)]) 25 | 26 | def forward(self, x): 27 | for lin in self.linear: 28 | x = lin(x) 29 | return x 30 | 31 | class FiLM(nn.Module): 32 | """ 33 | feature-wise linear modulation 34 | """ 35 | def __init__(self, input_dim, attribute_dim): 36 | super().__init__() 37 | self.input_dim = input_dim 38 | self.generator = nn.Linear(attribute_dim, input_dim*2) 39 | 40 | def forward(self, x, c): 41 | """ 42 | x: (*, input_dim) 43 | c: (*, attribute_dim) 44 | """ 45 | c = self.generator(c) 46 | gamma = c[..., :self.input_dim] 47 | beta = c[..., self.input_dim:] 48 | return x*gamma + beta 49 | 50 | class FiLMMLP(nn.Module): 51 | """ 52 | MLP with FiLMs in between 53 | """ 54 | def __init__(self, in_size, out_size, attribute_dim, loop=3): 55 | super().__init__() 56 | self.loop = loop 57 | self.mlps = nn.ModuleList([nn.Linear(in_size, out_size)] 58 | + [nn.Linear(out_size, out_size) for i in range(loop-1)]) 59 | self.films = nn.ModuleList([FiLM(out_size, attribute_dim) for i in range(loop)]) 60 | 61 | def forward(self, x, c): 62 | """ 63 | x: (*, input_dim) 64 | c: (*, attribute_dim) 65 | """ 66 | for i in range(self.loop): 67 | x = self.mlps[i](x) 68 | x = F.relu(x) 69 | x = self.films[i](x, c) 70 | return x 71 | 72 | class Normalize1d(nn.Module): 73 | """ 74 | normalize over the last dimension 75 | ddsp normalizes over time dimension of mfcc 76 | """ 77 | def __init__(self, channels, norm_type='instance', batch_dims=1): 78 | super().__init__() 79 | self.norm_type = norm_type 80 | if norm_type == 'instance': 81 | self.norm = nn.InstanceNorm1d(channels, affine=True) 82 | if norm_type == 'batch': 83 | self.norm = nn.BatchNorm1d(channels, affine=True) 84 | self.flat = nn.Flatten(0, batch_dims-1) 85 | 86 | def forward(self, x): 87 | """ 88 | First b_dim dimensions are batch dimensions 89 | Last dim is normalized 90 | """ 91 | orig_shape = x.shape 92 | x = self.flat(x) 93 | if len(x.shape) == 2: 94 | # no channel dimension 95 | x = x.unsqueeze(1) 96 | x = self.norm(x) 97 | x = x.view(orig_shape) 98 | return x 99 | 100 | class Normalize2d(nn.Module): 101 | """ 102 | take the average over 2 dimensions (time, frequency) 103 | """ 104 | def __init__(self, norm_type='instance'): 105 | super().__init__() 106 | self.norm_type = norm_type 107 | if norm_type == 'instance': 108 | self.norm = nn.InstanceNorm2d(1) 109 | if norm_type == 'batch': 110 | self.norm = nn.BatchNorm2d(1, affine=False) 111 | 112 | def forward(self, x): 113 | """ 114 | 3D input first of which is batch dim 115 | [batch, dim1, dim2] 116 | """ 117 | x = self.norm(x.unsqueeze(1)).squeeze(1) # dummy channel 118 | return x 119 | 120 | class CoordConv1D(nn.Module): 121 | # input dimension needs to be fixed 122 | def __init__(self, in_channels, out_channels, input_dim, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True): 123 | super().__init__() 124 | # 0~1 125 | pos_embed = torch.arange(input_dim, dtype=torch.float)[None, None, :] / (input_dim) 126 | # -1~1 127 | pos_embed = pos_embed * 2 -1 128 | self.input_dim = input_dim 129 | self.kernel_size = (kernel_size,) 130 | self.stride = (stride,) 131 | self.padding = (padding,) 132 | self.dilation = (dilation,) 133 | 134 | self.register_buffer('pos_embed', pos_embed) 135 | self.conv = nn.Conv1d(in_channels+1, out_channels, kernel_size, stride, padding, dilation, groups, bias) 136 | 137 | def forward(self, x): 138 | # x: batch, C_in, H 139 | batch_size, c_in, h = x.shape 140 | coord = self.pos_embed.expand(batch_size, -1, -1) 141 | x = torch.cat([x, coord], dim=1) 142 | x = self.conv(x) 143 | return x 144 | 145 | class Resnet1D(nn.Module): 146 | """Resnet for encoder/decoder similar to Jukebox 147 | """ 148 | def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=3, reverse_dilation=False): 149 | """init 150 | 151 | Args: 152 | n_in (int): input channels 153 | n_depth (int): depth of resnet 154 | m_conv (float, optional): multiplier for intermediate channel. Defaults to 1.0. 155 | dilation_growth_rate (int, optional): rate of exponential dilation growth . Defaults to 1. 156 | reverse_dilation (bool, optional): reverse growing dilation for encoder/decoder symmetry. Defaults to False. 157 | """ 158 | super().__init__() 159 | conv_block = lambda input_channels, inner_channels, dilation: nn.Sequential( 160 | nn.ReLU(), 161 | # this conv doesn't change size 162 | nn.Conv1d(input_channels, inner_channels, 3, 1, dilation, dilation), 163 | nn.ReLU(), 164 | #1x1 convolution 165 | nn.Conv1d(inner_channels, input_channels, 1, 1, 0), 166 | ) 167 | # blocks of convolution with growing dilation 168 | conv_blocks = [conv_block(n_in, int(m_conv * n_in), 169 | dilation=dilation_growth_rate ** depth) 170 | for depth in range(n_depth)] 171 | if reverse_dilation: # decoder should be flipped backwards 172 | conv_blocks = conv_blocks[::-1] 173 | self.blocks = nn.ModuleList(conv_blocks) 174 | 175 | def forward(self, x): 176 | for block in self.blocks: 177 | # residual connection 178 | x = x + block(x) 179 | return x 180 | 181 | class Resnet2D(nn.Module): 182 | def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=3, reverse_dilation=False): 183 | """init 184 | 185 | Args: 186 | n_in (int): input channels 187 | n_depth (int): depth of resnet 188 | m_conv (float, optional): multiplier for intermediate channel. Defaults to 1.0. 189 | dilation_growth_rate (int, optional): rate of exponential dilation growth . Defaults to 3. 190 | """ 191 | super().__init__() 192 | conv_block = lambda input_channels, inner_channels, dilation: nn.Sequential( 193 | nn.ReLU(), 194 | # this conv doesn't change size 195 | nn.Conv2d(input_channels, inner_channels, 3, 1, dilation, dilation), 196 | nn.ReLU(), 197 | #1x1 convolution 198 | nn.Conv2d(inner_channels, input_channels, 1, 1, 0), 199 | ) 200 | # blocks of convolution with growing dilation 201 | conv_blocks = [conv_block(n_in, int(m_conv * n_in), 202 | dilation=dilation_growth_rate ** depth) 203 | for depth in range(n_depth)] 204 | if reverse_dilation: # decoder should be flipped backwards 205 | conv_blocks = conv_blocks[::-1] 206 | self.blocks = nn.ModuleList(conv_blocks) 207 | 208 | def forward(self, x): 209 | for block in self.blocks: 210 | # residual connection 211 | x = x + block(x) 212 | return x -------------------------------------------------------------------------------- /diffsynth/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.spectral import multiscale_fft, compute_loudness 4 | from diffsynth.util import log_eps 5 | import torch.nn.functional as F 6 | import functools 7 | 8 | def spectrogram_loss(x_audio, target_audio, fft_sizes=[64, 128, 256, 512, 1024, 2048], hop_ls=None, win_ls=None, log_mag_w=0.0, mag_w=1.0, norm=None): 9 | x_specs = multiscale_fft(x_audio, fft_sizes, hop_ls, win_ls) 10 | target_specs = multiscale_fft(target_audio, fft_sizes, hop_ls, win_ls) 11 | loss = 0.0 12 | spec_loss = {} 13 | log_spec_loss = {} 14 | for n_fft, x_spec, target_spec in zip(fft_sizes, x_specs, target_specs): 15 | spec_norm = norm['spec'][n_fft] if norm is not None else 1.0 16 | log_spec_norm = norm['logspec'][n_fft] if norm is not None else 1.0 17 | if mag_w > 0: 18 | spec_loss[n_fft] = mag_w * torch.mean(torch.abs(x_spec - target_spec)) / spec_norm 19 | if log_mag_w > 0: 20 | log_spec_loss[n_fft] = log_mag_w * torch.mean(torch.abs(log_eps(x_spec) - log_eps(target_spec))) / log_spec_norm 21 | return {'spec':spec_loss, 'logspec':log_spec_loss} 22 | 23 | def waveform_loss(x_audio, target_audio, l1_w=0, l2_w=1.0, linf_w=0, linf_k=1024, norm=None): 24 | norm = {'l1':1.0, 'l2':1.0} if norm is None else norm 25 | l1_loss = l1_w * torch.mean(torch.abs(x_audio - target_audio)) / norm['l1'] if l1_w > 0 else 0.0 26 | # mse loss 27 | l2_loss = l2_w * torch.mean((x_audio - target_audio)**2) / norm['l2'] if l2_w > 0 else 0.0 28 | if linf_w > 0: 29 | # actually gets k elements 30 | residual = (x_audio - target_audio)**2 31 | values, _ = torch.topk(residual, linf_k, dim=-1) 32 | linf_loss = torch.mean(values) / norm['l2'] 33 | else: 34 | linf_loss = 0.0 35 | return {'l1':l1_loss, 'l2':l2_loss, 'linf':linf_loss} 36 | 37 | class SpecWaveLoss(): 38 | """ 39 | loss for reconstruction with multiscale spectrogram loss and waveform loss 40 | """ 41 | def __init__(self, fft_sizes=[64, 128, 256, 512, 1024, 2048], hop_lengths=None, win_lengths=None, mag_w=1.0, log_mag_w=1.0, l1_w=0, l2_w=0.0, linf_w=0.0, linf_k=1024, norm=None): 42 | super().__init__() 43 | self.fft_sizes = fft_sizes 44 | self.hop_lengths = hop_lengths 45 | self.win_lengths = win_lengths 46 | self.mag_w = mag_w 47 | self.log_mag_w = log_mag_w 48 | self.l1_w=l1_w 49 | self.l2_w=l2_w 50 | self.linf_w=linf_w 51 | self.spec_loss = functools.partial(spectrogram_loss, fft_sizes=fft_sizes, hop_ls=hop_lengths, win_ls=win_lengths, log_mag_w=log_mag_w, mag_w=mag_w, norm=norm) 52 | self.wave_loss = functools.partial(waveform_loss, l1_w=l1_w, l2_w=l2_w, linf_w=linf_w, linf_k=linf_k, norm=norm) 53 | 54 | def __call__(self, x_audio, target_audio): 55 | if (self.mag_w + self.log_mag_w) > 0: 56 | spec_losses = self.spec_loss(x_audio, target_audio) 57 | multi_spec_loss = sum(spec_losses['spec'].values()) + sum(spec_losses['logspec'].values()) 58 | multi_spec_loss /= (len(self.fft_sizes)*(self.mag_w + self.log_mag_w)) 59 | else: # no spec loss 60 | multi_spec_loss = torch.tensor([0.0], device=x_audio.device) 61 | if (self.l1_w + self.l2_w + self.linf_w) > 0: 62 | wave_losses = self.wave_loss(x_audio, target_audio) 63 | waveform_loss = wave_losses['l1'] + wave_losses['l2'] + wave_losses['linf'] 64 | waveform_loss /= (self.l1_w + self.l2_w + self.linf_w) 65 | else: # no waveform loss 66 | waveform_loss = torch.tensor([0.0], device=x_audio.device) 67 | return multi_spec_loss, waveform_loss 68 | 69 | def calculate_norm(loader, fft_sizes, hop_ls, win_ls): 70 | """ 71 | calculate stats for scaling losses 72 | based on jukebox 73 | doesn't really work 74 | """ 75 | n, spec_n = 0, 0 76 | spec_total = {n_fft: 0.0 for n_fft in fft_sizes} 77 | log_spec_total = {n_fft: 0.0 for n_fft in fft_sizes} 78 | total, total_sq, l1_total = 0.0, 0.0, 0.0 79 | print('calculating bandwidth') 80 | for data_dict in loader: 81 | x_audio = data_dict['audio'] 82 | total = torch.sum(x_audio) 83 | total_sq = torch.sum(x_audio**2) 84 | l1_total = torch.sum(torch.abs(x_audio)) 85 | x_specs = multiscale_fft(x_audio, fft_sizes, hop_ls, win_ls) 86 | for n_fft, spec in zip(fft_sizes, x_specs): 87 | # spec: power spectrogram [batch_size, n_bins, time] 88 | spec_total[n_fft] += torch.mean(spec) 89 | # probably not right 90 | log_spec_total[n_fft] += torch.mean(torch.abs(log_eps(spec))) 91 | n += x_audio.shape[0] * x_audio.shape[1] 92 | spec_n += 1 93 | 94 | print('done.') 95 | mean = total / n 96 | for n_fft in fft_sizes: 97 | spec_total[n_fft] /= spec_n 98 | log_spec_total[n_fft] /= spec_n 99 | 100 | return {'l2': total_sq/n - mean**2, 'l1': l1_total/n, 'spec': spec_total, 'logspec': log_spec_total} 101 | -------------------------------------------------------------------------------- /diffsynth/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import diffsynth.util as util 6 | from diffsynth.spectral import compute_lsd, loudness_loss, Mfcc 7 | import pytorch_lightning as pl 8 | from diffsynth.modelutils import construct_synth_from_conf 9 | from diffsynth.schedules import ParamSchedule 10 | import hydra 11 | from diffsynth.estimator import F0MelEstimator 12 | 13 | class EstimatorSynth(pl.LightningModule): 14 | """ 15 | audio -> Estimator -> Synth -> audio 16 | """ 17 | def __init__(self, model_cfg, synth_cfg, sched_cfg): 18 | super().__init__() 19 | self.synth = construct_synth_from_conf(synth_cfg) 20 | self.estimator = hydra.utils.instantiate(model_cfg.estimator, output_dim=self.synth.ext_param_size) 21 | self.loss_w_sched = ParamSchedule(sched_cfg) # loss weighting 22 | self.sw_loss = hydra.utils.instantiate(model_cfg.sw_loss) # reconstruction loss 23 | if model_cfg.perc_model is not None: 24 | self.perc_model = hydra.utils.instantiate(model_cfg.perc_model) 25 | else: 26 | self.perc_model = None 27 | self.log_grad = model_cfg.log_grad 28 | self.lr = model_cfg.lr 29 | self.decay_rate = model_cfg.decay_rate 30 | self.mfcc = Mfcc(n_fft=1024, hop_length=256, n_mels=40, n_mfcc=20, sample_rate=16000) 31 | self.save_hyperparameters() 32 | 33 | def param_loss(self, synth_output, param_dict): 34 | loss = 0 35 | for k, target in param_dict.items(): 36 | output_name = self.synth.dag_summary[k] 37 | if output_name in self.synth.fixed_param_names: 38 | continue 39 | if target.numel() == 0: 40 | continue 41 | x = synth_output[output_name] 42 | if target.shape[1] > 1: 43 | x = util.resample_frames(x, target.shape[1]) 44 | loss += F.l1_loss(x, target) 45 | loss = loss / len(param_dict.keys()) 46 | return loss 47 | 48 | def estimate_param(self, conditioning): 49 | """ 50 | Args: 51 | conditioning (dict): {'PARAM NAME': Conditioning Tensor, ...} 52 | 53 | Returns: 54 | torch.Tensor: estimated parameters in Tensor ranged 0~1 55 | """ 56 | if isinstance(self.estimator, F0MelEstimator): 57 | return self.estimator(conditioning['audio'], conditioning['f0_hz']) 58 | return self.estimator(conditioning['audio']) 59 | 60 | def log_param_grad(self, params_dict): 61 | def save_grad(name): 62 | def hook(grad): 63 | # batch, n_frames, feat_size 64 | grad_v = grad.abs().mean(dim=(0, 1)) 65 | for i, gv in enumerate(grad_v): 66 | self.log('train/param_grad/'+name+f'_{i}', gv, on_step=False, on_epoch=True) 67 | return hook 68 | 69 | if self.log_grad: 70 | for k, v in params_dict.items(): 71 | if v.requires_grad == True: 72 | v.register_hook(save_grad(k)) 73 | 74 | def forward(self, conditioning): 75 | """ 76 | Args: 77 | conditioning (dict): {'PARAM NAME': Conditioning Tensor, ...} 78 | 79 | Returns: 80 | torch.Tensor: audio 81 | """ 82 | audio_length = conditioning['audio'].shape[1] 83 | est_param = self.estimate_param(conditioning) 84 | params_dict = self.synth.fill_params(est_param, conditioning) 85 | if self.log_grad is not None: 86 | self.log_param_grad(params_dict) 87 | 88 | resyn_audio, outputs = self.synth(params_dict, audio_length) 89 | return resyn_audio, outputs 90 | 91 | def get_params(self, conditioning): 92 | """ 93 | Don't render audio 94 | """ 95 | est_param = self.estimate_param(conditioning) 96 | params_dict = self.synth.fill_params(est_param, conditioning) 97 | if self.log_grad is not None: 98 | self.log_param_grad(params_dict) 99 | 100 | synth_params = self.synth.calculate_params(params_dict) 101 | return synth_params 102 | 103 | def train_losses(self, target, output, loss_w=None, sw_loss=None, perc_model=None): 104 | sw_loss = self.sw_loss if sw_loss is None else sw_loss 105 | perc_model = self.perc_model if perc_model is None else perc_model 106 | # always computes mean across batch dimension 107 | if loss_w is None: 108 | loss_w = {'param_w': 1.0, 'sw_w':1.0, 'perc_w':1.0} 109 | loss_dict = {} 110 | # parameter L1 loss 111 | if loss_w['param_w'] > 0.0 and 'params' in target: 112 | loss_dict['param'] = loss_w['param_w'] * self.param_loss(output, target['params']) 113 | else: 114 | loss_dict['param'] = 0.0 115 | # Audio losses 116 | target_audio = target['audio'] 117 | resyn_audio = output['output'] 118 | if loss_w['sw_w'] > 0.0 and sw_loss is not None: 119 | # Reconstruction loss 120 | spec_loss, wave_loss = sw_loss(target_audio, resyn_audio) 121 | loss_dict['spec'], loss_dict['wave'] = loss_w['sw_w'] * spec_loss, loss_w['sw_w'] * wave_loss 122 | else: 123 | loss_dict['spec'], loss_dict['wave'] = (0, 0) 124 | if loss_w['perc_w'] > 0.0 and perc_model is not None: 125 | loss_dict['perc'] = loss_w['perc_w']*perc_model.perceptual_loss(target_audio, resyn_audio) 126 | else: 127 | loss_dict['perc'] = 0 128 | return loss_dict 129 | 130 | def monitor_losses(self, target, output): 131 | mon_losses = {} 132 | # Audio losses 133 | target_audio = target['audio'] 134 | resyn_audio = output['output'] 135 | # losses not used for training 136 | mon_losses['lsd'] = compute_lsd(target_audio, resyn_audio) 137 | mon_losses['loud'] = loudness_loss(resyn_audio, target_audio) 138 | mon_losses['mfcc'] = F.l1_loss(self.mfcc(target_audio), self.mfcc(resyn_audio)) 139 | return mon_losses 140 | 141 | def training_step(self, batch_dict, batch_idx): 142 | # get loss weights 143 | loss_weights = self.loss_w_sched.get_parameters(self.global_step) 144 | self.log_dict({'lw/'+k: v for k, v in loss_weights.items()}, on_epoch=True, on_step=False) 145 | if loss_weights['sw_w']+loss_weights['perc_w'] == 0: 146 | # do not render audio because reconstruction is unnecessary 147 | synth_params = self.get_params(batch_dict) 148 | # Parameter loss 149 | batch_loss = loss_weights['param_w'] * self.param_loss(synth_params, batch_dict['params']) 150 | else: 151 | # render audio 152 | resyn_audio, outputs = self(batch_dict) 153 | losses = self.train_losses(batch_dict, outputs, loss_weights) 154 | self.log_dict({'train/'+k: v for k, v in losses.items()}, on_epoch=True, on_step=False) 155 | batch_loss = sum(losses.values()) 156 | self.log('train/total', batch_loss, prog_bar=True, on_epoch=True, on_step=False) 157 | return batch_loss 158 | 159 | def validation_step(self, batch_dict, batch_idx, dataloader_idx): 160 | # render audio 161 | resyn_audio, outputs = self(batch_dict) 162 | losses = self.train_losses(batch_dict, outputs) 163 | eval_losses = self.monitor_losses(batch_dict, outputs) 164 | losses.update(eval_losses) 165 | prefix = 'val_id/' if dataloader_idx==0 else 'val_ood/' 166 | losses = {prefix+k: v for k, v in losses.items()} 167 | self.log_dict(losses, prog_bar=True, on_epoch=True, on_step=False, add_dataloader_idx=False) 168 | return losses 169 | 170 | def get_progress_bar_dict(self): 171 | # don't show the version number 172 | items = super().get_progress_bar_dict() 173 | items.pop("v_num", None) 174 | items.pop("val_id/wave", None) 175 | return items 176 | 177 | def test_step(self, batch_dict, batch_idx, dataloader_idx): 178 | # render audio 179 | resyn_audio, outputs = self(batch_dict) 180 | losses = self.train_losses(batch_dict, outputs) 181 | eval_losses = self.monitor_losses(batch_dict, outputs) 182 | losses.update(eval_losses) 183 | prefix = 'val_id/' if dataloader_idx==0 else 'val_ood/' 184 | losses = {prefix+k: v for k, v in losses.items()} 185 | self.log_dict(losses, prog_bar=True, on_epoch=True, on_step=False, add_dataloader_idx=False) 186 | return losses 187 | 188 | def configure_optimizers(self): 189 | optimizer = torch.optim.Adam(self.estimator.parameters(), self.lr) 190 | return { 191 | "optimizer": optimizer, 192 | "lr_scheduler": { 193 | "scheduler": torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.decay_rate) 194 | } 195 | } -------------------------------------------------------------------------------- /diffsynth/modelutils.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | from diffsynth.modules.generators import SineOscillator 5 | from diffsynth.processor import Add, Mix 6 | from diffsynth.modules.fm import FM2, FM3 7 | from diffsynth.modules.envelope import ADSREnvelope 8 | from diffsynth.synthesizer import Synthesizer 9 | from diffsynth.modules.harmor import Harmor 10 | from diffsynth.modules.delay import ModulatedDelay, ChorusFlanger 11 | from diffsynth.modules.reverb import DecayReverb 12 | 13 | def construct_synth_from_conf(synth_conf): 14 | dag = [] 15 | for module_name, v in synth_conf.dag.items(): 16 | module = hydra.utils.instantiate(v.config, name=module_name) 17 | conn = v.connections 18 | dag.append((module, conn)) 19 | fixed_p = synth_conf.fixed_params 20 | fixed_p = {} if fixed_p is None else fixed_p 21 | fixed_p = {k: None if v is None else v*torch.ones(1) for k, v in fixed_p.items()} 22 | synth = Synthesizer(dag, fixed_params=fixed_p, static_params=synth_conf.static_params) 23 | return synth 24 | 25 | # defunct 26 | def construct_synths(name, sr=16000): 27 | static_params = [] 28 | if name == 'fm2_fixed': 29 | fmosc = FM2(sample_rate=sr, name='fm2') 30 | dag = [ 31 | (fmosc, {'mod_amp': 'envm', 'car_amp': 'envc', 'mod_freq': 'FRQ_M', 'car_freq': 'FRQ_C'}) 32 | ] 33 | fixed_params = {'FRQ_M': torch.ones(1)*440, 'FRQ_C': torch.ones(1)*440} 34 | static_params=['FRQ_M', 'FRQ_C'] 35 | elif name == 'fm2_free': 36 | fmosc = FM2(sample_rate=sr, name='fm2') 37 | dag = [ 38 | (fmosc, {'mod_amp': 'AMP_M', 'car_amp': 'AMP_C', 'mod_freq': 'FRQ_M', 'car_freq': 'FRQ_C'}) 39 | ] 40 | static_params=['FRQ_M', 'FRQ_C'] 41 | fixed_params = {} 42 | elif name == 'fm2_half': 43 | fmosc = FM2(sample_rate=sr, name='fm2') 44 | dag = [ 45 | (fmosc, {'mod_amp': 'AMP_M', 'car_amp': 'AMP_C', 'mod_freq': 'FRQ_M', 'car_freq': 'FRQ_C'}) 46 | ] 47 | static_params=['FRQ_M'] 48 | fixed_params = {'FRQ_C': torch.ones(1)*440} 49 | elif name == 'fm2_free_env': 50 | fmosc = FM2(sample_rate=sr, name='fm2') 51 | envm = ADSREnvelope(name='envm') 52 | envc = ADSREnvelope(name='envc') 53 | dag = [ 54 | (envm, {'floor': 'AMP_FLOOR', 'peak': 'PEAK_M', 'attack': 'AT_M', 'decay': 'DE_M', 'sus_level': 'SU_M', 'release': 'RE_M', 'note_off': 'NO'}), 55 | (envc, {'floor': 'AMP_FLOOR', 'peak': 'PEAK_C', 'attack': 'AT_C', 'decay': 'DE_C', 'sus_level': 'SU_C', 'release': 'RE_C', 'note_off': 'NO'}), 56 | (fmosc, {'mod_amp': 'envm', 'car_amp': 'envc', 'mod_freq': 'FRQ_M', 'car_freq': 'FRQ_C'}) 57 | ] 58 | fixed_params = {'AMP_FLOOR':torch.zeros(1), 'NO': torch.ones(1)*0.8} 59 | static_params=['FRQ_M', 'FRQ_C', 'PEAK_M', 'AT_M', 'DE_M', 'SU_M', 'RE_M', 'PEAK_C', 'AT_C', 'DE_C', 'SU_C', 'RE_C', 'AMP_FLOOR', 'NO'] 60 | elif name == 'fm3_free': 61 | fmosc = FM3(sample_rate=sr, name='fm3') 62 | dag = [ 63 | (fmosc, {'amp_1': 'AMP_1', 'amp_2': 'AMP_2', 'amp_3': 'AMP_3', 'freq_1': 'FRQ_1', 'freq_2': 'FRQ_2', 'freq_3': 'FRQ_3'}) 64 | ] 65 | static_params=['FRQ_1', 'FRQ_2', 'FRQ_3'] 66 | fixed_params = {} 67 | elif name == 'fm2x2_free': 68 | fm2_1 = FM2(sample_rate=sr, name='fm2_1') 69 | fm2_2 = FM2(sample_rate=sr, name='fm2_2') 70 | mix = Mix(name='add') 71 | dag = [ 72 | (fm2_1, {'mod_amp': 'AMP_1', 'car_amp': 'AMP_2', 'mod_freq': 'FRQ_1', 'car_freq': 'FRQ_2'}), 73 | (fm2_2, {'mod_amp': 'AMP_3', 'car_amp': 'AMP_4', 'mod_freq': 'FRQ_3', 'car_freq': 'FRQ_4'}), 74 | (mix, {'signal_a': 'fm2_1', 'signal_b': 'fm2_2', 'mix_a': 'MIX_A', 'mix_b': 'MIX_B'}) 75 | ] 76 | static_params=['FRQ_1', 'FRQ_2', 'FRQ_3', 'FRQ_4', 'MIX_A', 'MIX_B'] 77 | fixed_params = {'MIX_A': torch.ones(1)*0.5, 'MIX_B': torch.ones(1)*0.5} 78 | elif name == 'fm6_free': 79 | fm3_1 = FM3(sample_rate=sr, name='fm3_1') 80 | fm3_2 = FM3(sample_rate=sr, name='fm3_2') 81 | add = Add(name='add') 82 | dag = [ 83 | (fm3_1, {'amp_1': 'AMP_1', 'amp_2': 'AMP_2', 'amp_3': 'AMP_3', 'freq_1': 'FRQ_1', 'freq_2': 'FRQ_2', 'freq_3': 'FRQ_3'}), 84 | (fm3_2, {'amp_1': 'AMP_4', 'amp_2': 'AMP_5', 'amp_3': 'AMP_6', 'freq_1': 'FRQ_4', 'freq_2': 'FRQ_5', 'freq_3': 'FRQ_6'}), 85 | (add, {'signal_a': 'fm3_1', 'signal_b': 'fm3_2'}) 86 | ] 87 | static_params=['FRQ_1', 'FRQ_2', 'FRQ_3', 'FRQ_4', 'FRQ_5', 'FRQ_6'] 88 | fixed_params = {} 89 | elif name == 'sin': 90 | sin = SineOscillator(sample_rate=sr, name='sin') 91 | dag = [ 92 | (sin, {'amplitudes': 'AMP','frequencies': 'FRQ'}) 93 | ] 94 | static_params=['FRQ'] 95 | fixed_params = {} 96 | elif name == 'harmor_fixed': 97 | harmor = Harmor(sample_rate=sr, name='harmor', n_oscs=1) 98 | dag = [ 99 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'DUMMY', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}) 100 | ] 101 | fixed_params = {'BFRQ': torch.ones(1)*440} 102 | static_params = ['M_OSC', 'Q_FILT', 'BFRQ', 'DUMMY'] 103 | elif name == 'harmor_1oscfree': 104 | harmor = Harmor(sample_rate=sr, name='harmor', n_oscs=1) 105 | dag = [ 106 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'DUMMY', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}) 107 | ] 108 | fixed_params = {} 109 | static_params = ['M_OSC', 'Q_FILT', 'DUMMY', 'BFRQ'] 110 | elif name == 'harmor_2oscfree': 111 | harmor = Harmor(sample_rate=sr, name='harmor', sep_amp=True, n_oscs=2) 112 | dag = [ 113 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'MULT', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}), 114 | ] 115 | fixed_params = {} 116 | static_params=['BFRQ', 'M_OSC', 'MULT', 'Q_FILT'] 117 | elif name == 'harmor_cf': 118 | harmor = Harmor(sample_rate=sr, name='harmor', sep_amp=True, n_oscs=2) 119 | md = ModulatedDelay(name='md', sr=sr) 120 | dag = [ 121 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'MULT', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}), 122 | (md, {'audio': 'harmor', 'delay_ms': 'MD_DELAY', 'phase': 'MD_PHASE', 'depth': 'MD_DEPTH', 'mix': 'MD_MIX'}) 123 | ] 124 | fixed_params = {} 125 | static_params=['BFRQ', 'M_OSC', 'MULT', 'Q_FILT', 'MD_DELAY', 'MD_DEPTH', 'MD_MIX'] 126 | elif name == 'harmor_cffixed': 127 | harmor = Harmor(sample_rate=sr, name='harmor', sep_amp=True, n_oscs=2) 128 | cf = ChorusFlanger(name='cf', sr=sr, delay_range=(1.0, 40.0)) 129 | dag = [ 130 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'MULT', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}), 131 | (cf, {'audio': 'harmor', 'delay_ms': 'CF_DELAY', 'rate': 'CF_RATE', 'depth': 'CF_DEPTH', 'mix': 'CF_MIX'}) 132 | ] 133 | fixed_params = {'CF_RATE': torch.ones(1), 'CF_DEPTH': torch.ones(1)*0.1} 134 | static_params=['BFRQ', 'M_OSC', 'MULT', 'Q_FILT', 'CF_DELAY', 'CF_MIX', 'CF_RATE', 'CF_DEPTH'] 135 | elif name == 'harmor_cffenv': 136 | harmor = Harmor(name='harmor', sep_amp=True, n_oscs=2) 137 | enva = ADSREnvelope(name='enva', max_value=0.6, channels=2) 138 | envc = ADSREnvelope(name='envc', channels=1) 139 | cf = ChorusFlanger(name='cf', sr=sr, delay_range=(1.0, 40.0)) 140 | dag = [ 141 | (enva, {'floor': 'AMP_FLOOR', 'peak': 'PEAK_A', 'attack': 'AT_A', 'decay': 'DE_A', 'sus_level': 'SU_A', 'release': 'RE_A', 'note_off': 'NO', 'noise_mag': 'NOISE_A'}), 142 | (envc, {'floor': 'CUT_FLOOR', 'peak': 'PEAK_C', 'attack': 'AT_C', 'decay': 'DE_C', 'sus_level': 'SU_C', 'release': 'RE_C', 'note_off': 'NO', 'noise_mag': 'NOISE_C'}), 143 | (harmor, {'amplitudes': 'enva', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'MULT', 'cutoff': 'envc', 'q': 'Q_FILT'}), 144 | (cf, {'audio': 'harmor', 'delay_ms': 'CF_DELAY', 'rate': 'CF_RATE', 'depth': 'CF_DEPTH', 'mix': 'CF_MIX'}) 145 | ] 146 | fixed_params = {'AMP_FLOOR':torch.zeros(1), 'NO': torch.ones(1)*0.75, 'CF_RATE': torch.ones(1), 'CF_DEPTH': torch.ones(1)*0.1, 'NOISE_A':torch.zeros(1), 'NOISE_C':torch.zeros(1)} 147 | static_params=['AMP_FLOOR', 'PEAK_A', 'AT_A', 'DE_A', 'SU_A', 'RE_A', 'CUT_FLOOR', 'PEAK_C', 'AT_C', 'DE_C', 'SU_C', 'RE_C', 'BFRQ', 'MULT', 'M_OSC', 'Q_FILT', 'NOISE_A', 'NOISE_C', 'CF_DELAY', 'CF_MIX', 'NO', 'CF_DEPTH', 'CF_RATE'] 148 | elif name == 'harmor_rev': 149 | harmor = Harmor(sample_rate=sr, name='harmor', sep_amp=True, n_oscs=2) 150 | reverb = DecayReverb(name='reverb', ir_length=16000) 151 | dag = [ 152 | (harmor, {'amplitudes': 'AMP', 'osc_mix': 'M_OSC', 'f0_hz': 'BFRQ', 'f0_mult': 'MULT', 'cutoff': 'CUTOFF', 'q': 'Q_FILT'}), 153 | (reverb, {'audio': 'harmor', 'gain': 'RE_GAIN', 'decay': 'RE_DECAY'}) 154 | ] 155 | fixed_params = {} 156 | static_params=['BFRQ', 'M_OSC', 'MULT', 'Q_FILT', 'RE_GAIN', 'RE_DECAY'] 157 | synth = Synthesizer(dag, fixed_params=fixed_params, static_params=static_params) 158 | 159 | return synth 160 | 161 | -------------------------------------------------------------------------------- /diffsynth/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyakuchiki/diffsynth/4fc64ba04608cbaf8e016e4878fc50bedfca4da3/diffsynth/modules/.DS_Store -------------------------------------------------------------------------------- /diffsynth/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyakuchiki/diffsynth/4fc64ba04608cbaf8e016e4878fc50bedfca4da3/diffsynth/modules/__init__.py -------------------------------------------------------------------------------- /diffsynth/modules/delay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Processor 4 | import diffsynth.util as util 5 | import numpy as np 6 | import math 7 | 8 | class ModulatedDelay(Processor): 9 | """ 10 | Use with LFO to create flanger/vibrato/chorus 11 | """ 12 | def __init__(self, name='chorus', sample_rate=16000): 13 | super().__init__(name) 14 | self.sr = sample_rate 15 | self.param_desc = { 16 | 'audio': {'size': 1, 'range': (-1, 1), 'type': 'raw'}, 17 | 'delay_ms': {'size': 1, 'range': (1, 10.0), 'type': 'sigmoid'}, #ms 18 | 'phase': {'size': 1, 'range': (-1.0, 1.0), 'type': 'sigmoid'}, 19 | 'depth': {'size': 1, 'range': (0, 0.25), 'type': 'sigmoid'}, 20 | 'mix': {'size': 1, 'range': (0, 1.0), 'type': 'sigmoid'} 21 | } 22 | 23 | def forward(self, audio, delay_ms, phase, depth, mix): 24 | """pass audio through chorus/flanger 25 | Args: 26 | audio (torch.Tensor): [batch_size, n_samples] 27 | # static parameters 28 | delay_ms (torch.Tensor): Average delay in ms [batch_size, 1, 1] 29 | phase (torch.Tensor): -1->delay_ms*(1-depth), 1->delay_ms*(1+depth) [batch_size, n_frames, 1] 30 | depth (torch.Tensor): Lfo depth relative to delay_ms (0~1) [batch_size, 1, 1] 31 | mix (torch.Tensor): wet/dry ratio (0: all dry, 1: all wet) [batch_size, n_samples or 1, 1] 32 | 33 | Returns: 34 | [torch.Tensor]: Mixed audio. Shape [batch, n_samples] 35 | """ 36 | # delay: delay_ms*(1-depth) <-> delay_ms*(1+depth) 37 | 38 | delay_ms = delay_ms.squeeze(-1) 39 | depth = depth.squeeze(-1) 40 | mix = mix.squeeze(-1) 41 | phase = util.resample_frames(phase, audio.shape[1]) 42 | phase = phase.squeeze(-1) 43 | 44 | max_delay = self.param_desc['delay_ms']['range'][1] * 2 / 1000.0 * self.sr # samples 45 | delay_center = delay_ms / 1000.0 * self.sr # samples 46 | delay_value = phase * (depth * delay_center) + delay_center 47 | delay_phase = delay_value / max_delay # 0-> no delay 1: max_delay 48 | delayed = util.variable_delay(delay_phase, audio, buf_size=math.ceil(max_delay)) 49 | return mix * delayed + (1-mix)*audio 50 | 51 | class ChorusFlanger(Processor): 52 | """ 53 | LFO modulated delay 54 | no feedback 55 | delay_ms: 56 | Flanger: 1ms~5ms 57 | Chorus: 5ms~ 58 | """ 59 | 60 | def __init__(self, name='chorus', sample_rate=16000, delay_range=(1.0, 40.0)): 61 | super().__init__(name) 62 | self.sr = sample_rate 63 | self.param_desc = { 64 | 'delay_ms': {'size': 1, 'range': delay_range, 'type': 'sigmoid'}, #ms 65 | 'rate': {'size': 1, 'range': (0.1, 10.0), 'type': 'sigmoid'}, #Hz 66 | 'depth': {'size': 1, 'range': (0, 0.25), 'type': 'sigmoid'}, 67 | 'mix': {'size': 1, 'range': (0, 0.5), 'type': 'sigmoid'} 68 | } 69 | 70 | def forward(self, audio, delay_ms, rate, depth, mix): 71 | """pass audio through chorus/flanger 72 | Args: 73 | audio (torch.Tensor): [batch_size, n_samples] 74 | # static parameters 75 | delay_ms (torch.Tensor): Average delay in ms [batch_size, 1, 1] 76 | rate (torch.Tensor): LFO rate in Hz [batch_size, 1, 1] 77 | depth (torch.Tensor): Lfo depth relative to delay_ms (0~1) [batch_size, 1, 1] 78 | mix (torch.Tensor): wet/dry ratio (0: all dry, 1: all wet) [batch_size, n_samples or 1, 1] 79 | 80 | Returns: 81 | [torch.Tensor]: Mixed audio. Shape [batch, n_samples] 82 | """ 83 | # delay: delay_ms*(1-depth) <-> delay_ms*(1+depth) 84 | 85 | delay_ms = delay_ms.squeeze(-1) 86 | rate = rate.squeeze(-1) 87 | depth = depth.squeeze(-1) 88 | mix = mix.squeeze(-1) 89 | 90 | max_delay = self.param_desc['delay_ms']['range'][1] * 2 / 1000.0 * self.sr # samples 91 | delay_center = delay_ms / 1000.0 * self.sr # samples 92 | n_samples = audio.shape[1] 93 | delay_lfo = torch.sin(torch.linspace(0, n_samples/self.sr, n_samples, device=mix.device)[None, :]*math.pi*2*rate) 94 | delay_value = delay_lfo * (depth*delay_center) + delay_center 95 | delay_phase = delay_value / max_delay 96 | delayed = util.variable_delay(delay_phase, audio, buf_size=math.ceil(max_delay)) 97 | return mix * delayed + (1-mix)*audio 98 | 99 | -------------------------------------------------------------------------------- /diffsynth/modules/envelope.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from diffsynth.processor import Processor 4 | 5 | def soft_clamp_min(x, min_v, T=100): 6 | return torch.sigmoid((min_v-x)*T)*(min_v-x)+x 7 | 8 | class ADSREnvelope(Processor): 9 | def __init__(self, n_frames=250, name='env', min_value=0.0, max_value=1.0, channels=1): 10 | super().__init__(name=name) 11 | self.n_frames = int(n_frames) 12 | self.param_names = ['total_level', 'attack', 'decay', 'sus_level', 'release'] 13 | self.min_value = min_value 14 | self.max_value = max_value 15 | self.channels = channels 16 | self.param_desc = { 17 | 'floor': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 18 | 'peak': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 19 | 'attack': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 20 | 'decay': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 21 | 'sus_level': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 22 | 'release': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 23 | 'noise_mag': {'size':self.channels, 'range': (0, 0.1), 'type': 'sigmoid'}, 24 | 'note_off': {'size':self.channels, 'range': (0, 1), 'type': 'sigmoid'}, 25 | } 26 | 27 | def forward(self, floor, peak, attack, decay, sus_level, release, noise_mag=0.0, note_off=0.8, n_frames=None): 28 | """generate envelopes from parameters 29 | 30 | Args: 31 | floor (torch.Tensor): floor level of the signal 0~1, 0=min_value (batch, 1, channels) 32 | peak (torch.Tensor): peak level of the signal 0~1, 1=max_value (batch, 1, channels) 33 | attack (torch.Tensor): relative attack point 0~1 (batch, 1, channels) 34 | decay (torch.Tensor): actual decay point is attack+decay (batch, 1, channels) 35 | sus_level (torch.Tensor): sustain level 0~1 (batch, 1, channels) 36 | release (torch.Tensor): release point is attack+decay+release (batch, 1, channels) 37 | note_off (float or torch.Tensor, optional): note off position. Defaults to 0.8. 38 | n_frames (int, optional): number of frames. Defaults to None. 39 | 40 | Returns: 41 | torch.Tensor: envelope signal (batch_size, n_frames, 1) 42 | """ 43 | torch.clamp(floor, min=0, max=1) 44 | torch.clamp(peak, min=0, max=1) 45 | torch.clamp(attack, min=0, max=1) 46 | torch.clamp(decay, min=0, max=1) 47 | torch.clamp(sus_level, min=0, max=1) 48 | torch.clamp(release, min=0, max=1) 49 | 50 | batch_size = attack.shape[0] 51 | if n_frames is None: 52 | n_frames = self.n_frames 53 | # batch, n_frames, 1 54 | x = torch.linspace(0, 1.0, n_frames)[None, :, None].repeat(batch_size, 1, self.channels) 55 | x = x.to(attack.device) 56 | attack = attack * note_off 57 | A = x / (attack) 58 | A = torch.clamp(A, max=1.0) 59 | D = (x - attack) * (sus_level - 1) / (decay+1e-5) 60 | D = torch.clamp(D, max=0.0) 61 | D = soft_clamp_min(D, sus_level-1) 62 | S = (x - note_off) * (-sus_level / (release+1e-5)) 63 | S = torch.clamp(S, max=0.0) 64 | S = soft_clamp_min(S, -sus_level) 65 | peak = peak * self.max_value + (1 - peak) * self.min_value 66 | floor = floor * self.max_value + (1 - floor) * self.min_value 67 | signal = (A + D + S + torch.randn_like(A)*noise_mag)*(peak - floor) + floor 68 | return torch.clamp(signal, min=self.min_value, max=self.max_value) -------------------------------------------------------------------------------- /diffsynth/modules/filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Processor 4 | import diffsynth.util as util 5 | import numpy as np 6 | import math 7 | 8 | class FIRFilter(Processor): 9 | """ 10 | uses frequency sampling 11 | """ 12 | 13 | def __init__(self, filter_size=64, name='firfilter', scale_fn=util.exp_sigmoid, initial_bias=-5.0): 14 | super().__init__(name) 15 | self.filter_size = filter_size 16 | self.scale_fn = scale_fn 17 | self.initial_bias = initial_bias 18 | self.param_desc = { 19 | 'freq_response': {'size': self.filter_size // 2 + 1, 'range': (1e-7, 2.0), 'type': 'exp_sigmoid'}, 20 | 'audio': {'size':1, 'range': (-1, 1), 'type': 'raw'}, 21 | } 22 | 23 | def forward(self, audio, freq_response): 24 | """pass audio through FIRfilter 25 | Args: 26 | audio (torch.Tensor): [batch, n_samples] 27 | freq_response (torch.Tensor): frequency response (only magnitude) [batch, n_frames, filter_size // 2 + 1] 28 | 29 | Returns: 30 | [torch.Tensor]: Filtered audio. Shape [batch, n_samples] 31 | """ 32 | if self.scale_fn is not None: 33 | freq_response = self.scale_fn(freq_response + self.initial_bias) 34 | return util.fir_filter(audio, freq_response, self.filter_size) 35 | 36 | # class SVFCell(nn.Module): 37 | # def __init__(self): 38 | # super().__init__() 39 | 40 | # def forward(self, x, h_1, h_2, g, twoR, coeff_1, coeff_2): 41 | # # parameter [batch_size] 42 | # y_bp = coeff_2 * (x - h_2) + coeff_1 * h_1 43 | # y_lp = g * y_bp + h_2 44 | # y_hp = x - y_lp - twoR * y_bp 45 | # h_1 = 2 * y_bp - h_1 46 | # h_2 = 2 * y_lp - h_2 47 | # return y_bp, y_lp, y_hp, h_1, h_2 48 | 49 | """ 50 | NOTE: SVF Filter is very slow 51 | """ 52 | 53 | class SVFLayer(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | # self.cell = cell 57 | 58 | def forward(self, audio, g, twoR, mix): 59 | """pass audio through SVF 60 | Args: 61 | *** time-first, batch-second *** 62 | audio (torch.Tensor): [n_samples, batch_size] 63 | All filter parameters are [n_samples, batch_size, 1or3] 64 | g (torch.Tensor): Cutoff parameter 65 | twoR (torch.Tensor): Damping parameter 66 | mix (torch.Tensor): Mixing coefficient of bp, lp and hp 67 | 68 | Returns: 69 | [torch.Tensor]: Filtered audio. Shape [batch, n_samples] 70 | """ 71 | seq_len, batch_size = audio.shape 72 | T = 1.0 / (1.0 + g * (g + twoR)) 73 | H = T.unsqueeze(-1) * torch.cat([torch.ones_like(g), -g, g, twoR*g+1], dim=-1).reshape(seq_len, batch_size, 2, 2) 74 | 75 | # Y = gHBx + Hs 76 | gHB = g * T * torch.cat([torch.ones_like(g), g], dim=-1) 77 | # [n_samples, batch_size, 2] 78 | gHBx = gHB * audio.unsqueeze(-1) 79 | 80 | Y = torch.empty(seq_len, batch_size, 2, device=audio.device) 81 | # initialize filter state 82 | state = torch.ones(batch_size, 2, device=audio.device) 83 | for t in range(seq_len): 84 | Y[t] = gHBx[t] + torch.bmm(H[t], state.unsqueeze(-1)).squeeze(-1) 85 | state = 2 * Y[t] - state 86 | 87 | # HP = x - LP - 2R*BP 88 | y_hps = audio - twoR.squeeze(-1) * Y[:, :, 0] - Y[:, :, 1] 89 | 90 | y_mixed = twoR.squeeze(-1) * mix[:, :, 0] * Y[:, :, 0] + mix[:, :, 1] * Y[:, :, 1] + mix[:, :, 2] * y_hps 91 | y_mixed = y_mixed.permute(1,0).contiguous() 92 | return y_mixed 93 | 94 | class SVFilter(Processor): 95 | def __init__(self, name='svf'): 96 | super().__init__(name) 97 | self.svf = torch.jit.script(SVFLayer()) 98 | self.param_desc = { 99 | 'audio': {'size': 1, 'range': (-1, 1), 'type': 'sigmoid'}, 100 | 'g': {'size': 1, 'range': (1e-6, 1), 'type': 'sigmoid'}, 101 | 'twoR': {'size': 1, 'range': (1e-6, np.sqrt(2)), 'type': 'sigmoid'}, 102 | 'mix': {'size': 3, 'range': (0, 1.0), 'type': 'sigmoid'} 103 | } 104 | 105 | def forward(self, audio, g, twoR, mix): 106 | """pass audio through SVF 107 | Args: 108 | *** batch-first *** 109 | audio (torch.Tensor): [batch_size, n_samples] 110 | All filter parameters are [batch_size, frame_size, 1or3] 111 | g (torch.Tensor): Cutoff parameter 112 | twoR (torch.Tensor): Damping parameter 113 | mix (torch.Tensor): Mixing coefficient of bp, lp and hp 114 | 115 | Returns: 116 | [torch.Tensor]: Filtered audio. Shape [batch, n_samples] 117 | """ 118 | batch_size, seq_len = audio.shape 119 | audio = audio.permute(1, 0).contiguous() 120 | 121 | # g = torch.clamp(g, min=1e-6, max=1) 122 | # twoR = torch.clamp(twoR, min=1e-6, max=np.sqrt(2)) 123 | 124 | if g.ndim == 2: # not time changing 125 | g = g[None, :, :].expand(seq_len, -1, -1) 126 | else: 127 | if g.shape[1] != seq_len: 128 | g = util.resample_frames(g, seq_len) 129 | g = g.permute(1, 0, 2).contiguous() 130 | 131 | if twoR.ndim == 2: # not time changing 132 | twoR = twoR[None, :, :].expand(seq_len, -1, -1) 133 | else: 134 | if twoR.shape[1] != seq_len: 135 | twoR = util.resample_frames(twoR, seq_len) 136 | twoR = twoR.permute(1, 0, 2).contiguous() 137 | 138 | # normalize mixing coefficient 139 | mix = mix / mix.sum(dim=-1, keepdim=True) 140 | if mix.ndim == 2: # not time changing 141 | mix[None, :, :].expand(seq_len, -1, -1) 142 | else: 143 | if mix.shape[1] != seq_len: 144 | mix = util.resample_frames(mix, seq_len) 145 | mix = mix.permute(1, 0, 2).contiguous() 146 | 147 | # time, batch, (1~3) 148 | filt_audio = self.svf(audio, g, twoR, mix) 149 | return filt_audio 150 | -------------------------------------------------------------------------------- /diffsynth/modules/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Gen, FREQ_RANGE 4 | import diffsynth.util as util 5 | import numpy as np 6 | 7 | class FM2(Gen): 8 | """ 9 | FM Synth with one carrier and one modulator both sine waves 10 | """ 11 | def __init__(self, sample_rate=16000, max_mod_index=14, name='fm2'): 12 | super().__init__(name=name) 13 | self.sample_rate = sample_rate 14 | self.mod_ratio = np.log(max_mod_index+1) 15 | self.param_desc = { 16 | 'mod_amp': {'size': 1, 'range': (0, 1), 'type': 'sigmoid'}, 17 | 'mod_freq': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 18 | 'car_amp': {'size': 1, 'range': (0, 1), 'type': 'sigmoid'}, 19 | 'car_freq': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'} 20 | } 21 | 22 | def forward(self, mod_amp, mod_freq, car_amp, car_freq, n_samples): 23 | # https://sound.stackexchange.com/questions/31709/what-is-the-level-of-frequency-modulation-of-many-synthesizers 24 | mod_amp = torch.exp(mod_amp**3.4*self.mod_ratio) - 1 25 | 26 | mod_signal = util.sin_synthesis(mod_freq, mod_amp, n_samples, self.sample_rate) 27 | car_signal = util.sin_synthesis(car_freq, car_amp, n_samples, self.sample_rate, mod_signal) 28 | return car_signal 29 | 30 | class FM3(Gen): 31 | """ 32 | Osc1 -> Osc2 -> Osc3 -> output 33 | All sin waves 34 | """ 35 | def __init__(self, sample_rate=16000, max_mod_index=14, name='fm3'): 36 | super().__init__(name=name) 37 | self.sample_rate = sample_rate 38 | # self.mod_ratio = np.log(max_mod_index+1) 39 | self.max_mod_index = max_mod_index 40 | self.param_desc = { 41 | 'amp_1': {'size': 1, 'range': (0, 1), 'type': 'sigmoid'}, 42 | 'freq_1': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 43 | 'amp_2': {'size': 1, 'range': (0, 1), 'type': 'sigmoid'}, 44 | 'freq_2': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 45 | 'amp_3': {'size': 1, 'range': (0, 1), 'type': 'sigmoid'}, 46 | 'freq_3': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'} 47 | } 48 | 49 | def forward(self, amp_1, freq_1, amp_2, freq_2, amp_3, freq_3, n_samples): 50 | audio_1 = util.sin_synthesis(freq_1, amp_1, n_samples, self.sample_rate) 51 | audio_2 = util.sin_synthesis(freq_2, amp_2, n_samples, self.sample_rate, fm_signal=audio_1 * self.max_mod_index) 52 | audio_3 = util.sin_synthesis(freq_3, amp_3, n_samples, self.sample_rate, fm_signal=audio_2 * self.max_mod_index) 53 | return audio_3 -------------------------------------------------------------------------------- /diffsynth/modules/frequency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from diffsynth.processor import Processor 5 | 6 | class FreqMultiplier(Processor): 7 | 8 | def __init__(self, mult_low=0.5, mult_high=8, name='frq'): 9 | super().__init__(name=name) 10 | self.mult_low = mult_low 11 | self.mult_high = mult_high 12 | self.param_desc = { 13 | 'base_freq':{'size': 1, 'range': (32.7, 2093), 'type': 'freq_sigmoid'}, 14 | 'mult': {'size': 1, 'range': (self.mult_low, self.mult_high), 'type': 'sigmoid'}, 15 | } 16 | 17 | def forward(self, base_freq, mult): 18 | frq = base_freq * mult 19 | return frq 20 | 21 | class FreqKnobsCoarse(Processor): 22 | # DX7 oscillator frequency knobs without fine 23 | 24 | def __init__(self, name='frq', coarse_scale_fn='gumbel'): 25 | super().__init__(name=name) 26 | # coarse: 0.5, 1, 2, 3, 4, ..., 31 27 | multipliers = torch.arange(0, 8, dtype=torch.float) 28 | multipliers[0] = 0.5 29 | self.register_buffer('multipliers', multipliers) 30 | self.coarse_scale_fn = coarse_scale_fn 31 | self.param_desc = { 32 | 'base_freq':{'size': 1, 'range': (32.7, 2093), 'type': 'freq_sigmoid'}, 33 | 'coarse': {'size': 8, 'range': (-np.inf, np.inf), 'type': 'raw'}, 34 | 'detune': {'size': 8, 'range': (-7, 7), 'type': 'sigmoid'}, 35 | } 36 | 37 | def forward(self, base_freq, coarse, detune): 38 | if self.coarse_scale_fn == 'gumbel': 39 | # coarse - logits over multipliers 40 | one_hot = F.gumbel_softmax(coarse, tau=1, hard=True, dim=-1) 41 | coarse_value = (one_hot * self.multipliers).sum(dim=-1) 42 | elif self.coarse_scale_fn is not None: 43 | coarse_value = self.coarse_scale_fn(coarse) 44 | else: 45 | one_hot = torch.argmax(coarse, dim=-1) 46 | coarse_value = (one_hot * self.multipliers).sum(dim=-1) 47 | 48 | coarse_value = coarse_value.unsqueeze(-1) 49 | frq = base_freq * coarse_value 50 | frq = (frq + detune) #Hz 51 | return frq -------------------------------------------------------------------------------- /diffsynth/modules/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Gen, FREQ_RANGE 4 | import diffsynth.util as util 5 | import numpy as np 6 | from diffsynth.util import midi_to_hz 7 | 8 | class Additive(Gen): 9 | """Synthesize audio with a bank of harmonic sinusoidal oscillators. 10 | code mostly borrowed from DDSP""" 11 | 12 | def __init__(self, sample_rate=16000, normalize_below_nyquist=True, name='harmonic', n_harmonics=64): 13 | super().__init__(name=name) 14 | self.sample_rate = sample_rate 15 | self.normalize_below_nyquist = normalize_below_nyquist 16 | self.n_harmonics = n_harmonics 17 | self.param_desc = { 18 | 'amplitudes': {'size': 1, 'range': (0, 1), 'type': 'exp_sigmoid'}, 19 | 'harmonic_distribution': {'size': self.n_harmonics, 'range': (0, 1), 'type': 'exp_sigmoid'}, 20 | 'f0_hz': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'} 21 | } 22 | 23 | def forward(self, amplitudes, harmonic_distribution, f0_hz, n_samples): 24 | """Synthesize audio with additive synthesizer from controls. 25 | 26 | Args: 27 | amplitudes: Amplitude tensor of shape [batch, n_frames, 1]. 28 | harmonic_distribution: Tensor of shape [batch, n_frames, n_harmonics]. 29 | f0_hz: The fundamental frequency in Hertz. Tensor of shape [batch, 30 | n_frames, 1]. 31 | 32 | Returns: 33 | signal: A tensor of harmonic waves of shape [batch, n_samples]. 34 | """ 35 | if len(f0_hz.shape) < 3: # when given as a condition 36 | f0_hz = f0_hz[:, :, None] 37 | # Bandlimit the harmonic distribution. 38 | if self.normalize_below_nyquist: 39 | n_harmonics = int(harmonic_distribution.shape[-1]) 40 | harmonic_frequencies = util.get_harmonic_frequencies(f0_hz, n_harmonics) 41 | harmonic_distribution = util.remove_above_nyquist(harmonic_frequencies, harmonic_distribution, self.sample_rate) 42 | 43 | # Normalize 44 | harmonic_distribution /= torch.sum(harmonic_distribution, axis=-1, keepdim=True) 45 | 46 | signal = util.harmonic_synthesis(frequencies=f0_hz, amplitudes=amplitudes, harmonic_distribution=harmonic_distribution, n_samples=n_samples, sample_rate=self.sample_rate) 47 | return signal 48 | 49 | class Sinusoids(Gen): 50 | def __init__(self, sample_rate=16000, name='sinusoids', n_sinusoids=64): 51 | super().__init__(name=name) 52 | self.sample_rate = sample_rate 53 | self.n_sinusoids = n_sinusoids 54 | self.param_desc = { 55 | 'amplitudes': {'size': self.n_sinusoids, 'range': (0, 2), 'type': 'exp_sigmoid'}, 56 | 'frequencies': {'size': self.n_sinusoids, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 57 | } 58 | 59 | def forward(self, amplitudes, frequencies, n_samples): 60 | """Synthesize audio with sinusoid oscillators 61 | 62 | Args: 63 | amplitudes: Amplitude tensor of shape [batch, n_frames, n_sinusoids]. 64 | frequencies: Tensor of shape [batch, n_frames, n_sinusoids]. 65 | 66 | Returns: 67 | signal: A tensor of harmonic waves of shape [batch, n_samples]. 68 | """ 69 | 70 | # resample to n_samples 71 | amplitudes_envelope = util.resample_frames(amplitudes, n_samples) 72 | frequency_envelope = util.resample_frames(frequencies, n_samples) 73 | 74 | signal = util.oscillator_bank(frequency_envelope, amplitudes_envelope, self.sample_rate) 75 | return signal 76 | 77 | class FilteredNoise(Gen): 78 | """ 79 | taken from ddsp-pytorch and ddsp 80 | uses frequency sampling 81 | """ 82 | 83 | def __init__(self, filter_size=257, scale_fn=util.exp_sigmoid, name='noise', initial_bias=-5.0, amplitude=1.0): 84 | super().__init__(name=name) 85 | self.filter_size = filter_size 86 | self.scale_fn = scale_fn 87 | self.initial_bias = initial_bias 88 | self.amplitude = amplitude 89 | self.param_desc = { 90 | 'freq_response': {'size': self.filter_size // 2 + 1, 'range': (1e-7, 2.0), 'type': 'exp_sigmoid'}, 91 | } 92 | 93 | def forward(self, freq_response, n_samples): 94 | """generate Gaussian white noise through FIRfilter 95 | Args: 96 | freq_response (torch.Tensor): frequency response (only magnitude) [batch, n_frames, filter_size // 2 + 1] 97 | 98 | Returns: 99 | [torch.Tensor]: Filtered audio. Shape [batch, n_samples] 100 | """ 101 | 102 | batch_size = freq_response.shape[0] 103 | if self.scale_fn: 104 | freq_response = self.scale_fn(freq_response + self.initial_bias) 105 | 106 | audio = (torch.rand(batch_size, n_samples)*2.0-1.0).to(freq_response.device) * self.amplitude 107 | filtered = util.fir_filter(audio, freq_response, self.filter_size) 108 | return filtered 109 | 110 | class Wavetable(Gen): 111 | """Synthesize audio from a wavetable (series of single cycle waveforms). 112 | wavetable is parameterized 113 | code mostly borrowed from DDSP 114 | """ 115 | 116 | def __init__(self, len_waveform, sample_rate=16000, name='wavetable'): 117 | super().__init__(name=name) 118 | self.sample_rate = sample_rate 119 | self.len_waveform = len_waveform 120 | self.param_desc = { 121 | 'amplitudes': {'size': 1, 'range': (0, 1.0), 'type': 'sigmoid'}, 122 | 'wavetable': {'size': self.len_waveform, 'range': (-1, 1), 'type': 'sigmoid'}, 123 | 'f0_hz': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 124 | } 125 | 126 | def forward(self, amplitudes, wavetable, f0_hz, n_samples): 127 | """forward pass 128 | 129 | Args: 130 | amplitudes: (batch_size, n_frames) 131 | wavetable ([type]): (batch_size, n_frames, len_waveform) 132 | f0_hz ([type]): frequency of oscillator at each frame (batch_size, n_frames) 133 | 134 | Returns: 135 | signal: synthesized signal ([batch_size, n_samples]) 136 | """ 137 | 138 | signal = util.wavetable_synthesis(f0_hz, amplitudes, wavetable, n_samples, self.sample_rate) 139 | return signal 140 | 141 | class SawOscillator(Gen): 142 | """Synthesize audio from a saw oscillator 143 | """ 144 | 145 | def __init__(self, sample_rate=16000, name='wavetable'): 146 | super().__init__(name=name) 147 | self.sample_rate = sample_rate 148 | # saw waveform 149 | waveform = torch.roll(torch.linspace(1.0, -1.0, 64), 32) # aliasing? 150 | self.register_buffer('waveform', waveform) 151 | self.param_desc = { 152 | 'amplitudes': {'size': 1, 'range': (0, 1.0), 'type': 'sigmoid'}, 153 | 'f0_hz': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 154 | } 155 | 156 | def forward(self, amplitudes, f0_hz, n_samples): 157 | """forward pass of saw oscillator 158 | 159 | Args: 160 | amplitudes: (batch_size, n_frames, 1) 161 | f0_hz: frequency of oscillator at each frame (batch_size, n_frames, 1) 162 | 163 | Returns: 164 | signal: synthesized signal ([batch_size, n_samples]) 165 | """ 166 | 167 | signal = util.wavetable_synthesis(f0_hz, amplitudes, self.waveform, n_samples, self.sample_rate) 168 | return signal 169 | 170 | class SineOscillator(Gen): 171 | """Synthesize audio from a saw oscillator 172 | """ 173 | 174 | def __init__(self, sample_rate=16000, name='sin'): 175 | super().__init__(name=name) 176 | self.sample_rate = sample_rate 177 | self.param_desc = { 178 | 'amplitudes': {'size': 1, 'range': (0, 1.0), 'type': 'sigmoid'}, 179 | 'frequencies': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 180 | } 181 | 182 | def forward(self, amplitudes, frequencies, n_samples): 183 | """forward pass of saw oscillator 184 | 185 | Args: 186 | amplitudes: (batch_size, n_frames, 1) 187 | f0_hz: frequency of oscillator at each frame (batch_size, n_frames, 1) 188 | 189 | Returns: 190 | signal: synthesized signal ([batch_size, n_samples]) 191 | """ 192 | 193 | signal = util.sin_synthesis(frequencies, amplitudes, n_samples, self.sample_rate) 194 | return signal 195 | 196 | #TODO: wavetable scanner 197 | -------------------------------------------------------------------------------- /diffsynth/modules/harmor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Gen, FREQ_RANGE 4 | import diffsynth.util as util 5 | import numpy as np 6 | 7 | def low_pass(freq, cutoff, q): 8 | ratio = freq/cutoff 9 | s = ratio*1j 10 | freq_response = abs(1/(s**2 + 1/q*s+1)) 11 | return freq_response 12 | 13 | class Harmor(Gen): 14 | """ 15 | Subtractive synth-like additive synth 16 | Mixes 3 oscillators 17 | Each can be interpolated between saw <-> square 18 | Each has separate amplitude/frequencies or not (sep_f0s) 19 | Then a low-pass filter applied to all 20 | """ 21 | 22 | def __init__(self, sample_rate=16000, name='harmor', n_harmonics=24, sep_amp=False, n_oscs=2): 23 | super().__init__(name=name) 24 | self.sample_rate = sample_rate 25 | self.n_harmonics = n_harmonics 26 | self.sep_amp = sep_amp 27 | self.n_oscs = n_oscs 28 | 29 | # harmonic distribution of saw/sqr waves 30 | k = torch.arange(1, n_harmonics+1) 31 | saw_harm_dist = 2 / np.pi * (1/k) 32 | self.register_buffer('saw_harm_dist', saw_harm_dist) 33 | odd = torch.ones(n_harmonics) 34 | odd[1::2] = 0 35 | sqr_harm_dist = 4/np.pi * (1/k) * odd 36 | self.register_buffer('sqr_harm_dist', sqr_harm_dist) 37 | n_amps = self.n_oscs if self.sep_amp else 1 38 | self.param_desc = { 39 | 'amplitudes': {'size': n_amps, 'range': (0, 1), 'type': 'sigmoid'}, 40 | 'osc_mix': {'size': self.n_oscs, 'range': (0, 1), 'type': 'sigmoid'}, 41 | 'f0_hz': {'size': 1, 'range': FREQ_RANGE, 'type': 'freq_sigmoid'}, 42 | 'f0_mult': {'size': self.n_oscs-1, 'range': (1, 8), 'type': 'sigmoid'}, 43 | 'cutoff': {'size': 1, 'range': (30.0, self.sample_rate/2), 'type': 'freq_sigmoid'}, 44 | 'q': {'size': 1, 'range': (0.1, 2.0), 'type': 'sigmoid'} 45 | } 46 | 47 | 48 | def forward(self, amplitudes, osc_mix, f0_hz, f0_mult, cutoff, q, n_samples): 49 | """Synthesize audio with additive synthesizer from controls. 50 | 51 | Args: 52 | amplitudes: Amplitudes tensor of shape. [batch, n_frames, self.n_oscs or 1] 53 | osc_mix: saw<->sqr mix. [batch, n_frames, self.n_oscs] 54 | f0_hz: f0 of each oscillators. [batch, n_frames, 1] 55 | f0_mult: f0 of each oscillators. [batch, n_frames or 1, self.n_oscs-1] 56 | cutoff: cutoff frequency in hz. [batch, n_frames, 1] 57 | q: resonance param 0~around 1.5 is ok. [batch, n_frames or 1, 1] 58 | 59 | Returns: 60 | signal: A tensor of harmonic waves of shape [batch, n_samples]. 61 | """ 62 | batch, n_frames, _ = amplitudes.shape 63 | if f0_hz.shape[1] != n_frames: 64 | f0_hz = util.resample_frames(f0_hz, n_frames) 65 | first_mult = torch.ones(batch, n_frames, 1).to(f0_hz.device) 66 | f0_mult = f0_mult.expand(-1, n_frames, -1) 67 | f0_mult = torch.cat([first_mult, f0_mult], dim=-1) 68 | f0_hz = f0_hz.expand(-1, -1, self.n_oscs) 69 | f0_hz = f0_hz * f0_mult 70 | 71 | if not self.sep_amp: 72 | amplitudes = amplitudes.expand(-1, -1, self.n_oscs) 73 | harm_dist = (1-osc_mix).unsqueeze(-1) * self.saw_harm_dist + osc_mix.unsqueeze(-1) * self.sqr_harm_dist 74 | audio = 0 75 | amps = [] 76 | frqs = [] 77 | for k in range(self.n_oscs): 78 | # create harmonic distributions for each oscs from osc_mix and f0 79 | harmonic_amplitudes = amplitudes[:, :, k:k+1] * harm_dist[:, :, k, :] 80 | harmonic_frequencies = util.get_harmonic_frequencies(f0_hz[:, :, k:k+1], self.n_harmonics) 81 | amps.append(harmonic_amplitudes) 82 | frqs.append(harmonic_frequencies) 83 | amplitude_envelopes = torch.cat(amps, dim=-1) 84 | frequency_envelopes = torch.cat(frqs, dim=-1) 85 | lowpass_multiplier = low_pass(frequency_envelopes, cutoff, q) 86 | filt_amplitude = lowpass_multiplier * amplitude_envelopes 87 | 88 | filt_amplitude = util.resample_frames(filt_amplitude, n_samples) 89 | frequency_envelopes = util.resample_frames(frequency_envelopes, n_samples) 90 | # TODO: Phaser? 91 | 92 | # removes sinusoids above nyquist freq. 93 | audio = util.oscillator_bank(frequency_envelopes, filt_amplitude, sample_rate=self.sample_rate) 94 | return audio -------------------------------------------------------------------------------- /diffsynth/modules/lfo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from diffsynth.processor import Processor 4 | 5 | def soft_clamp_min(x, min_v, T=100): 6 | x = torch.sigmoid((min_v-x)*T)*(min_v-x)+x 7 | return x 8 | 9 | class LFO(Processor): 10 | def __init__(self, n_frames=250, n_secs=4, channels=1, rate_range=(1, 100), level_range=(0, 1), name='lfo'): 11 | super().__init__(name=name) 12 | self.n_secs = n_secs 13 | self.n_frames = n_frames 14 | self.channels = channels 15 | self.param_desc = { 16 | 'rate': {'size': self.channels, 'range': rate_range, 'type': 'sigmoid'}, 17 | 'level': {'size': self.channels, 'range': level_range, 'type': 'sigmoid'}, 18 | } 19 | 20 | def forward(self, rate, level, n_frames=None): 21 | """ 22 | Args: 23 | rate (torch.Tensor): in Hz (batch, 1, self.channels) 24 | level (torch.Tensor): LFO level (batch, 1, self.channels) 25 | n_frames (int, optional): number of frames to generate. Defaults to None. 26 | 27 | Returns: 28 | torch.Tensor: lfo signal (batch_size, n_frames, self.channels) 29 | """ 30 | if n_frames is None: 31 | n_frames = self.n_frames 32 | 33 | batch_size = rate.shape[0] 34 | final_phase = rate * self.n_secs * np.pi * 2 35 | x = torch.linspace(0, 1, n_frames, device=rate.device)[None, :, None].repeat(batch_size, 1, self.channels) # batch, n_frames, channels 36 | phase = x * final_phase 37 | wave = level * torch.sin(phase) 38 | return wave -------------------------------------------------------------------------------- /diffsynth/modules/reverb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffsynth.processor import Processor 4 | import diffsynth.util as util 5 | import numpy as np 6 | 7 | # input_audio decay gain 8 | # make ir 9 | # set ir_0=0 (cut dry signal) 10 | # 11 | 12 | 13 | class DecayReverb(Processor): 14 | """ 15 | Reverb with exponential decay 16 | 1. Make IR based on decay and gain 17 | - Exponentially decaying white noise -> avoids flanging? 18 | 2. convolve with IR 19 | 3. Cut tail 20 | """ 21 | 22 | def __init__(self, name='reverb', ir_length=16000): 23 | super().__init__(name) 24 | noise = torch.rand(1, ir_length)*2-1 # [-1, 1) 25 | noise[:, 0] = 0.0 # initial value should be zero to mask dry signal 26 | self.register_buffer('noise', noise) 27 | time = torch.linspace(0.0, 1.0, ir_length)[None, :] 28 | self.register_buffer('time', time) 29 | self.ir_length = ir_length 30 | self.param_desc = { 31 | 'audio': {'size': 1, 'range': (-1, 1), 'type': 'raw'}, 32 | 'gain': {'size': 1, 'range': (0, 0.25), 'type': 'exp_sigmoid'}, 33 | 'decay': {'size': 1, 'range': (10.0, 25.0), 'type': 'sigmoid'}, 34 | } 35 | 36 | def forward(self, audio, gain, decay): 37 | """ 38 | gain: gain of reverb ir (batch, 1, 1) 39 | decay: decay rate - larger the faster (batch, 1, 1) 40 | """ 41 | gain = gain.squeeze(1) 42 | decay = decay.squeeze(1) 43 | ir = gain * torch.exp(-decay * self.time) * self.noise # batch, time 44 | wet = util.fft_convolve(audio, ir, padding='same', delay_compensation=0) 45 | return audio+wet -------------------------------------------------------------------------------- /diffsynth/perceptual/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyakuchiki/diffsynth/4fc64ba04608cbaf8e016e4878fc50bedfca4da3/diffsynth/perceptual/__init__.py -------------------------------------------------------------------------------- /diffsynth/perceptual/crepe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchcrepe 3 | import torch.nn.functional as F 4 | 5 | from .perceptual import Perceptual 6 | from diffsynth.util import slice_windows 7 | 8 | class CREPELoss(Perceptual): 9 | def __init__(self, model_spec='tiny'): 10 | super().__init__() 11 | torchcrepe.load.model('cpu', model_spec) 12 | self.model = torchcrepe.infer.model 13 | 14 | def process_frames(self, frames): 15 | # https://github.com/maxrmorrison/torchcrepe/blob/08b36ebe8b443ac1d2a9655192e268b3f1b19f34/torchcrepe/core.py#L625 16 | frames = frames - frames.mean(dim=-1, keepdim=True) 17 | frames = frames / torch.clamp(frames.std(dim=-1, keepdim=True), min=1e-10) 18 | return frames 19 | 20 | def perceptual_loss(self, target_audio, input_audio): 21 | target_frames = self.process_frames(slice_windows(target_audio, 1024, 512, window=None)) 22 | input_frames = self.process_frames(slice_windows(input_audio, 1024, 512, window=None)) 23 | # [batch, n_frames, 1024]-> batch', 1024 24 | target_frames = target_frames.flatten(0,1) 25 | target_embed = self.model.embed(target_frames) 26 | input_frames = input_frames.flatten(0,1) 27 | input_embed = self.model.embed(input_frames) 28 | return F.l1_loss(target_embed, input_embed) -------------------------------------------------------------------------------- /diffsynth/perceptual/openl3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import librosa 4 | import torch 5 | import torch.nn.functional as F 6 | import torchopenl3 as ol3 7 | from torchopenl3.models import CustomSTFT 8 | 9 | from .perceptual import Perceptual 10 | from diffsynth.util import slice_windows 11 | 12 | HOP_48K = 242 13 | NDFT_48K_MEL = 2048 14 | ORIG_SIZE = (257, 197) 15 | 16 | def custom_pad(x, n_dft, n_hop, sr): 17 | """ 18 | Taken from torchopenl3 19 | Pad sequence. 20 | Implemented similar to keras version used in kapre=0.1.4 21 | """ 22 | # x: (batch, 1, 16000) 23 | if sr % n_hop == 0: 24 | pad_along_width = max(n_dft - n_hop, 0) 25 | else: 26 | pad_along_width = max(n_dft - (sr % n_hop), 0) 27 | 28 | pad_left = pad_along_width // 2 29 | pad_right = pad_along_width - pad_left 30 | 31 | x = F.pad(x, (pad_left, pad_right)) 32 | return x 33 | 34 | def amplitude_to_decibel(x, amin=1e-10, dynamic_range=80.0): 35 | """ 36 | Taken from torchopenl3 37 | Convert (linear) amplitude to decibel (log10(x)). 38 | Implemented similar to kapre=0.1.4 39 | """ 40 | 41 | log_spec = ( 42 | 10 * torch.log(torch.clamp(x, min=amin)) / np.log(10).astype(np.float32) 43 | ) 44 | if x.ndim > 1: 45 | axis = tuple(range(x.ndim)[1:]) 46 | else: 47 | axis = None 48 | 49 | log_spec = log_spec - torch.amax(log_spec, dim=axis, keepdims=True) 50 | log_spec = torch.clamp(log_spec, min=-1 * dynamic_range) 51 | return log_spec 52 | 53 | class PerceptualOpenl3(Perceptual): 54 | def __init__(self, input_repr='mel256', data='music', embed_size=512, sr=16000, hop_s=1.0, use_layers=None): 55 | super().__init__() 56 | assert input_repr in ['mel256', 'mel128'] 57 | assert data in ['music', 'env'] 58 | assert embed_size in [512, 6144] 59 | self.orig_sr = sr 60 | self.hop_s = hop_s 61 | self.layers = use_layers 62 | model_name='torchopenl3_{0}_{1}_{2}.pth.tar'.format(input_repr, data, embed_size) 63 | print('loading', model_name) 64 | # model = ol3.models.PytorchOpenl3('_', input_repr, embed_size) 65 | # model = model.load_state_dict(torch.load(os.path.join(base_dir,model_name))) 66 | model = ol3.core.load_audio_embedding_model(input_repr, data, embed_size) 67 | self.model = model.eval().requires_grad_(False) 68 | self.model.speclayer = torch.nn.Identity() 69 | 70 | self.n_mels = int(input_repr[3:]) 71 | mult = self.orig_sr/48000 72 | self.n_dft = int(mult * NDFT_48K_MEL) # 682 73 | self.hop_len = int(mult * HOP_48K) # 80 74 | mel_fb = librosa.filters.mel( 75 | sr=self.orig_sr, 76 | n_fft=self.n_dft, 77 | n_mels=self.n_mels, 78 | fmin=0, 79 | fmax=24000, 80 | htk=True, 81 | norm=1, 82 | ) # lots of empty banks 83 | self.register_buffer("mel_fb", torch.tensor(mel_fb, requires_grad=False)) 84 | 85 | self.pad_freq = ORIG_SIZE[0] - (self.n_dft//2+1) 86 | self.remainder_w = ((16000 - self.hop_len) // self.hop_len + 1) - ORIG_SIZE[1] 87 | assert self.remainder_w >= 0 88 | self.lin_spec = CustomSTFT( 89 | n_dft=self.n_dft, 90 | n_hop=self.hop_len, 91 | power_spectrogram=2.0, 92 | return_decibel_spectrogram=False, 93 | ) 94 | 95 | def stft_mel(self, x): 96 | # batch, 1, frame_size 97 | x = slice_windows(x, self.orig_sr, int(self.orig_sr*self.hop_s)).flatten(0,1).unsqueeze(1) 98 | x = custom_pad(x, self.n_dft, self.hop_len, self.orig_sr) 99 | spec = self.lin_spec(x) 100 | assert spec.shape[2] >= ORIG_SIZE[1], spec.shape[2] 101 | spec = spec[:, :, :ORIG_SIZE[1], 0] 102 | # [batch_slices, 257, 197, 1] 103 | melspec = torch.matmul(self.mel_fb, spec) 104 | melspec = torch.sqrt(melspec+1e-10) 105 | return amplitude_to_decibel(melspec) 106 | 107 | def get_embed(self, x, layers=None): 108 | mel = self.stft_mel(x).contiguous() 109 | if layers is not None: # [1,2, ..., 28?] 110 | embeds = self.model(mel, keep_all_outputs=True) 111 | return [embeds[i] for i in layers] 112 | else: 113 | return self.model(mel) 114 | 115 | def perceptual_loss(self, target_audio, input_audio): 116 | target_embed = self.get_embed(target_audio, self.layers) 117 | input_embed = self.get_embed(input_audio, self.layers) 118 | if self.layers is not None: 119 | loss = 0 120 | for target_e, input_e in zip(target_embed, input_embed): 121 | loss += (1 - F.cosine_similarity(target_e, input_e, dim=-1)).mean() 122 | return loss / len(target_e) 123 | else: 124 | return (1 - F.cosine_similarity(target_embed, input_embed, dim=-1)).mean() 125 | 126 | def load_openl3_model(base_dir, input_repr='mel256', data='music', embed_size=512): 127 | assert input_repr in ['mel256', 'mel128', 'linear'] 128 | assert data in ['music', 'env'] 129 | assert embed_size in [512, 6144] 130 | # model_name='torchopenl3_{0}_{1}_{2}.pth.tar'.format(input_repr, data, embed_size) 131 | # print('loading', model_name) 132 | # model = ol3.models.PytorchOpenl3('_', input_repr, embed_size) 133 | # model = model.load_state_dict(torch.load(os.path.join(base_dir,model_name))) 134 | model = ol3.core.load_audio_embedding_model(input_repr, data, embed_size) 135 | return model.eval().requires_grad_(False) 136 | 137 | def openl3_loss(model, target_audio, input_audio, hop_size=1.0): 138 | # target/input : (batch, n_samples) 139 | # resampled to 48000Hz and sliced to be [batch*n_slice, 48000*1(window_size)] 140 | target_audio = ol3.utils.preprocess_audio_batch(target_audio, sr=16000, center=False, hop_size=hop_size, sampler='julian') 141 | input_audio = ol3.utils.preprocess_audio_batch(input_audio, sr=16000, center=False, hop_size=hop_size, sampler='julian') 142 | 143 | target_embed = model(target_audio.contiguous()) 144 | input_embed = model(input_audio.contiguous()) 145 | return (1 - F.cosine_similarity(target_embed, input_embed, dim=1)).mean() 146 | -------------------------------------------------------------------------------- /diffsynth/perceptual/perceptual.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diffsynth.spectral import MelSpec 8 | from diffsynth.layers import MLP, Normalize2d, Resnet2D 9 | from diffsynth.transforms import LogTransform 10 | 11 | class Perceptual(nn.Module): 12 | def perceptual_loss(self, target_audio, input_audio): 13 | raise NotImplementedError 14 | 15 | class PerceptualClassifier(Perceptual): 16 | # take melspectrogram 17 | def __init__(self, output_dims, n_samples, n_mels=128, n_downsample=3, stride=(2,2), res_depth=3, channels=32, dilation_growth_rate=2, m_conv=1.0, n_fft=1024, hop=256, norm='batch', sample_rate=16000): 18 | super().__init__() 19 | self.n_mels = n_mels 20 | self.channels = channels 21 | self.logmel = nn.Sequential(MelSpec(n_fft=n_fft, hop_length=hop, n_mels=n_mels, sample_rate=sample_rate, power=2), LogTransform()) 22 | self.norm = Normalize2d(norm) if norm else None 23 | 24 | spec_len = math.ceil((n_samples - n_fft) / hop) + 1 25 | self.spec_len_target = 2**(int(math.log2(spec_len))+1) # power of 2 26 | kernel_size = [s*2 for s in stride] 27 | pad = [s//2 for s in stride] 28 | input_size = (n_mels, self.spec_len_target) 29 | final_size = input_size 30 | blocks = [] 31 | for i in range(n_downsample): 32 | block = nn.Sequential( 33 | # downsampling conv, output size is L_in/stride 34 | nn.Conv2d(1 if i == 0 else channels, channels, kernel_size, stride, pad), 35 | # ResNet with growing dilation, doesn't change size 36 | Resnet2D(channels, res_depth, m_conv, dilation_growth_rate, reverse_dilation=False), 37 | ) 38 | blocks.append(block) 39 | final_size = (final_size[0] // stride[0], final_size[1] // stride[1]) 40 | 41 | self.convmodel = nn.Sequential(*blocks) 42 | print('output dims after convolution', final_size) 43 | 44 | self.mlp = MLP(final_size[0] * final_size[1] * channels, 64, loop=2) 45 | self.out = nn.Linear(64, output_dims) 46 | 47 | def perceptual_loss(self, target_audio, input_audio, layers=(2, )): 48 | self.eval() 49 | batch_size = input_audio.shape[0] 50 | audios = torch.cat([input_audio, target_audio], dim=0) 51 | specs = self.logmel(audios).unsqueeze(1) 52 | loss = 0 53 | out = specs 54 | for i, m in enumerate(self.convmodel): 55 | out = m(out) 56 | if i in layers: 57 | loss += F.l1_loss(out[:batch_size], out[batch_size:]) 58 | return loss 59 | 60 | def transform(self, audio): 61 | spec = self.logmel(audio) 62 | if self.norm is not None: 63 | spec = self.norm(spec) 64 | # (batch, n_mels, time) 65 | batch_size, n_mels, n_frames = spec.shape 66 | padded_spec = F.pad(spec, (0, self.spec_len_target-n_frames)) 67 | return padded_spec 68 | 69 | def forward(self, audio): 70 | x = self.transform(audio).unsqueeze(1) 71 | x = self.convmodel(x) 72 | x = x.flatten(1, -1) 73 | out = self.mlp(x) 74 | out = self.out(out) 75 | return out 76 | 77 | def train_epoch(self, loader, optimizer, device, clip=1.0): 78 | self.train() 79 | sum_loss = 0 80 | count = 0 81 | for data_dict in loader: 82 | # send data to device 83 | data_dict = {name:tensor.to(device, non_blocking=True) for name, tensor in data_dict.items()} 84 | target = data_dict['label'] 85 | audio = data_dict['audio'] 86 | logits = self(audio) 87 | batch_loss = F.cross_entropy(logits, target) 88 | # Perform backward 89 | optimizer.zero_grad() 90 | batch_loss.backward() 91 | torch.nn.utils.clip_grad_norm_(self.parameters(), clip) 92 | optimizer.step() 93 | sum_loss += batch_loss.detach().item() 94 | count += 1 95 | sum_loss /= count 96 | return sum_loss 97 | 98 | def eval_epoch(self, loader, device): 99 | self.eval() 100 | sum_correct = 0 101 | count = 0 102 | with torch.no_grad(): 103 | for data_dict in loader: 104 | # send data to device 105 | data_dict = {name:tensor.to(device, non_blocking=True) for name, tensor in data_dict.items()} 106 | audio = data_dict['audio'] 107 | target = data_dict['label'] 108 | logits = self(audio) 109 | sum_correct += (torch.argmax(logits, dim=-1) == target).sum().item() 110 | count += audio.shape[0] 111 | return sum_correct/count # accuracy -------------------------------------------------------------------------------- /diffsynth/perceptual/wav2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchcrepe 3 | import torch.nn.functional as F 4 | from torchaudio.models import wav2vec2_base, wav2vec2_large, wav2vec2_large_lv60k 5 | from .perceptual import Perceptual 6 | 7 | class Wav2VecLoss(Perceptual): 8 | """ 9 | Use feature extractor only 10 | """ 11 | def __init__(self, model_spec, state_dict): 12 | super().__init__() 13 | if model_spec == 'base': 14 | entire = wav2vec2_base(num_out=32) 15 | elif model_spec == 'large': 16 | entire = wav2vec2_large(num_out=32) 17 | elif model_spec == 'large_lv60k': 18 | entire = wav2vec2_large_lv60k(num_out=32) 19 | model = entire.feature_extractor 20 | model.load_state_dict(torch.load(state_dict), strict=False) 21 | self.model = model.eval() 22 | 23 | def perceptual_loss(self, target_audio, input_audio): 24 | # length = torch.ones(target_audio.shape[0], device=target_audio.device) * target_audio.shape[-1] 25 | target_embed = self.model(target_audio, None)[0] 26 | input_embed = self.model(input_audio, None)[0] 27 | return F.l1_loss(target_embed, input_embed) -------------------------------------------------------------------------------- /diffsynth/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import diffsynth.util as util 4 | import math 5 | 6 | SCALE_FNS = { 7 | 'raw': lambda x, low, high: x, 8 | 'sigmoid': lambda x, low, high: x*(high-low) + low, 9 | 'freq_sigmoid': lambda x, low, high: util.unit_to_hz(x, low, high, clip=False), 10 | 'exp_sigmoid': lambda x, low, high: util.exp_scale(x, math.log(10.0), high, 1e-7+low), 11 | } 12 | 13 | FREQ_RANGE = (32, 2000) # MIDI C1=32.7 B6=1975.5 14 | 15 | class Processor(nn.Module): 16 | def __init__(self, name): 17 | """ Initialize as module """ 18 | super().__init__() 19 | self.name = name 20 | self.param_desc = {} 21 | 22 | def process(self, scaled_params=[], **kwargs): 23 | # scaling each parameter according to the property 24 | # input is 0~1 25 | for k in kwargs.keys(): 26 | if k not in self.param_desc or k in scaled_params: 27 | continue 28 | desc = self.param_desc[k] 29 | scale_fn = SCALE_FNS[desc['type']] 30 | p_range = desc['range'] 31 | # if (kwargs[k] > 1).any(): 32 | # raise ValueError('parameter to be scaled is not 0~1') 33 | kwargs[k] = scale_fn(kwargs[k], p_range[0], p_range[1]) 34 | return self(**kwargs) 35 | 36 | def forward(self): 37 | raise NotImplementedError 38 | 39 | class Gen(Processor): 40 | def __init__(self, name): 41 | super().__init__(name) 42 | 43 | class Add(Processor): 44 | def __init__(self, name='add'): 45 | super().__init__(name=name) 46 | self.param_desc = { 47 | 'signal_a': {'size':1, 'range': (-1, 1), 'type': 'raw'}, 'signal_b': {'size':1, 'range': (-1, 1), 'type': 'raw'} 48 | } 49 | 50 | def forward(self, signal_a, signal_b): 51 | # kinda sucks can only add two 52 | return signal_a+signal_b 53 | 54 | class Mix(Processor): 55 | def __init__(self, name='add'): 56 | super().__init__(name=name) 57 | self.param_desc = { 58 | 'signal_a': {'size':1, 'range': (-1, 1), 'type': 'raw'}, 59 | 'signal_b': {'size':1, 'range': (-1, 1), 'type': 'raw'}, 60 | 'mix_a': {'size':1, 'range': (0, 1), 'type': 'sigmoid'}, 61 | 'mix_b': {'size':1, 'range': (0, 1), 'type': 'sigmoid'}, 62 | } 63 | 64 | def forward(self, signal_a, signal_b, mix_a, mix_b): 65 | # kinda sucks can only add two 66 | return mix_a[:, :, 0]*signal_a+mix_b[:, :, 0]*signal_b -------------------------------------------------------------------------------- /diffsynth/schedules.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | For scheduling AE parameters 4 | 5 | """ 6 | import functools 7 | 8 | def linear_anneal(i, end_value, start_value, start, warm): 9 | l = max(i - start, 0) 10 | value = (end_value-start_value) * (float(l) / float(max(warm, l))) + start_value 11 | return value 12 | 13 | # loss weights and other parameters used during training 14 | required_args = ['param_w', # parameter loss 15 | 'sw_w', # spectral/waveform loss 16 | 'perc_w', # perceptual loss 17 | ] 18 | 19 | class ParamSchedule(): 20 | def __init__(self, sched_cfg): 21 | self.sched = {} 22 | for param_name, param_sched in sched_cfg.items(): 23 | if param_name == 'name': 24 | continue 25 | if isinstance(param_sched, float): 26 | self.sched[param_name] = param_sched 27 | continue 28 | if param_sched['type'] == 'linear': 29 | self.sched[param_name] = functools.partial(linear_anneal, 30 | start=param_sched['start'], 31 | warm=param_sched['warm'], 32 | start_value=param_sched['start_v'], 33 | end_value=param_sched['end_v']) 34 | else: 35 | raise ValueError() 36 | for k in required_args: 37 | if k not in self.sched: 38 | self.sched[k] = 0.0 39 | 40 | def get_parameters(self, cur_step): 41 | cur_param = {} 42 | for param_name, param_func in self.sched.items(): 43 | cur_param[param_name] = param_func(i=cur_step) if callable(param_func) else param_func 44 | return cur_param 45 | 46 | 47 | # Below is defunct 48 | 49 | SCHEDULE_REGISTRY = {} 50 | 51 | class ParamScheduler(): 52 | def __init__(self, name): 53 | self.sched = SCHEDULE_REGISTRY[name] 54 | for k in required_args: 55 | if k not in self.sched: 56 | self.sched[k] = 0.0 57 | 58 | def get_parameters(self, cur_step): 59 | cur_param = {} 60 | for param_name, param_func in self.sched.items(): 61 | cur_param[param_name] = param_func(i=cur_step) if callable(param_func) else param_func 62 | return cur_param 63 | 64 | switch_1 = { 65 | # parameter loss weight 66 | 'param_w': functools.partial(linear_anneal, end_value=0.0, start_value=10.0, start=12500, warm=37500), 67 | # reconstruction (spectral/wave) loss weight 68 | 'sw_w': functools.partial(linear_anneal, end_value=1.0, start_value=0.0, start=12500, warm=37500), 69 | } 70 | SCHEDULE_REGISTRY['switch_1'] = switch_1 71 | 72 | # even weights 73 | even_1 = { 74 | # parameter loss weight 75 | 'param_w': functools.partial(linear_anneal, end_value=5.0, start_value=10.0, start=12500, warm=37500), 76 | # reconstruction (spectral/wave) loss weight 77 | 'sw_w': functools.partial(linear_anneal, end_value=0.5, start_value=0.0, start=12500, warm=37500), 78 | } 79 | SCHEDULE_REGISTRY['even_1'] = even_1 80 | 81 | # switch completely from param to spectral loss and perceptual loss 82 | switch_p = { 83 | # parameter loss weight 84 | 'param_w': functools.partial(linear_anneal, end_value=0.0, start_value=10.0, start=12500, warm=37500), 85 | # reconstruction (spectral/wave) loss weight 86 | 'sw_w': functools.partial(linear_anneal, end_value=1.0, start_value=0.0, start=12500, warm=37500), 87 | # perceptual loss based on ae 88 | 'perc_w': functools.partial(linear_anneal, end_value=0.01, start_value=0.0, start=12500, warm=37500), 89 | } 90 | SCHEDULE_REGISTRY['switch_p'] = switch_p 91 | 92 | only_param = { 93 | 'param_w': 10.0, 94 | } 95 | SCHEDULE_REGISTRY['only_param'] = only_param 96 | 97 | sw_param = { 98 | 'param_w': 5.0, 99 | 'sw_w': 0.5 100 | } 101 | SCHEDULE_REGISTRY['sw_param'] = sw_param 102 | 103 | only_sw = { 104 | 'sw_w': 1.0, 105 | } 106 | SCHEDULE_REGISTRY['only_sw'] = only_sw 107 | 108 | only_perc = { 109 | 'perc_w': 10.0, 110 | } 111 | SCHEDULE_REGISTRY['only_perc'] = only_perc 112 | 113 | sw_perc = { 114 | 'sw_w': 1.0, 115 | 'perc_w': 10.0 116 | } 117 | SCHEDULE_REGISTRY['sw_perc'] = sw_perc -------------------------------------------------------------------------------- /diffsynth/spectral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import librosa 6 | from torchaudio.transforms import MelScale 7 | from torchaudio.functional import create_dct 8 | from diffsynth.util import log_eps, pad_or_trim_to_expected_length 9 | 10 | amp = lambda x: x[...,0]**2 + x[...,1]**2 11 | 12 | class MelSpec(nn.Module): 13 | def __init__(self, n_fft=2048, hop_length=1024, n_mels=128, sample_rate=16000, power=1, f_min=40, f_max=7600, pad_end=True, center=False): 14 | """ 15 | 16 | """ 17 | super().__init__() 18 | self.n_fft = n_fft 19 | self.hop_length = hop_length 20 | self.power = power 21 | self.f_min = f_min 22 | self.f_max = f_max 23 | self.sample_rate = sample_rate 24 | self.n_mels = n_mels 25 | self.pad_end = pad_end 26 | self.center = center 27 | self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) 28 | 29 | def forward(self, audio): 30 | if self.pad_end: 31 | _batch_dim, l_x = audio.shape 32 | remainder = (l_x - self.n_fft) % self.hop_length 33 | pad = 0 if (remainder == 0) else self.hop_length - remainder 34 | audio = F.pad(audio, (0, pad), 'constant') 35 | spec = spectrogram(audio, self.n_fft, self.hop_length, self.power, self.center) 36 | mel_spec = self.mel_scale(spec) 37 | return mel_spec 38 | 39 | class Spec(nn.Module): 40 | def __init__(self, n_fft=2048, hop_length=1024, power=2, pad_end=True, center=False): 41 | """ 42 | 43 | """ 44 | super().__init__() 45 | self.n_fft = n_fft 46 | self.hop_length = hop_length 47 | self.power = power 48 | self.pad_end = pad_end 49 | self.center = center 50 | 51 | def forward(self, audio): 52 | if self.pad_end: 53 | _batch_dim, l_x = audio.shape 54 | remainder = (l_x - self.n_fft) % self.hop_length 55 | pad = 0 if (remainder == 0) else self.hop_length - remainder 56 | audio = F.pad(audio, (0, pad), 'constant') 57 | spec = spectrogram(audio, self.n_fft, self.hop_length, self.power, self.center) 58 | return spec 59 | 60 | class Mfcc(nn.Module): 61 | def __init__(self, n_fft=2048, hop_length=1024, n_mels=128, n_mfcc=40, norm='ortho', sample_rate=16000, f_min=40, f_max=7600, pad_end=True, center=False): 62 | """ 63 | uses log mels 64 | """ 65 | super().__init__() 66 | self.norm = norm 67 | self.n_mfcc = n_mfcc 68 | self.melspec = MelSpec(n_fft, hop_length, n_mels, sample_rate, power=2, f_min=f_min, f_max=f_max, pad_end=pad_end, center=center) 69 | dct_mat = create_dct(self.n_mfcc, self.melspec.n_mels, self.norm) 70 | self.register_buffer('dct_mat', dct_mat) 71 | 72 | def forward(self, audio): 73 | mel_spec = self.melspec(audio) 74 | mel_spec = torch.log(mel_spec+1e-6) 75 | # (batch, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) 76 | # -> (batch, time, n_mfcc).tranpose(...) 77 | mfcc = torch.matmul(mel_spec.transpose(1, 2), self.dct_mat).transpose(1, 2) 78 | return mfcc 79 | 80 | def spectrogram(audio, size=2048, hop_length=1024, power=2, center=False, window=None): 81 | power_spec = amp(torch.view_as_real(torch.stft(audio, size, window=window, hop_length=hop_length, center=center, return_complex=True))) 82 | if power == 2: 83 | spec = power_spec 84 | elif power == 1: 85 | spec = power_spec.sqrt() 86 | return spec 87 | 88 | def compute_lsd(orig_audio, resyn_audio): 89 | window = torch.hann_window(1024).to(orig_audio.device) 90 | orig_power_s = spectrogram(orig_audio, 1024, 256, window=window).detach() 91 | resyn_power_s = spectrogram(resyn_audio, 1024, 256, window=window).detach() 92 | lsd = torch.sqrt(((10 * (torch.log10(resyn_power_s+1e-5)-torch.log10(orig_power_s+1e-5)))**2).sum(dim=(1,2))) / orig_power_s.shape[-1] 93 | lsd = lsd.mean() 94 | return lsd 95 | 96 | def multiscale_fft(audio, sizes=[64, 128, 256, 512, 1024, 2048], hop_lengths=None, win_lengths=None) -> torch.Tensor: 97 | """multiscale fft power spectrogram 98 | uses torch.stft so it should be differentiable 99 | 100 | Args: 101 | audio : (batch) input audio tensor Shape: [(batch), n_samples] 102 | sizes : fft sizes. Defaults to [64, 128, 256, 512, 1024, 2048]. 103 | overlap : overlap between windows. Defaults to 0.75. 104 | """ 105 | specs = [] 106 | if hop_lengths is None: 107 | overlap = 0.75 108 | hop_lengths = [int((1-overlap)*s) for s in sizes] 109 | if win_lengths is None: 110 | win_lengths = [None for s in sizes] 111 | if isinstance(audio, np.ndarray): 112 | audio = torch.from_numpy(audio) 113 | stft_params = zip(sizes, hop_lengths, win_lengths) 114 | for n_fft, hl, wl in stft_params: 115 | window = torch.hann_window(n_fft if wl is None else wl).to(audio.device) 116 | stft = torch.stft(audio, n_fft, window=window, hop_length=hl, win_length=wl, center=False, return_complex=True) 117 | stft = torch.view_as_real(stft) 118 | specs.append(amp(stft)) 119 | return specs 120 | 121 | def compute_loudness(audio, sample_rate=16000, frame_rate=50, n_fft=2048, range_db=120.0, ref_db=20.7): 122 | """Perceptual loudness in dB, relative to white noise, amplitude=1. 123 | 124 | Args: 125 | audio: tensor. Shape [batch_size, audio_length] or [audio_length]. 126 | sample_rate: Audio sample rate in Hz. 127 | frame_rate: Rate of loudness frames in Hz. 128 | n_fft: Fft window size. 129 | range_db: Sets the dynamic range of loudness in decibels. The minimum loudness (per a frequency bin) corresponds to -range_db. 130 | ref_db: Sets the reference maximum perceptual loudness as given by (A_weighting + 10 * log10(abs(stft(audio))**2.0). The default value corresponds to white noise with amplitude=1.0 and n_fft=2048. There is a slight dependence on fft_size due to different granularity of perceptual weighting. 131 | 132 | Returns: 133 | Loudness in decibels. Shape [batch_size, n_frames] or [n_frames,]. 134 | """ 135 | # Temporarily a batch dimension for single examples. 136 | is_1d = (len(audio.shape) == 1) 137 | if is_1d: 138 | audio = audio[None, :] 139 | 140 | # Take STFT. 141 | hop_length = sample_rate // frame_rate 142 | s = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, return_complex=True) 143 | s = torch.view_as_real(s) 144 | # batch, frequency_bins, n_frames 145 | 146 | # Compute power of each bin 147 | amplitude = torch.sqrt(amp(s) + 1e-5) #sqrt(0) gives nan gradient 148 | power_db = torch.log10(amplitude + 1e-5) 149 | power_db *= 20.0 150 | 151 | # Perceptual weighting. 152 | frequencies = librosa.fft_frequencies(sr=sample_rate, n_fft=n_fft) 153 | a_weighting = librosa.A_weighting(frequencies)[None, :, None] 154 | loudness = power_db + torch.from_numpy(a_weighting.astype(np.float32)).to(audio.device) 155 | 156 | # Set dynamic range. 157 | loudness -= ref_db 158 | loudness = torch.clamp(loudness, min=-range_db) 159 | 160 | # Average over frequency bins. 161 | loudness = torch.mean(loudness, dim=1) 162 | 163 | # Remove temporary batch dimension. 164 | loudness = loudness[0] if is_1d else loudness 165 | 166 | # Compute expected length of loudness vector 167 | n_secs = audio.shape[-1] / float(sample_rate) # `n_secs` can have milliseconds 168 | expected_len = int(n_secs * frame_rate) 169 | 170 | # Pad with `-range_db` noise floor or trim vector 171 | loudness = pad_or_trim_to_expected_length(loudness, expected_len, -range_db) 172 | return loudness 173 | 174 | def loudness_loss(input_audio, target_audio, sr=16000): 175 | input_l = compute_loudness(input_audio, sr) 176 | target_l = compute_loudness(target_audio, sr) 177 | return F.l1_loss(input_l, target_l, reduction='mean') 178 | 179 | def compute_f0(audio, sample_rate, frame_rate, viterbi=True): 180 | """Fundamental frequency (f0) estimate using CREPE. 181 | 182 | This function is non-differentiable and takes input as a numpy array. 183 | Args: 184 | audio: Numpy ndarray of single audio example. Shape [audio_length,]. 185 | sample_rate: Sample rate in Hz. 186 | frame_rate: Rate of f0 frames in Hz. 187 | viterbi: Use Viterbi decoding to estimate f0. 188 | 189 | Returns: 190 | f0_hz: Fundamental frequency in Hz. Shape [n_frames,]. 191 | """ 192 | import crepe 193 | n_secs = len(audio) / float(sample_rate) # `n_secs` can have milliseconds 194 | crepe_step_size = 1000 / frame_rate # milliseconds 195 | expected_len = int(n_secs * frame_rate) 196 | audio = np.asarray(audio) 197 | 198 | # Compute f0 with crepe. 199 | _, f0_hz, f0_confidence, _ = crepe.predict(audio, sr=sample_rate, viterbi=viterbi, step_size=crepe_step_size, center=False, verbose=0) 200 | 201 | # Postprocessing on f0_hz 202 | f0_hz = pad_or_trim_to_expected_length(torch.from_numpy(f0_hz), expected_len, 0) # pad with 0 203 | f0_hz = f0_hz.numpy().astype(np.float32) 204 | 205 | # # Postprocessing on f0_confidence 206 | # f0_confidence = pad_or_trim_to_expected_length(f0_confidence, expected_len, 1) 207 | # f0_confidence = np.nan_to_num(f0_confidence) # Set nans to 0 in confidence 208 | # f0_confidence = f0_confidence.astype(np.float32) 209 | return f0_hz 210 | 211 | def fix_f0(f0, diff_width=4, thres=0.4): 212 | """ 213 | f0: [batch, n_frames] 214 | f0 computed by crepe tends to fall off near note end 215 | fix the f0 value to the last sane value when f0 falls off fast 216 | """ 217 | orig_shape = f0.shape 218 | if len(orig_shape) == 3: #[batch, n_frames, feature_dim=1] 219 | f0 = f0.squeeze(-1) 220 | norm_diff = (f0[:, diff_width:] - f0[:, :-diff_width]) / f0[:, diff_width:] 221 | norm_diff = F.pad(norm_diff, (0, diff_width)) 222 | spike = norm_diff.abs() 1: 117 | wavetable = resample_frames(wavetable, n_samples) 118 | 119 | phase_velocity = frq_env / float(sample_rate) 120 | phase = torch.cumsum(phase_velocity, 1)[:, :-1] % 1.0 # [batch_size, n_samples] 121 | phase = torch.cat([torch.zeros(batch_size, 1).to(phase.device), phase], dim=1) # exclusive cumsum starting at 0 122 | if fm_signal is not None: 123 | audio = linear_lookup(phase+fm_signal / (2*np.pi) , wavetable) 124 | else: 125 | audio = linear_lookup(phase, wavetable) 126 | audio *= amp_env 127 | return audio 128 | 129 | def linear_lookup(phase, wavetable): 130 | """Lookup from wavetables 131 | 132 | Args: 133 | phase: instantaneous phase of base oscillator (0.0~1.0) [batch_size, n_samples] 134 | wavetable ([type]): [batch_size, n_samples, len_waveform] or [batch_size, len_waveform] 135 | """ 136 | phase = phase[:, :, None] 137 | len_waveform = wavetable.shape[-1] 138 | phase_wavetable = torch.linspace(0.0, 1.0, len_waveform).to(wavetable.device) 139 | if len(wavetable.shape) == 2: 140 | wavetable = wavetable.unsqueeze(1) 141 | 142 | # Get pair-wise distances from the oscillator phase to each wavetable point. 143 | # Axes are [batch, time, len_waveform]. NOTE: <- this is super large 144 | phase_distance = abs((phase - phase_wavetable[None, None, :])) 145 | phase_distance *= len_waveform - 1 146 | # Weighting for interpolation. 147 | # Distance is > 1.0 (and thus weights are 0.0) for all but nearest neighbors. 148 | weights = nn.functional.relu(1.0 - phase_distance) # [batch_size, n_samples, len_waveform] 149 | weighted_wavetables = weights * wavetable 150 | return torch.sum(weighted_wavetables, dim=-1) 151 | 152 | def resample_frames(inputs, n_timesteps, mode='linear', add_endpoint=True): 153 | """interpolate signals with a value each frame into signal with a value each timestep 154 | [n_frames] -> [n_timesteps] 155 | 156 | Args: 157 | inputs (torch.Tensor): [n_frames], [batch_size, n_frames], [batch_size, n_frames, channels] 158 | n_timesteps (int): 159 | mode (str): 'window' for interpolating with overlapping windows 160 | add_endpoint ([type]): I think its for windowed interpolation 161 | Returns: 162 | torch.Tensor 163 | [n_timesteps], [batch_size, n_timesteps], or [batch_size, n_timesteps, channels?] 164 | """ 165 | orig_shape = inputs.shape 166 | 167 | if len(orig_shape)==1: 168 | inputs = inputs.unsqueeze(0) # [dummy_batch, n_frames] 169 | inputs = inputs.unsqueeze(1) # [dummy_batch, dummy_channel, n_frames] 170 | if len(orig_shape)==2: 171 | inputs = inputs.unsqueeze(1) # [batch, dummy_channel, n_frames] 172 | if len(orig_shape)==3: 173 | inputs = inputs.permute(0, 2, 1) # # [batch, channels, n_frames] 174 | 175 | if mode == 'window': 176 | raise NotImplementedError 177 | # upsample_with_windows(outputs, n_timesteps, add_endpoint) 178 | else: 179 | # interpolate expects [batch_size, channel, (depth, height,) width] 180 | outputs = nn.functional.interpolate(inputs, size=n_timesteps, mode=mode, align_corners=not add_endpoint) 181 | 182 | if len(orig_shape) == 1: 183 | outputs = outputs.squeeze(1) # get rid of dummy channel 184 | outputs = outputs.squeeze(0) #[n_timesteps] 185 | if len(orig_shape) == 2: 186 | outputs = outputs.squeeze(1) # get rid of dummy channel # [n_frames, n_timesteps] 187 | if len(orig_shape)==3: 188 | outputs = outputs.permute(0, 2, 1) # [batch, n_frames, channels] 189 | 190 | return outputs 191 | 192 | def get_harmonic_frequencies(frequencies, n_harmonics): 193 | """Create integer multiples of the fundamental frequency. 194 | 195 | Args: 196 | frequencies: Fundamental frequencies (Hz). Shape [batch_size, time, 1]. 197 | n_harmonics: Number of harmonics. 198 | 199 | Returns: 200 | harmonic_frequencies: Oscillator frequencies (Hz). 201 | Shape [batch_size, time, n_harmonics]. 202 | """ 203 | f_ratios = torch.linspace(1.0, float(n_harmonics), int(n_harmonics)).to(frequencies.device) 204 | f_ratios = f_ratios[None, None, :] 205 | harmonic_frequencies = frequencies * f_ratios 206 | return harmonic_frequencies 207 | 208 | def remove_above_nyquist(frequency_envelopes, amplitude_envelopes, sample_rate=16000): 209 | """Set amplitudes for oscillators above nyquist to 0. 210 | 211 | Args: 212 | frequency_envelopes: Sample/frame-wise oscillator frequencies (Hz). Shape 213 | [batch_size, n_samples(or n_frames), n_sinusoids]. 214 | amplitude_envelopes: Sample/frame-wise oscillator amplitude. Shape [batch_size, 215 | n_samples(or n_frames), n_sinusoids]. 216 | sample_rate: Sample rate in samples per a second. 217 | 218 | Returns: 219 | amplitude_envelopes: Sample-wise filtered oscillator amplitude. 220 | Shape [batch_size, n_samples, n_sinusoids]. 221 | """ 222 | if amplitude_envelopes.shape[1] != frequency_envelopes.shape[1]: 223 | frequency_envelopes = resample_frames(frequency_envelopes, amplitude_envelopes.shape[1]) 224 | 225 | amplitude_envelopes = torch.where(torch.ge(frequency_envelopes, sample_rate / 2.0), torch.zeros_like(amplitude_envelopes), amplitude_envelopes) 226 | return amplitude_envelopes 227 | 228 | def harmonic_synthesis(frequencies, amplitudes, harmonic_shifts=None, harmonic_distribution=None, n_samples=64000, sample_rate=16000, amp_resample_method='window'): 229 | """Generate audio from frame-wise monophonic harmonic oscillator bank. 230 | 231 | Args: 232 | frequencies: Frame-wise fundamental frequency in Hz. Shape [batch_size, n_frame, 1]. 233 | amplitudes: Frame-wise oscillator peak amplitude. Shape [batch_size, n_frames, 1]. 234 | harmonic_shifts: Harmonic frequency variations (Hz), zero-centered. Total frequency of a harmonic is equal to (frequencies * harmonic_number * (1 + harmonic_shifts)). Shape [batch_size, n_frames, n_harmonics]. 235 | harmonic_distribution: Harmonic amplitude variations, ranged zero to one. Total amplitude of a harmonic is equal to (amplitudes * harmonic_distribution). Shape [batch_size, n_frames, n_harmonics]. 236 | n_samples: Total length of output audio. Interpolates and crops to this. 237 | sample_rate: Sample rate. 238 | amp_resample_method: Mode with which to resample amplitude envelopes. 239 | 240 | Returns: 241 | audio: Output audio. Shape [batch_size, n_samples] 242 | """ 243 | 244 | if harmonic_distribution is not None: 245 | n_harmonics = harmonic_distribution.shape[-1] 246 | elif harmonic_shifts is not None: 247 | n_harmonics = harmonic_shifts.shape[-1] 248 | else: 249 | n_harmonics = 1 250 | 251 | # Create harmonic frequencies [batch_size, n_frames, n_harmonics]. 252 | harmonic_frequencies = get_harmonic_frequencies(frequencies, n_harmonics) 253 | if harmonic_shifts is not None: 254 | harmonic_frequencies *= (1.0 + harmonic_shifts) 255 | 256 | # Create harmonic amplitudes [batch_size, n_frames, n_harmonics]. 257 | if harmonic_distribution is not None: 258 | harmonic_amplitudes = amplitudes * harmonic_distribution 259 | else: 260 | harmonic_amplitudes = amplitudes 261 | 262 | # Create sample-wise envelopes. 263 | frequency_envelopes = resample_frames(harmonic_frequencies, n_samples) # cycles/sec 264 | # amplitude_envelopes = resample_frames(harmonic_amplitudes, n_samples,method=amp_resample_method) 265 | amplitude_envelopes = resample_frames(harmonic_amplitudes, n_samples) # window has not been implemented yet 266 | # Synthesize from harmonics [batch_size, n_samples]. 267 | audio = oscillator_bank(frequency_envelopes, 268 | amplitude_envelopes, 269 | sample_rate=sample_rate) 270 | return audio 271 | 272 | def oscillator_bank(frequency_envelopes, amplitude_envelopes, sample_rate=16000, sum_sinusoids=True): 273 | """Generates audio from sample-wise frequencies for a bank of oscillators. 274 | 275 | Args: 276 | frequency_envelopes: Sample-wise oscillator frequencies (Hz). Shape [batch_size, n_samples, n_sinusoids]. 277 | amplitude_envelopes: Sample-wise oscillator amplitude. Shape [batch_size, n_samples, n_sinusoids]. 278 | sample_rate: Sample rate in samples per a second. 279 | sum_sinusoids: Add up audio from all the sinusoids. 280 | 281 | Returns: 282 | wav: Sample-wise audio. Shape [batch_size, n_samples, n_sinusoids] if sum_sinusoids=False, else shape is [batch_size, n_samples]. 283 | """ 284 | 285 | # Don't exceed Nyquist. 286 | amplitude_envelopes = remove_above_nyquist(frequency_envelopes, amplitude_envelopes, sample_rate) 287 | 288 | # Change Hz to radians per sample. 289 | omegas = frequency_envelopes * (2.0 * np.pi) # rad / sec 290 | omegas = omegas / float(sample_rate) # rad / sample 291 | 292 | # Accumulate phase and synthesize. 293 | output = torch.cumsum(omegas, 1) 294 | output = torch.sin(output) 295 | output = amplitude_envelopes * output # [batch_size, n_samples, n_sinusoids] 296 | if sum_sinusoids: 297 | output = torch.sum(output, dim=-1) # [batch_size, n_samples] 298 | return output 299 | 300 | def get_fft_size(frame_size: int, ir_size: int) -> int: 301 | """Calculate final size for efficient FFT. power of 2 302 | fft size should be greater than frame_size + ir_size 303 | Args: 304 | frame_size: Size of the audio frame. 305 | ir_size: Size of the convolving impulse response. 306 | 307 | Returns: 308 | fft_size: Size for efficient FFT. 309 | """ 310 | convolved_frame_size = ir_size + frame_size - 1 311 | # Next power of 2. 312 | return int(2**np.ceil(np.log2(convolved_frame_size))) 313 | 314 | def frame_signal(signal, frame_size): 315 | """ 316 | cut signal into nonoverlapping frames 317 | Args: 318 | signal: [batch, n_samples] 319 | frame_size: int 320 | Returns: 321 | [batch, n_frames, frame_size] 322 | """ 323 | signal_len = signal.shape[-1] 324 | padding = (frame_size - (signal_len % frame_size) ) % frame_size 325 | signal = torch.nn.functional.pad(signal, (0, padding), 'constant', 0) 326 | frames = torch.split(signal.unsqueeze(1), frame_size, dim=-1) 327 | return torch.cat(frames, dim=1) 328 | 329 | def slice_windows(signal, frame_size, hop_size, window=None): 330 | """ 331 | slice signal into overlapping frames 332 | pads end if (l_x - frame_size) % hop_size != 0 333 | Args: 334 | signal: [batch, n_samples] 335 | frame_size (int): size of frames 336 | hop_size (int): size between frames 337 | Returns: 338 | [batch, n_frames, frame_size] 339 | """ 340 | _batch_dim, l_x = signal.shape 341 | remainder = (l_x - frame_size) % hop_size 342 | pad = 0 if (remainder == 0) else hop_size - remainder 343 | signal = F.pad(signal, (0, pad), 'constant') 344 | signal = signal[:, None, None, :] # adding dummy channel/height 345 | frames = F.unfold(signal, (1, frame_size), stride=(1, hop_size)) #batch, frame_size, n_frames 346 | frames = frames.permute(0, 2, 1) # batch, n_frames, frame_size 347 | if window == 'hamming': 348 | win = torch.hamming_window(frame_size)[None, None, :].to(frames.device) 349 | frames = frames * win 350 | return frames 351 | 352 | def variable_delay(phase, audio, buf_size): 353 | """delay with variable length 354 | 355 | Args: 356 | phase (torch.Tensor): 0~1 0: no delay 1: delay=max_length (batch, n_samples) 357 | audio (torch.Tensor): audio signal (batch, n_samples) 358 | buf_size (int) : buffer size in samples = max delay length 359 | 360 | Returns: 361 | torch.Tensor: delayed audio (batch, n_samples) 362 | """ 363 | batch_size, n_samples = audio.shape 364 | audio_4d = audio[:, None, None, :] # (B, C=1, H=1, W=n_samples) 365 | delay_ratio = buf_size*2/n_samples 366 | grid_x = torch.linspace(-1, 1, n_samples, device=audio.device)[None, :] 367 | grid_x = grid_x - delay_ratio + delay_ratio*phase # B, W=n_samples 368 | grid_x = grid_x[:, None, :, None] # B, H=1, W=n_samples, 1 369 | grid_y = torch.zeros(batch_size, 1, n_samples, 1, device=audio.device) # # B, H=1, W=n_samples, 1 370 | grid = torch.cat([grid_x, grid_y], dim=-1) 371 | output = torch.nn.functional.grid_sample(audio_4d, grid, align_corners=True) 372 | # shape: (B, C=1, H=1, W) 373 | output = output.squeeze(2).squeeze(1) 374 | return output 375 | 376 | # pad_sig = torch.nn.functional.pad(audio, (buf_size-1, 0)) 377 | # assert orig_len % proc_len == 0 378 | # if proc_len is None: 379 | # frames = slice_windows(pad_sig, buf_size, 1) # (b, n_samples, buf_size) 380 | # return linear_lookup(phase, frames) 381 | # else: 382 | # output = [] 383 | # sig_frames = slice_windows(pad_sig, proc_len+buf_size-1, proc_len) 384 | # n_frames = sig_frames.shape[1] 385 | # for i in range(n_frames): 386 | # sig = sig_frames[:, i] # (batch, proc_len+buf_size-1) 387 | # ph = phase[:, i*proc_len:(i+1)*proc_len] 388 | # ll_frames = slice_windows(sig, buf_size, 1) # (batch, proc_len, max_length) 389 | # output.append(linear_lookup(ph, ll_frames)) # (batch, proc_len) 390 | # output = torch.cat(output, dim=-1) 391 | # output = output[:, :orig_len] 392 | # return output 393 | 394 | def overlap_and_add(signal, frame_step): 395 | """overlap-add signals ported from tf.signals 396 | 397 | Args: 398 | signal (torch.Tensor): (batch_size, frames, frame_length) 399 | frame_step (int): size of overlap offset 400 | Returns: 401 | A `Tensor` with shape `[..., output_size]` 402 | """ 403 | batch_size = signal.shape[0] 404 | n_frames = signal.shape[1] 405 | frame_length = signal.shape[2] 406 | 407 | output_length = frame_length + frame_step * (n_frames - 1) 408 | if frame_length == frame_step: 409 | return signal.flatten(-2, -1) 410 | segments = -(-frame_length // frame_step) 411 | pad_width = segments * frame_step - frame_length 412 | # The following code is documented using this example: 413 | # 414 | # frame_step = 2 415 | # signal.shape = (3, 5) 416 | # a b c d e 417 | # f g h i j 418 | # k l m n o 419 | 420 | # Pad the frame_length dimension to a multiple of the frame step. 421 | # Pad the frames dimension by `segments` so that signal.shape = (6, 6) 422 | # a b c d e 0 423 | # f g h i j 0 424 | # k l m n o 0 425 | # 0 0 0 0 0 0 426 | # 0 0 0 0 0 0 427 | # 0 0 0 0 0 0 428 | signal = nn.functional.pad(signal, (0, pad_width, 0, segments)) 429 | # Reshape so that signal.shape = (6, 3, 2) 430 | # ab cd e0 431 | # fg hi j0 432 | # kl mn o0 433 | # 00 00 00 434 | # 00 00 00 435 | # 00 00 00 436 | signal = signal.reshape(batch_size, n_frames+segments, segments, frame_step) 437 | # Transpose dimensions so that signal.shape = (3, 6, 2) 438 | # ab fg kl 00 00 00 439 | # cd hi mn 00 00 00 440 | # e0 j0 o0 00 00 00 441 | signal = signal.transpose(-2, -3) 442 | # Reshape so that signal.shape = (18, 2) 443 | # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 00 00 00 444 | signal = torch.flatten(signal, -3, -2) 445 | signal.shape 446 | # Truncate so that signal.shape = (15, 2) 447 | # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 448 | signal = signal[..., :(n_frames + segments - 1) * segments, :] 449 | # Reshape so that signal.shape = (3, 5, 2) 450 | # ab fg kl 00 00 451 | # 00 cd hi mn 00 452 | # 00 00 e0 j0 o0 453 | signal = signal.reshape(batch_size, segments, (n_frames + segments - 1), frame_step) 454 | # Now, reduce over the columns, to achieve the desired sum. 455 | signal = torch.sum(signal, -3) 456 | # Flatten the array. 457 | signal = signal.reshape(batch_size, (n_frames + segments - 1) * frame_step) 458 | 459 | # Truncate to final length. 460 | signal = signal[..., :output_length] 461 | return signal 462 | 463 | def crop_and_compensate_delay(audio, audio_size, ir_size, padding, delay_compensation): 464 | """Copied over from ddsp 465 | Crop audio output from convolution to compensate for group delay. 466 | 467 | Args: 468 | audio: Audio after convolution. Tensor of shape [batch, time_steps]. 469 | audio_size: Initial size of the audio before convolution. 470 | ir_size: Size of the convolving impulse response. 471 | padding: Either 'valid' or 'same'. For 'same' the final output to be the 472 | same size as the input audio (audio_timesteps). For 'valid' the audio is 473 | extended to include the tail of the impulse response (audio_timesteps + 474 | ir_timesteps - 1). 475 | delay_compensation: Samples to crop from start of output audio to compensate 476 | for group delay of the impulse response. If delay_compensation < 0 it 477 | defaults to automatically calculating a constant group delay of the 478 | windowed linear phase filter from frequency_impulse_response(). 479 | 480 | Returns: 481 | Tensor of cropped and shifted audio. 482 | 483 | Raises: 484 | ValueError: If padding is not either 'valid' or 'same'. 485 | """ 486 | # Crop the output. 487 | if padding == 'valid': 488 | crop_size = ir_size + audio_size - 1 489 | elif padding == 'same': 490 | crop_size = audio_size 491 | else: 492 | raise ValueError('Padding must be \'valid\' or \'same\', instead ' 493 | 'of {}.'.format(padding)) 494 | 495 | # Compensate for the group delay of the filter by trimming the front. 496 | # For an impulse response produced by frequency_impulse_response(), 497 | # the group delay is constant because the filter is linear phase. 498 | total_size = audio.shape[-1] 499 | crop = total_size - crop_size 500 | start = ((ir_size - 1) // 2 - 501 | 1 if delay_compensation < 0 else delay_compensation) 502 | end = crop - start 503 | return audio[:, start:-end] 504 | 505 | def fir_filter(audio, freq_response, filter_size): 506 | # get IR 507 | h = torch.fft.irfft(freq_response, n=filter_size, dim=-1) 508 | 509 | # Compute filter windowed impulse response 510 | # window_size == filter_size 511 | filter_window = torch.hann_window(filter_size).roll(filter_size//2,-1).to(audio.device) 512 | h = filter_window[None, None, :] * h 513 | filtered = fft_convolve(audio, h, padding='same') 514 | return filtered 515 | 516 | def fft_convolve(audio, impulse_response, padding = 'same', delay_compensation = -1): 517 | """ ported from ddsp original description below 518 | Filter audio with frames of time-varying impulse responses. 519 | 520 | Time-varying filter. Given audio [batch, n_samples], and a series of impulse 521 | responses [batch, n_frames, n_impulse_response], splits the audio into frames, 522 | applies filters, and then overlap-and-adds audio back together. 523 | Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute 524 | convolution for large impulse response sizes. 525 | 526 | Args: 527 | audio: Input audio. Tensor of shape [batch, n_samples]. 528 | impulse_response: Finite impulse response to convolve. Can either be a 2-D 529 | Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch, 530 | ir_frames, ir_size]. A 2-D tensor will apply a single linear 531 | time-invariant filter to the audio. A 3-D Tensor will apply a linear 532 | time-varying filter. Automatically chops the audio into equally shaped 533 | blocks to match ir_frames. 534 | padding: Either 'valid' or 'same'. For 'same' the final output to be the 535 | same size as the input audio (n_samples). For 'valid' the audio is 536 | extended to include the tail of the impulse response (n_samples + 537 | ir_timesteps - 1). 538 | delay_compensation: Samples to crop from start of output audio to compensate 539 | for group delay of the impulse response. If delay_compensation is less 540 | than 0 it defaults to automatically calculating a constant group delay of 541 | the windowed linear phase filter from frequency_impulse_response(). 542 | 543 | Returns: 544 | audio_out: Convolved audio. Tensor of shape 545 | [batch, n_samples + ir_timesteps - 1] ('valid' padding) or shape 546 | [batch, audio_timesteps] ('same' padding). 547 | 548 | Raises: 549 | ValueError: If audio and impulse response have different batch size. 550 | ValueError: If audio cannot be split into evenly spaced frames. (i.e. the 551 | number of impulse response frames is on the order of the audio size and 552 | not a multiple of the audio size.) 553 | """ 554 | # Add a frame dimension to impulse response if it doesn't have one. 555 | ir_shape = impulse_response.shape 556 | if len(ir_shape) == 2: 557 | impulse_response = impulse_response[:, None, :] 558 | ir_shape = impulse_response.shape 559 | 560 | # Get shapes of audio and impulse response. 561 | batch_size_ir, n_ir_frames, ir_size = ir_shape 562 | batch_size, audio_size = audio.shape 563 | 564 | # Validate that batch sizes match. 565 | if batch_size != batch_size_ir: 566 | raise ValueError('Batch size of audio ({}) and impulse response ({}) must ' 567 | 'be the same.'.format(batch_size, batch_size_ir)) 568 | 569 | # Cut audio into frames. 570 | frame_size = int(np.ceil(audio_size / n_ir_frames)) 571 | 572 | audio_frames = frame_signal(audio, frame_size) 573 | 574 | # Check that number of frames match. 575 | n_audio_frames = audio_frames.shape[1] 576 | if n_audio_frames != n_ir_frames: 577 | raise ValueError( 578 | 'Number of Audio frames ({}) and impulse response frames ({}) do not ' 579 | 'match. For small hop size = ceil(audio_size / n_ir_frames), ' 580 | 'number of impulse response frames must be a multiple of the audio ' 581 | 'size.'.format(n_audio_frames, n_ir_frames)) 582 | 583 | # Pad and FFT the audio and impulse responses. 584 | fft_size = get_fft_size(frame_size, ir_size) 585 | 586 | S = torch.fft.rfft(audio_frames, n=fft_size, dim=-1) # zeropadded 587 | H = torch.fft.rfft(impulse_response, n=fft_size, dim=-1) 588 | 589 | # Multiply the FFTs (same as convolution in time). 590 | # Filter the original audio 591 | audio_ir_fft = S*H 592 | 593 | # Take the IFFT to resynthesize audio. 594 | # batch_size, n_frames, fft_size 595 | audio_frames_out = torch.fft.irfft(audio_ir_fft, n=fft_size, dim=-1) 596 | audio_out = overlap_and_add(audio_frames_out, frame_size) 597 | 598 | # Crop and shift the output audio. 599 | return crop_and_compensate_delay(audio_out, audio_size, ir_size, padding, 600 | delay_compensation) 601 | 602 | def pad_or_trim_to_expected_length(vector, expected_len, pad_value=0, len_tolerance=20): 603 | """Ported from DDSP 604 | Make vector equal to the expected length. 605 | 606 | Feature extraction functions like `compute_loudness()` or `compute_f0` produce feature vectors that vary in length depending on factors such as `sample_rate` or `hop_size`. This function corrects vectors to the expected length, warning the user if the difference between the vector and expected length was unusually high to begin with. 607 | 608 | Args: 609 | vector: Tensor. Shape [(batch,) vector_length] 610 | expected_len: Expected length of vector. 611 | pad_value: Value to pad at end of vector. 612 | len_tolerance: Tolerance of difference between original and desired vector length. 613 | 614 | Returns: 615 | vector: Vector with corrected length. 616 | 617 | Raises: 618 | ValueError: if `len(vector)` is different from `expected_len` beyond 619 | `len_tolerance` to begin with. 620 | """ 621 | expected_len = int(expected_len) 622 | vector_len = int(vector.shape[-1]) 623 | 624 | if abs(vector_len - expected_len) > len_tolerance: 625 | # Ensure vector was close to expected length to begin with 626 | raise ValueError('Vector length: {} differs from expected length: {} ' 627 | 'beyond tolerance of : {}'.format(vector_len, 628 | expected_len, 629 | len_tolerance)) 630 | 631 | is_1d = (len(vector.shape) == 1) 632 | vector = vector[None, :] if is_1d else vector 633 | 634 | # Pad missing samples 635 | if vector_len < expected_len: 636 | n_padding = expected_len - vector_len 637 | vector = F.pad(vector, ((0, 0, 0, n_padding)), mode='constant', value=pad_value) 638 | # Trim samples 639 | elif vector_len > expected_len: 640 | vector = vector[..., :expected_len] 641 | 642 | # Remove temporary batch dimension. 643 | vector = vector[0] if is_1d else vector 644 | return vector 645 | 646 | def midi_to_hz(notes): 647 | return 440.0 * (2.0**((notes - 69.0) / 12.0)) 648 | 649 | def hz_to_midi(frequencies): 650 | """torch-compatible hz_to_midi function.""" 651 | if isinstance(frequencies, torch.Tensor): 652 | notes = 12.0 * (torch.log2(frequencies+1e-5) - math.log2(440.0)) + 69.0 653 | # Map 0 Hz to MIDI 0 (Replace -inf MIDI with 0.) 654 | notes = torch.where(torch.le(frequencies, 0.0), torch.zeros_like(frequencies, device=frequencies.device), notes) 655 | else: 656 | notes = 12.0 * (math.log2(frequencies+1e-5) - math.log2(440.0)) + 69.0 657 | return notes 658 | 659 | def unit_to_midi(unit, midi_min, midi_max = 90.0, clip = False): 660 | """Map the unit interval [0, 1] to MIDI notes.""" 661 | unit = torch.clamp(unit, 0.0, 1.0) if clip else unit 662 | return midi_min + (midi_max - midi_min) * unit 663 | 664 | def unit_to_hz(unit, hz_min, hz_max, clip = False): 665 | """Map unit interval [0, 1] to [hz_min, hz_max], scaling logarithmically.""" 666 | midi = unit_to_midi(unit, midi_min=hz_to_midi(hz_min), midi_max=hz_to_midi(hz_max), clip=clip) 667 | return midi_to_hz(midi) 668 | 669 | def frequencies_sigmoid(freqs, hz_min=8.2, hz_max=8000.0): 670 | """Sum of sigmoids to logarithmically scale network outputs to frequencies. 671 | without depth 672 | 673 | Args: 674 | freqs: Neural network outputs, [batch, time, n_sinusoids] 675 | hz_min: Lowest frequency to consider. 676 | hz_max: Highest frequency to consider. 677 | 678 | Returns: 679 | A tensor of frequencies in hertz [batch, time, n_sinusoids]. 680 | """ 681 | freqs = torch.sigmoid(freqs) 682 | return unit_to_hz(freqs, hz_min=hz_min, hz_max=hz_max) 683 | 684 | # def upsample_with_windows(inputs, n_timesteps, add_endpoint): 685 | # """[summary] 686 | # code borrowed from ddsp 687 | 688 | # Args: 689 | # inputs ([type]): [description] 690 | # n_timesteps ([type]): [description] 691 | # add_endpoint ([type]): [description] 692 | 693 | # Returns: 694 | # [type]: [description] 695 | # """ 696 | # if len(inputs.shape) != 3: 697 | # raise ValueError('Upsample_with_windows() only supports 3 dimensions, ' 698 | # 'not {}.'.format(inputs.shape)) 699 | 700 | # # Mimic behavior of tf.image.resize. 701 | # # For forward (not endpointed), hold value for last interval. 702 | # if add_endpoint: 703 | # inputs = torch.cat([inputs, inputs[:, -1:, :]], axis=1) 704 | 705 | # n_frames = int(inputs.shape[1]) 706 | # n_intervals = (n_frames - 1) 707 | 708 | # if n_frames >= n_timesteps: 709 | # raise ValueError('Upsample with windows cannot be used for downsampling' 710 | # 'More input frames ({}) than output timesteps ({})'.format( 711 | # n_frames, n_timesteps)) 712 | 713 | # if n_timesteps % n_intervals != 0.0: 714 | # minus_one = '' if add_endpoint else ' - 1' 715 | # raise ValueError( 716 | # 'For upsampling, the target the number of timesteps must be divisible ' 717 | # 'by the number of input frames{}. (timesteps:{}, frames:{}, ' 718 | # 'add_endpoint={}).'.format(minus_one, n_timesteps, n_frames, 719 | # add_endpoint)) 720 | 721 | # # Constant overlap-add, half overlapping windows. 722 | # hop_size = n_timesteps // n_intervals 723 | # window_length = 2 * hop_size 724 | # window = torch.hann_window(window_length) # [window] 725 | 726 | # # Transpose for overlap_and_add. 727 | # x = torch.transpose(inputs, 1, 2) # [batch_size, n_channels, n_frames] 728 | 729 | # # Broadcast multiply. 730 | # # Add dimension for windows [batch_size, n_channels, n_frames, window]. 731 | # x = x.unsqueeze(-1) 732 | # window = window[None, None, None, :] 733 | # x_windowed = (x * window) 734 | # x = overlap_add(x_windowed, hop_size) 735 | # nn.functional.fold(x_windowed, stride=hop_size) 736 | # # Transpose back. 737 | # x = tf.transpose(x, perm=[0, 2, 1]) # [batch_size, n_timesteps, n_channels] 738 | 739 | # # Trim the rise and fall of the first and last window. 740 | # return x[:, hop_size:-hop_size, :] -------------------------------------------------------------------------------- /gen_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | import tqdm 4 | import argparse, os 5 | from diffsynth.modelutils import construct_synth_from_conf 6 | from omegaconf import OmegaConf 7 | 8 | def make_dirs(base_dir, synth_name): 9 | dat_dir = os.path.join(base_dir, synth_name) 10 | audio_dir = os.path.join(dat_dir, 'audio') 11 | param_dir = os.path.join(dat_dir, 'param') 12 | os.makedirs(audio_dir, exist_ok=True) 13 | os.makedirs(param_dir, exist_ok=True) 14 | return audio_dir, param_dir 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('dataset_dir', type=str, help='') 19 | parser.add_argument('synth_conf', type=str, help='') 20 | parser.add_argument('--data_size', type=int, default=20000) 21 | parser.add_argument('--audio_len', type=float, default=4.0) 22 | parser.add_argument('--sr', type=int, default=16000) 23 | parser.add_argument('--batch_size', type=int, default=64) 24 | parser.add_argument('--save_param', action='store_true') 25 | args = parser.parse_args() 26 | 27 | conf = OmegaConf.load(args.synth_conf) 28 | synth = construct_synth_from_conf(conf).to('cuda') 29 | 30 | audio_dir, param_dir = make_dirs(args.dataset_dir, conf.name) 31 | 32 | n_samples = int(args.audio_len * args.sr) 33 | count = 0 34 | break_flag = False 35 | skip_count = 0 36 | if args.save_param: 37 | save_params = conf.save_params # harmor_q, harmor_cutoff, etc. 38 | else: # save all external params 39 | rev_dag_summary = {v: k for k,v in synth.dag_summary.items()} # HARM_Q: harmor_q 40 | save_params = [rev_dag_summary[k] for k in synth.ext_param_sizes.keys()] 41 | with torch.no_grad(): 42 | with tqdm.tqdm(total=args.data_size) as pbar: 43 | while True: 44 | if break_flag: 45 | break 46 | audio, output = synth.uniform(args.batch_size, n_samples, 'cuda') 47 | params = {k: output[synth.dag_summary[k]].cpu() for k in save_params} 48 | for j in range(args.batch_size): 49 | if count >= args.data_size: 50 | break_flag=True 51 | break 52 | aud = audio[j] 53 | # remove silence 54 | if aud.abs().max() < 0.05: 55 | skip_count += 1 56 | continue 57 | p = {k:pv[j] for k, pv in params.items()} 58 | param_path = os.path.join(param_dir, '{0:05}.pt'.format(count)) 59 | torch.save(p, param_path) 60 | audio_path = os.path.join(audio_dir, '{0:05}.wav'.format(count)) 61 | sf.write(audio_path, aud.cpu().numpy(), samplerate=args.sr) 62 | count+=1 63 | pbar.update(1) 64 | print('skipped {0} quiet sounds'.format(skip_count)) -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import librosa 4 | import librosa.display 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from pytorch_lightning.utilities.distributed import rank_zero_only 8 | from pytorch_lightning.callbacks import Callback 9 | 10 | def plot_spec(y, ax, sr=16000): 11 | D = librosa.stft(y) # STFT of y 12 | S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) 13 | img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='log', ax=ax) 14 | ax.label_outer() 15 | 16 | def plot_recons(x, x_tilde, plot_dir, name=None, epochs=None, sr=16000, num=6, save=True): 17 | """Plot spectrograms/waveforms of original/reconstructed audio 18 | 19 | Args: 20 | x (numpy array): [batch, n_samples] 21 | x_tilde (numpy array): [batch, n_samples] 22 | sr (int, optional): sample rate. Defaults to 16000. 23 | dir (str): plot directory. 24 | name (str, optional): file name. 25 | epochs (int, optional): no. of epochs. 26 | num (int, optional): number of spectrograms to plot. Defaults to 6. 27 | """ 28 | fig, axes = plt.subplots(num, 4, figsize=(15, 30)) 29 | for i in range(num): 30 | plot_spec(x[i], axes[i, 0], sr) 31 | plot_spec(x_tilde[i], axes[i, 1], sr) 32 | axes[i, 2].plot(x[i]) 33 | axes[i, 2].set_ylim(-1,1) 34 | axes[i, 3].plot(x_tilde[i]) 35 | axes[i, 3].set_ylim(-1,1) 36 | if save: 37 | if epochs: 38 | fig.savefig(os.path.join(plot_dir, 'epoch{:0>3}_recons.png'.format(epochs))) 39 | plt.close(fig) 40 | else: 41 | fig.savefig(os.path.join(plot_dir, name+'.png')) 42 | plt.close(fig) 43 | else: 44 | return fig 45 | 46 | def save_to_board(i, name, writer, orig_audio, resyn_audio, plot_num=4, sr=16000): 47 | orig_audio = orig_audio.detach().cpu() 48 | resyn_audio = resyn_audio.detach().cpu() 49 | for j in range(plot_num): 50 | writer.add_audio('{0}_orig/{1}'.format(name, j), orig_audio[j].unsqueeze(0), i, sample_rate=sr) 51 | writer.add_audio('{0}_resyn/{1}'.format(name, j), resyn_audio[j].unsqueeze(0), i, sample_rate=sr) 52 | fig = plot_recons(orig_audio.detach().cpu().numpy(), resyn_audio.detach().cpu().numpy(), '', sr=sr, num=plot_num, save=False) 53 | writer.add_figure('plot_recon_{0}'.format(name), fig, i) 54 | 55 | class AudioLogger(Callback): 56 | def __init__(self, batch_frequency=1000): 57 | super().__init__() 58 | self.batch_freq = batch_frequency 59 | 60 | @rank_zero_only 61 | def log_local(self, writer, name, current_epoch, orig_audio, resyn_audio): 62 | save_to_board(current_epoch, name, writer, orig_audio, resyn_audio) 63 | 64 | def log_audio(self, pl_module, batch, batch_idx, name="train"): 65 | if batch_idx % self.batch_freq == 0: 66 | is_train = pl_module.training 67 | if is_train: 68 | pl_module.eval() 69 | # get audio 70 | with torch.no_grad(): 71 | resyn_audio, _outputs = pl_module(batch) 72 | resyn_audio = torch.clamp(resyn_audio.detach().cpu(), -1, 1) 73 | orig_audio = torch.clamp(batch['audio'].detach().cpu(), -1, 1) 74 | 75 | self.log_local(pl_module.logger.experiment, name, pl_module.current_epoch, orig_audio, resyn_audio) 76 | 77 | if is_train: 78 | pl_module.train() 79 | 80 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 81 | self.log_audio(pl_module, batch, batch_idx, name="train") 82 | 83 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 84 | data_type='id' if dataloader_idx==0 else 'ood' 85 | self.log_audio(pl_module, batch, batch_idx, name="val_"+data_type) 86 | 87 | def save_to_board_mel(i, writer, orig_mel, recon_mel, plot_num=8): 88 | orig_mel = orig_mel.detach().cpu() 89 | recon_mel = recon_mel.detach().cpu() 90 | 91 | fig, axes = plt.subplots(2, plot_num, figsize=(30, 8)) 92 | for j in range(plot_num): 93 | axes[0, j].imshow(orig_mel[j], aspect=0.25) 94 | axes[1, j].imshow(recon_mel[j], aspect=0.25) 95 | fig.tight_layout() 96 | writer.add_figure('plot_recon', fig, i) 97 | 98 | def plot_param_dist(param_stats): 99 | """ 100 | violin plot of parameter values 101 | """ 102 | 103 | fig, ax = plt.subplots(figsize=(15, 5)) 104 | labels = param_stats.keys() 105 | parts = ax.violinplot(param_stats.values(), showmeans=True) 106 | ax.set_xticks(np.arange(1, len(labels) + 1)) 107 | ax.set_xticklabels(labels, fontsize=8) 108 | ax.set_xlim(0.25, len(labels) + 0.75) 109 | ax.set_ylim(0, 1) 110 | return fig -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os, argparse, json, pickle, re 2 | import matplotlib.pyplot as plt 3 | import torch 4 | 5 | from diffsynth.loss import SpecWaveLoss 6 | from diffsynth import util 7 | from diffsynth.model import EstimatorSynth 8 | from plot import plot_spec, plot_param_dist 9 | import soundfile as sf 10 | 11 | import hydra 12 | import pytorch_lightning as pl 13 | 14 | def write_plot_audio(y, name): 15 | # y; numpy array of audio 16 | # write audio file 17 | sf.write('{0}.wav'.format(name), y, 16000) 18 | fig, ax = plt.subplots(figsize=(1.5, 1), tight_layout=True) 19 | ax.axis('off') 20 | plot_spec(y, ax, 16000) 21 | fig.savefig('{0}.png'.format(name)) 22 | plt.close(fig) 23 | 24 | def test_model(model, id_loader, ood_loader, device, sw_loss=None, perc_model=None): 25 | model.eval() 26 | # in-domain 27 | syn_result = util.StatsLog() 28 | param_stats = [util.StatsLog(), util.StatsLog()] 29 | with torch.no_grad(): 30 | for data_dict in id_loader: 31 | params = data_dict.pop('params') 32 | params = {name:tensor.to(device, non_blocking=True) for name, tensor in params.items()} 33 | data_dict = {name:tensor.to(device, non_blocking=True) for name, tensor in data_dict.items()} 34 | data_dict['params'] = params 35 | 36 | resyn_audio, outputs = model(data_dict) 37 | # parameter values 38 | monitor_params = list(params.keys()) 39 | for pname, pvalue in outputs.items(): 40 | if pname in monitor_params: 41 | # pvalue: batch, n_frames, param_dim>=1 42 | pvs = pvalue.mean(dim=1) 43 | for i, pv in enumerate(pvs.unbind(-1)): 44 | param_stats[0].add_entry(pname+'{0}'.format(i), pv) 45 | 46 | # Reconstruction loss 47 | losses = model.train_losses(data_dict, outputs, sw_loss=sw_loss, perc_model=perc_model) 48 | losses.update(model.monitor_losses(data_dict, outputs)) 49 | syn_result.update(losses) 50 | syn_result_dict = {'id/'+k: v for k, v in syn_result.average().items()} 51 | 52 | # out-of-domain 53 | real_result = util.StatsLog() 54 | with torch.no_grad(): 55 | for data_dict in ood_loader: 56 | data_dict = {name:tensor.to(device, non_blocking=True) for name, tensor in data_dict.items()} 57 | 58 | resyn_audio, outputs = model(data_dict) 59 | # parameter values 60 | monitor_params = list(params.keys()) 61 | for pname, pvalue in outputs.items(): 62 | if pname in monitor_params: 63 | # pvalue: batch, n_frames, param_dim>=1 64 | pvs = pvalue.mean(dim=1) 65 | for i, pv in enumerate(pvs.unbind(-1)): 66 | param_stats[1].add_entry(pname+'{0}'.format(i), pv) 67 | 68 | # Reconstruction loss 69 | losses = model.train_losses(data_dict, outputs, sw_loss=sw_loss, perc_model=perc_model) 70 | losses.update(model.monitor_losses(data_dict, outputs)) 71 | real_result.update(losses) 72 | real_result_dict = {'ood/'+k: v for k, v in real_result.average().items()} 73 | 74 | result = {} 75 | result.update(syn_result_dict) 76 | result.update(real_result_dict) 77 | return result, param_stats 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('ckpt', type=str, help='') 82 | parser.add_argument('--batch_size', type=int, default=64, help='') 83 | parser.add_argument('--write_audio', action='store_true') 84 | args = parser.parse_args() 85 | 86 | pl.seed_everything(0, workers=True) 87 | device = 'cuda' 88 | 89 | ckpt_dir = args.ckpt 90 | config_dir = re.sub(r'tb_logs.*', '.hydra', ckpt_dir) 91 | # initialize model 92 | hydra.initialize(config_path=config_dir, job_name="test") 93 | cfg = hydra.compose(config_name="config") 94 | 95 | model = EstimatorSynth(cfg.model, cfg.synth, cfg.schedule) 96 | datamodule = hydra.utils.instantiate(cfg.data) 97 | datamodule.setup(None) 98 | id_test_loader, ood_test_loader = datamodule.test_dataloader() 99 | model = EstimatorSynth.load_from_checkpoint(ckpt_dir).to(device) 100 | 101 | # directory for audio/spectrogram output 102 | output_dir = re.sub(r'tb_logs.*', 'test/output', ckpt_dir) 103 | os.makedirs(output_dir, exist_ok=True) 104 | # directory for ground-truth 105 | target_dir = re.sub(r'tb_logs.*', 'test/target', ckpt_dir) 106 | os.makedirs(target_dir, exist_ok=True) 107 | 108 | id_testbatch = next(iter(id_test_loader)) 109 | id_testbatch.pop('params') 110 | id_testbatch = {name:tensor.to(device) for name, tensor in id_testbatch.items()} 111 | ood_testbatch = next(iter(ood_test_loader)) 112 | ood_testbatch = {name:tensor.to(device) for name, tensor in ood_testbatch.items()} 113 | 114 | sw_loss = SpecWaveLoss(l1_w=0.0, l2_w=0.0, norm=None) 115 | with torch.no_grad(): 116 | model = model.eval() 117 | if args.write_audio: 118 | # render audio and plot spectrograms? 119 | id_resyn_audio, _output = model(id_testbatch) 120 | for i in range(args.batch_size): 121 | resyn_audio = id_resyn_audio[i].detach().cpu().numpy() 122 | write_plot_audio(resyn_audio, os.path.join(output_dir, 'id_{0:03}'.format(i))) 123 | orig_audio = id_testbatch['audio'][i].detach().cpu().numpy() 124 | write_plot_audio(orig_audio, os.path.join(target_dir, 'id_{0:03}'.format(i))) 125 | ood_resyn_audio, _output = model(ood_testbatch) 126 | for i in range(args.batch_size): 127 | resyn_audio = ood_resyn_audio[i].detach().cpu().numpy() 128 | write_plot_audio(resyn_audio, os.path.join(output_dir, 'ood_{0:03}'.format(i))) 129 | orig_audio = ood_testbatch['audio'][i].detach().cpu().numpy() 130 | write_plot_audio(orig_audio, os.path.join(target_dir, 'ood_{0:03}'.format(i))) 131 | print('finished writing audio') 132 | 133 | # get objective measure 134 | test_losses, param_stats = test_model(model, id_loader=id_test_loader, ood_loader=ood_test_loader, device=device, sw_loss=sw_loss) 135 | results_str = 'Test loss: ' 136 | for k in test_losses: 137 | results_str += '{0}: {1:.3f} '.format(k, test_losses[k]) 138 | print(results_str) 139 | with open(os.path.join(output_dir, 'test_loss.json'), 'w') as f: 140 | json.dump(test_losses, f) 141 | # plot parameter stats 142 | fig_1 = plot_param_dist(param_stats[0].stats) 143 | fig_1.savefig(os.path.join(output_dir, 'id_params_dist.png')) 144 | fig_2 = plot_param_dist(param_stats[1].stats) 145 | fig_2.savefig(os.path.join(output_dir, 'ood_params_dist.png')) 146 | with open(os.path.join(output_dir, 'params_dists.pkl'), 'wb') as f: 147 | pickle.dump(param_stats, f) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | from plot import AudioLogger 4 | import warnings 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from diffsynth.model import EstimatorSynth 7 | 8 | @hydra.main(config_path="configs/", config_name="config.yaml") 9 | def main(cfg): 10 | pl.seed_everything(0, workers=True) 11 | warnings.simplefilter('ignore', RuntimeWarning) 12 | model = EstimatorSynth(cfg.model, cfg.synth, cfg.schedule) 13 | logger = pl.loggers.TensorBoardLogger("tb_logs", "", default_hp_metric=False, version='') 14 | hparams = {'data': cfg.data.train_type, 'schedule': cfg.schedule.name, 'synth': cfg.synth.name} 15 | # dummy value 16 | logger.log_hyperparams(hparams, {'val_id/lsd': 40, 'val_ood/lsd': 40}) 17 | # log audio examples 18 | checkpoint_callback = ModelCheckpoint(monitor="val_ood/lsd", save_top_k=1, filename="epoch_{epoch:03}_{val_ood/lsd:.2f}", save_last=True, auto_insert_metric_name=False) 19 | callbacks = [pl.callbacks.LearningRateMonitor(logging_interval='step'), AudioLogger(), checkpoint_callback] 20 | trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 21 | datamodule = hydra.utils.instantiate(cfg.data) 22 | # make model 23 | trainer.fit(model=model, datamodule=datamodule) 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /trainclassifier.py: -------------------------------------------------------------------------------- 1 | import os, json, argparse 2 | import tqdm 3 | import numpy as np 4 | import librosa 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from trainutils import get_loaders 10 | from diffsynth.perceptual.perceptual import PerceptualClassifier 11 | 12 | class NSynthDataset(Dataset): 13 | def __init__(self, base_dir, sample_rate=16000, length=4.0): 14 | self.base_dir = base_dir 15 | self.raw_dir = os.path.join(base_dir, 'audio') 16 | self.length = length 17 | self.sample_rate = sample_rate 18 | # load json file that comes with nsynth dataset 19 | with open(os.path.join(self.base_dir, 'examples.json')) as f: 20 | self.json_dict = json.load(f) 21 | self.json_keys = list(self.json_dict.keys()) 22 | # restrict the dataset to some categories 23 | self.nb_files = len(self.json_keys) 24 | 25 | def __getitem__(self, index): 26 | output = {} 27 | note = self.json_dict[self.json_keys[index]] 28 | file_name = os.path.join(self.raw_dir, note['note_str']+'.wav') 29 | output['label'] = int(note['instrument_family'])# 0=bass, 1=brass, 2=flute, etc. 30 | output['audio'], _sr = librosa.load(file_name, sr=self.sample_rate, duration=self.length) 31 | return output 32 | 33 | def __len__(self): 34 | return self.nb_files 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('output_dir', type=str, help='') 39 | parser.add_argument('dataset', type=str, help='directory of dataset') 40 | parser.add_argument('--batch_size', type=int, default=64) 41 | parser.add_argument('--lr', type=float, default=1e-4) 42 | parser.add_argument('--epochs', type=int, default=200) 43 | parser.add_argument('--decay_rate', type=float, default=1.0, help='') 44 | parser.add_argument('--length', type=float, default=4.0, help='') 45 | parser.add_argument('--sr', type=int, default=16000, help='') 46 | args = parser.parse_args() 47 | 48 | torch.manual_seed(0) # just fixes the dataset/dataloader and initial weights 49 | np.random.seed(seed=0) # subset 50 | 51 | device = 'cuda' 52 | # output dir 53 | os.makedirs(args.output_dir, exist_ok=True) 54 | model_dir = os.path.join(args.output_dir, 'model') 55 | os.makedirs(model_dir, exist_ok=True) 56 | 57 | with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f: 58 | json.dump(args.__dict__, f, indent=4) 59 | 60 | # load dataset 61 | dset = NSynthDataset(args.dataset, sample_rate=args.sr, length=args.length) 62 | dsets, loaders = get_loaders(dset, args.batch_size, splits=[.8, .2, 0.0], nbworkers=4) 63 | dset_train, dset_valid, _dset_test = dsets 64 | train_loader, valid_loader, _test_loader = loaders 65 | testbatch = next(iter(valid_loader)) 66 | 67 | # load model 68 | model = PerceptualClassifier(11, testbatch['audio'].shape[-1]).to(device) # 11 classes to classify 69 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 70 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate) 71 | 72 | best_acc= 0 73 | # training loop 74 | for i in tqdm.tqdm(range(1, args.epochs+1)): 75 | train_loss = model.train_epoch(loader=train_loader, optimizer=optimizer, device=device) 76 | valid_acc = model.eval_epoch(loader=valid_loader, device=device) 77 | scheduler.step() 78 | tqdm.tqdm.write('Epoch: {0:03} Train: {1:.4f} Valid: {2:.4f}'.format(i, train_loss, valid_acc)) 79 | if valid_acc > best_acc: 80 | best_acc = valid_acc 81 | torch.save(model.state_dict(), os.path.join(model_dir, 'state_dict.pth')) 82 | if i % 10 == 0: 83 | torch.save(model.state_dict(), os.path.join(model_dir, 'statedict_{0:03}.pth'.format(i))) --------------------------------------------------------------------------------