├── inference.py ├── utils ├── plot.py ├── __init__.py ├── loading.py └── data.py ├── tests ├── __init__.py ├── spec │ ├── spec_helper.rb │ └── localhost │ │ ├── util_spec.rb │ │ └── pip_spec.rb ├── data │ ├── example.mp3 │ ├── human_example_clean.wav │ └── human_example_ipad_balcony.wav ├── test_dataset_config.yaml ├── test_train_config.yaml ├── test_scripts.py ├── conftest.py ├── test_features.py └── test_models.py ├── ckpts ├── .gitignore ├── g.pt.dvc └── do.pt.dvc ├── .rspec ├── .dvc ├── .gitignore ├── config └── plots │ ├── default.json │ ├── smooth.json │ ├── confusion.json │ ├── confusion_normalized.json │ ├── scatter.json │ └── linear.json ├── modules ├── __init__.py ├── commons.py ├── losses.py └── models.py ├── .dvcignore ├── features ├── __init__.py ├── denoise.py ├── speaker_embed.py ├── loudness.py ├── f0.py └── ppg.py ├── .gitignore ├── requirements.txt ├── dvc.yaml ├── Rakefile ├── .github └── workflows │ └── python-app.yml ├── LICENSE ├── config.yaml ├── preprocess.py ├── README.MD ├── dvc.lock └── train.py /inference.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckpts/.gitignore: -------------------------------------------------------------------------------- 1 | /g.pt 2 | /do.pt 3 | -------------------------------------------------------------------------------- /.rspec: -------------------------------------------------------------------------------- 1 | --color 2 | --format documentation 3 | -------------------------------------------------------------------------------- /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /tmp 3 | /cache 4 | -------------------------------------------------------------------------------- /tests/spec/spec_helper.rb: -------------------------------------------------------------------------------- 1 | require 'serverspec' 2 | 3 | set :backend, :exec 4 | 5 | -------------------------------------------------------------------------------- /ckpts/g.pt.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: 7365bafab826f8e018cdeed12789e70e 3 | size: 159689487 4 | path: g.pt 5 | -------------------------------------------------------------------------------- /ckpts/do.pt.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: a2165794251fed3a1e6bdabbd322b2e0 3 | size: 1034751023 4 | path: do.pt 5 | -------------------------------------------------------------------------------- /tests/data/example.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SolomidHero/speech-regeneration-enhancer/HEAD/tests/data/example.mp3 -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | [core] 2 | remote = gdrive 3 | ['remote "gdrive"'] 4 | url = gdrive://1ILm45lXT8WBLS4wVxVw-UmQkvOHy4Khd 5 | -------------------------------------------------------------------------------- /tests/data/human_example_clean.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SolomidHero/speech-regeneration-enhancer/HEAD/tests/data/human_example_clean.wav -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator, Discriminator 2 | from .losses import MultiResolutionSTFTLoss, adversarial_loss, discriminator_loss -------------------------------------------------------------------------------- /tests/data/human_example_ipad_balcony.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SolomidHero/speech-regeneration-enhancer/HEAD/tests/data/human_example_ipad_balcony.wav -------------------------------------------------------------------------------- /.dvcignore: -------------------------------------------------------------------------------- 1 | # Add patterns of files dvc should ignore, which could improve 2 | # the performance. Learn more at 3 | # https://dvc.org/doc/user-guide/dvcignore 4 | -------------------------------------------------------------------------------- /features/__init__.py: -------------------------------------------------------------------------------- 1 | from .loudness import get_loudness 2 | from .f0 import get_f0 3 | from .ppg import get_ppg 4 | from .speaker_embed import get_speaker_embed 5 | from .denoise import denoise -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.pytest_cache/ 3 | 4 | **/.DS_Store 5 | 6 | tests/test_dataset/ 7 | tests/test_experiment/ 8 | tests/test_logs/ 9 | 10 | /data 11 | /exp_ckpts 12 | /exp_logs 13 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .loading import load_checkpoint, save_checkpoint, scan_checkpoint 2 | from .data import ( 3 | PreprocessDataset, FeatureDataset, 4 | get_datafolder_files, define_train_list, train_test_split, 5 | save_dataset_filelist, load_dataset_filelist 6 | ) 7 | -------------------------------------------------------------------------------- /modules/commons.py: -------------------------------------------------------------------------------- 1 | # commons.py 2 | # Implementation of some torch.nn modules 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class Conv1d(nn.Conv1d): 8 | "Conv1d with orthogonal initialisation" 9 | def reset_parameters(self) -> None: 10 | super().reset_parameters() 11 | nn.init.orthogonal_(self.weight) 12 | -------------------------------------------------------------------------------- /tests/test_dataset_config.yaml: -------------------------------------------------------------------------------- 1 | wav_dir: 'tests/data' 2 | ppg_dir: 'tests/test_dataset/ppg' 3 | f0_dir: 'tests/test_dataset/f0' 4 | loudness_dir: 'tests/test_dataset/loudness' 5 | spk_embs_file: 'tests/test_dataset/spk_embs.pt' 6 | train_list: 'tests/test_dataset/train.list' 7 | test_list: 'tests/test_dataset/test.list' 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/pyannote/pyannote-audio.git@18ed481c7ef7aa7239dc49e269703b7fbbcf6753 2 | torchaudio>=0.7.2 3 | torch>=1.7.1 4 | torch_optimizer==0.1.0 5 | pytest==6.2.2 6 | scipy>=1.6.0 7 | numpy>=1.20.1 8 | tqdm==4.55.0 9 | pyworld==0.2.12 10 | hydra_core>=1.0.6 11 | omegaconf==2.0.6 12 | transformers==4.3.2 13 | librosa==0.8.0 14 | -------------------------------------------------------------------------------- /tests/spec/localhost/util_spec.rb: -------------------------------------------------------------------------------- 1 | require_relative '../spec_helper' 2 | 3 | 4 | describe file('/usr/local/lib/pkgconfig/sndfile.pc') do 5 | it { should exist } 6 | end 7 | 8 | 9 | describe file('/proc/driver/nvidia/version') do 10 | it { should exist } 11 | end 12 | 13 | 14 | describe package('nvcc') do 15 | it { should be_installed } 16 | end 17 | -------------------------------------------------------------------------------- /tests/test_train_config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 2 2 | epochs: 1 3 | num_workers: 4 4 | n_gpu: null 5 | 6 | stdout_interval: 1 7 | checkpoint_interval: 1 8 | summary_interval: 1 9 | validation_interval: 1 10 | lambda_adv: 4.0 11 | grad_norm_clip_value: 1.0 12 | ckpt_dir: 'tests/test_experiment/' 13 | logs_dir: 'tests/test_logs' 14 | 15 | dist_config: 16 | dist_backend: nccl 17 | dist_url: tcp://localhost:54321 18 | world_size: 1 -------------------------------------------------------------------------------- /dvc.yaml: -------------------------------------------------------------------------------- 1 | stages: 2 | preprocess: 3 | cmd: python3 preprocess.py dataset.wav_dir=tests/data 4 | deps: 5 | - preprocess.py 6 | params: 7 | - config.yaml: 8 | - data 9 | - dataset 10 | outs: 11 | - data 12 | train: 13 | cmd: python3 train.py train.ckpt_dir=exp_ckpts train.logs_dir=exp_logs train.batch_size=1 14 | train.epochs=10 train.checkpoint_interval=10 train.validation_interval=10 15 | deps: 16 | - data 17 | - train.py 18 | params: 19 | - config.yaml: 20 | - data 21 | - dataset 22 | - model 23 | - seed 24 | - train 25 | outs: 26 | - exp_ckpts 27 | - exp_logs 28 | -------------------------------------------------------------------------------- /utils/loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import glob 4 | 5 | 6 | def load_checkpoint(filepath, device): 7 | assert os.path.isfile(filepath) 8 | print("Loading '{}'".format(filepath)) 9 | checkpoint_dict = torch.load(filepath, map_location=device) 10 | print("Complete.") 11 | return checkpoint_dict 12 | 13 | 14 | def save_checkpoint(filepath, obj): 15 | print("Saving checkpoint to {}".format(filepath)) 16 | torch.save(obj, filepath) 17 | print("Complete.") 18 | 19 | 20 | def scan_checkpoint(cp_dir, prefix): 21 | pattern = os.path.join(cp_dir, prefix + '????????') 22 | cp_list = glob.glob(pattern) 23 | if len(cp_list) == 0: 24 | return None 25 | return sorted(cp_list)[-1] -------------------------------------------------------------------------------- /features/denoise.py: -------------------------------------------------------------------------------- 1 | # denoising.py 2 | # Preprocessing of audio with high-quality denoising. 3 | # facebookresearch/denoiser based on demucs architecture 4 | # is used as denoise model 5 | 6 | import torch 7 | 8 | def load_denoiser(model_name="dns64"): 9 | return torch.hub.load("facebookresearch/denoiser", "dns64", force_reload=False).eval() 10 | 11 | def denoise(wav, sr=None, device='cpu'): 12 | """ 13 | Denoise .wav audio data 14 | Args: 15 | wav - waveform (numpy array) 16 | device - (defaul 'cpu') 17 | Returns: 18 | wav - same wav, denoised 19 | """ 20 | model = load_denoiser().to(device) 21 | with torch.no_grad(): 22 | res = model(torch.from_numpy(wav).unsqueeze(0).to(device)) 23 | return res.squeeze().cpu().numpy() -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require 'rake' 2 | require 'rspec/core/rake_task' 3 | 4 | task :spec => 'spec:all' 5 | task :default => :spec 6 | 7 | namespace :spec do 8 | targets = [] 9 | Dir.glob('./tests/spec/*').each do |dir| 10 | next unless File.directory?(dir) 11 | target = File.basename(dir) 12 | target = "_#{target}" if target == "default" 13 | targets << target 14 | end 15 | 16 | task :all => targets 17 | task :default => :all 18 | 19 | targets.each do |target| 20 | original_target = target == "_default" ? target[1..-1] : target 21 | desc "Run serverspec tests to #{original_target}" 22 | RSpec::Core::RakeTask.new(target.to_sym) do |t| 23 | ENV['TARGET_HOST'] = original_target 24 | t.pattern = "tests/spec/#{original_target}/*_spec.rb" 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /.dvc/plots/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "width": 300, 8 | "height": 300, 9 | "mark": { 10 | "type": "line" 11 | }, 12 | "encoding": { 13 | "x": { 14 | "field": "", 15 | "type": "quantitative", 16 | "title": "" 17 | }, 18 | "y": { 19 | "field": "", 20 | "type": "quantitative", 21 | "title": "", 22 | "scale": { 23 | "zero": false 24 | } 25 | }, 26 | "color": { 27 | "field": "rev", 28 | "type": "nominal" 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /features/speaker_embed.py: -------------------------------------------------------------------------------- 1 | # speaker_embed.py 2 | # Speaker Embedding feature extraction utils. 3 | # Dense vector of constant dim, which contains representation 4 | # of speaker prosody and style 5 | 6 | from pyannote.audio import Inference 7 | import torch 8 | 9 | def load_pyannote_audio(ckpt_path='hbredin/SpeakerEmbedding-XVectorMFCC-VoxCeleb', device='cpu'): 10 | """Load speaker embedding model from pyannote.audio""" 11 | model = Inference(ckpt_path, device=device, window='sliding') 12 | return model 13 | 14 | def normalize(embed): 15 | return embed / (embed ** 2).sum(-1, keepdims=True) ** 0.5 16 | 17 | def get_speaker_embed(wav, sr, device='cpu', backend='pyannote'): 18 | if backend == 'pyannote': 19 | model = load_pyannote_audio(device=device) 20 | if len(wav.shape) == 1: 21 | wav = wav[None] 22 | spk_emb = model({ 23 | 'waveform': torch.from_numpy(wav).to(device), 24 | 'sample_rate': sr, 25 | }).data.mean(0) 26 | spk_emb = normalize(spk_emb) 27 | 28 | return spk_emb -------------------------------------------------------------------------------- /.dvc/plots/smooth.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "mark": { 8 | "type": "line" 9 | }, 10 | "encoding": { 11 | "x": { 12 | "field": "", 13 | "type": "quantitative", 14 | "title": "" 15 | }, 16 | "y": { 17 | "field": "", 18 | "type": "quantitative", 19 | "title": "", 20 | "scale": { 21 | "zero": false 22 | } 23 | }, 24 | "color": { 25 | "field": "rev", 26 | "type": "nominal" 27 | } 28 | }, 29 | "transform": [ 30 | { 31 | "loess": "", 32 | "on": "", 33 | "groupby": [ 34 | "rev" 35 | ], 36 | "bandwidth": 0.3 37 | } 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | schedule: 12 | - cron: 0 0 * * 1 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [3.7, 3.8, 3.9] 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | sudo apt-get install libsndfile-dev ffmpeg 30 | python -m pip install --upgrade pip 31 | pip install pytest 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Test with pytest 34 | run: | 35 | pytest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SolomidHero 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /features/loudness.py: -------------------------------------------------------------------------------- 1 | # loudness.py 2 | # loudness feature extraction utils. 3 | # Basically it is a percieved mesure of signal energy. 4 | # Here we use weighting of power spectrogram frequencies, 5 | # then mean and log for each frame. 6 | 7 | import librosa 8 | import numpy as np 9 | 10 | 11 | def get_loudness(wav, sr, n_fft=1280, hop_length=320, win_length=None, ref=1.0, min_db=-80.0): 12 | """ 13 | Extract the loudness measurement of the signal. 14 | Feature is extracted using A-weighting of the signal frequencies. 15 | 16 | Args: 17 | wav - waveform (numpy array) 18 | sr - sampling rate 19 | n_fft - number of points for fft 20 | hop_length - stride of stft 21 | win_length - size of window of stft 22 | ref - reference for amplitude log-scale 23 | min_db - floor for db difference 24 | Returns: 25 | loudness - loudness of signal, shape (n_frames,) 26 | """ 27 | 28 | A_weighting = librosa.A_weighting(librosa.fft_frequencies(sr, n_fft=n_fft)+1e-6, min_db=min_db) 29 | weighting = 10 ** (A_weighting / 10) 30 | 31 | power_spec = abs(librosa.stft(wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) ** 2 32 | loudness = np.mean(power_spec * weighting[:, None], axis=0) 33 | loudness = librosa.power_to_db(loudness, ref=ref) # in db 34 | 35 | return loudness[:, np.newaxis].astype(np.float32) -------------------------------------------------------------------------------- /tests/spec/localhost/pip_spec.rb: -------------------------------------------------------------------------------- 1 | require_relative '../spec_helper' 2 | 3 | 4 | # ml packages 5 | describe package('torch') do 6 | it { should be_installed.by('pip').with_version('1.8.0') } 7 | end 8 | 9 | describe package('torch-optimizer') do 10 | it { should be_installed.by('pip').with_version('0.1.0') } 11 | end 12 | 13 | 14 | # tests 15 | describe package('pytest') do 16 | it { should be_installed.by('pip').with_version('6.2.2') } 17 | end 18 | 19 | 20 | # configurations 21 | describe package('omegaconf') do 22 | it { should be_installed.by('pip').with_version('2.0.6') } 23 | end 24 | 25 | describe package('hydra-core') do 26 | it { should be_installed.by('pip').with_version('1.0.6') } 27 | end 28 | 29 | 30 | # audio processing and feature extraction 31 | describe package('librosa') do 32 | it { should be_installed.by('pip').with_version('0.8.0') } 33 | end 34 | 35 | describe package('torchaudio') do 36 | it { should be_installed.by('pip').with_version('0.8.0') } 37 | end 38 | 39 | describe package('transformers') do 40 | it { should be_installed.by('pip').with_version('4.3.2') } 41 | end 42 | 43 | describe package('pyworld') do 44 | it { should be_installed.by('pip').with_version('0.2.12') } 45 | end 46 | 47 | describe package('numpy') do 48 | it { should be_installed.by('pip') } 49 | end 50 | 51 | describe package('scipy') do 52 | it { should be_installed.by('pip') } 53 | end -------------------------------------------------------------------------------- /tests/test_scripts.py: -------------------------------------------------------------------------------- 1 | # this test uses 2 | # - tests/test_dataset_config.yaml config for cfg.dataset 3 | # - tests/test_train_config.yaml config for cfg.train 4 | # instead of default in root directory 5 | 6 | 7 | from preprocess import main as preprocess_main 8 | from train import main as train_main 9 | 10 | import os 11 | import pytest 12 | import torch 13 | from pathlib import Path 14 | from hydra.experimental import compose, initialize 15 | 16 | @pytest.fixture(scope="module") 17 | def n_files(cfg): 18 | n_files = 0 19 | for _, _, filenames in os.walk(cfg.dataset.wav_dir): 20 | for name in filenames: 21 | if Path(name).suffix == '.wav': 22 | n_files += 1 23 | 24 | return n_files 25 | 26 | 27 | def test_config_consistency(cfg): 28 | with initialize(config_path="../"): 29 | train_config = compose(config_name="config") 30 | 31 | assert set(train_config.dataset.keys()) == set(cfg.dataset.keys()) 32 | assert set(train_config.train.keys()) == set(cfg.train.keys()) 33 | 34 | 35 | def test_preprocess(cfg, n_files): 36 | preprocess_main(cfg) 37 | 38 | assert len(os.listdir(cfg.dataset.ppg_dir)) == n_files 39 | assert len(os.listdir(cfg.dataset.f0_dir)) == n_files 40 | assert len(os.listdir(cfg.dataset.loudness_dir)) == n_files 41 | 42 | spk_embs = torch.load(cfg.dataset.spk_embs_file) 43 | assert len(spk_embs) == n_files 44 | 45 | 46 | def test_train(cfg): 47 | cfg.train.n_gpu = 0 48 | train_main(cfg) 49 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | lr: 1e-4 3 | betas: [0.9, 0.999] 4 | lr_decay: 0.995 # per epoch 5 | 6 | data: 7 | segment_size: 9600 # of target audio (should be divisible by upsampling rate (mult of upsamples)) 8 | sample_rate: 16000 9 | target_sample_rate: 24000 10 | hop_length: 320 # should match 20ms of source audio 11 | n_fft: 1280 # for loudness 12 | win_length: 1280 # for loudness 13 | f_min: 50 # for f0 14 | f_max: null # for f0 15 | 16 | seed: 17 17 | 18 | model: 19 | feature_dims: [768, 1, 1] # ppg_dim, f0_dim, loud_dim 20 | cond_dims: [512, 128] # spk_emb_dim, z_dim 21 | 22 | generator: 23 | hidden_dim: 768 24 | n_blocks: 7 25 | upsamples: [2, 2, 2, 3, 4, 5] 26 | channel_divs: [1, 1, 2, 1, 1, 2, 2] 27 | 28 | discriminator: 29 | out_channels: [16, 64, 256, 1024, 1024, 1024, 1] 30 | kernels: [15, 41, 41, 41, 41, 5, 3] 31 | downsamples: [1, 2, 2, 4, 4, 1, 1] 32 | lrelu_slope: 0.2 33 | 34 | train: 35 | batch_size: 16 36 | epochs: 1000 37 | num_workers: 4 38 | n_gpu: null # will be defined on training 39 | 40 | stdout_interval: 20 41 | checkpoint_interval: 5000 42 | summary_interval: 100 43 | validation_interval: 1000 44 | lambda_adv: 4.0 45 | grad_norm_clip_value: 1.0 46 | ckpt_dir: 'ckpts' 47 | logs_dir: 'logs' 48 | 49 | dist_config: 50 | dist_backend: nccl 51 | dist_url: tcp://localhost:54321 52 | world_size: 1 53 | 54 | dataset: 55 | wav_dir: 'data/wavs' 56 | ppg_dir: 'data/ppg' 57 | f0_dir: 'data/f0' 58 | loudness_dir: 'data/loudness' 59 | spk_embs_file: 'data/spk_embs.pt' 60 | train_list: 'data/train.list' 61 | test_list: 'data/test.list' 62 | 63 | hydra: 64 | run: 65 | dir: . -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # here we make fixtures of toy data 2 | # real parameters are stored and accessed from config 3 | 4 | 5 | import pytest 6 | import librosa 7 | import os 8 | import numpy as np 9 | 10 | from hydra.experimental import compose, initialize 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def cfg(): 15 | with initialize(config_path="../", job_name="test_app"): 16 | config = compose(config_name="config") 17 | config.dataset = compose(config_name="tests/test_dataset_config") 18 | config.train = compose(config_name="tests/test_train_config") 19 | 20 | return config 21 | 22 | @pytest.fixture(scope="session") 23 | def sample_rate(cfg): 24 | return cfg.data.sample_rate 25 | 26 | @pytest.fixture(scope="session") 27 | def example_wav(sample_rate): 28 | wav, sr = librosa.load( 29 | os.path.dirname(__file__) + "/data/example.mp3", 30 | sr=sample_rate, dtype=np.float32, 31 | ) 32 | return { 'wav': wav, 'sr': sr } 33 | 34 | @pytest.fixture(scope="session") 35 | def n_fft(cfg): 36 | return cfg.data.n_fft 37 | 38 | @pytest.fixture(scope="session") 39 | def hop_length(cfg): 40 | return cfg.data.hop_length 41 | 42 | @pytest.fixture(scope="session") 43 | def win_length(cfg): 44 | return cfg.data.win_length 45 | 46 | @pytest.fixture(scope="session") 47 | def f_min(cfg): 48 | return cfg.data.f_min 49 | 50 | @pytest.fixture(scope="session") 51 | def f_max(cfg): 52 | return cfg.data.f_max 53 | 54 | @pytest.fixture(scope="session") 55 | def hop_ms(example_wav, hop_length): 56 | return 1e3 * hop_length / example_wav['sr'] 57 | 58 | @pytest.fixture(scope="session") 59 | def n_frames(example_wav, hop_length): 60 | return (example_wav['wav'].shape[-1] - 1) // hop_length + 1 61 | 62 | # It is not clear if we should cleanup the test directories 63 | # or leave them for debugging 64 | # https://github.com/pytest-dev/pytest/issues/3051 65 | @pytest.fixture(autouse=True, scope='session') 66 | def clear_files_teardown(): 67 | yield None 68 | os.system("rm -r tests/test_dataset tests/test_experiment tests/test_logs") -------------------------------------------------------------------------------- /features/f0.py: -------------------------------------------------------------------------------- 1 | # f0.py 2 | # f0 (main signal frequency) extraction utils. 3 | # f0 contour (a.k.a. pitch contour) shows the current 4 | # dominate note/frequency for each frame 5 | 6 | import numpy as np 7 | import scipy.interpolate 8 | import pyworld as pw 9 | 10 | # best practice is to make f0 continuous and logarithmed 11 | def convert_continuos_f0(f0): 12 | """CONVERT F0 TO CONTINUOUS F0 13 | Reference: 14 | https://github.com/bigpon/vcc20_baseline_cyclevae/blob/master/baseline/src/bin/feature_extract.py 15 | 16 | Args: 17 | f0 (ndarray): original f0 sequence with the shape (T) 18 | Return: 19 | (ndarray): continuous f0 with the shape (T) 20 | """ 21 | # get uv information as binary 22 | uv = np.float32(f0 != 0) 23 | 24 | # get start and end of f0 25 | start_f0 = f0[f0 != 0][0] 26 | end_f0 = f0[f0 != 0][-1] 27 | 28 | # padding start and end of f0 sequence 29 | start_idx = np.where(f0 == start_f0)[0][0] 30 | end_idx = np.where(f0 == end_f0)[0][-1] 31 | f0[:start_idx] = start_f0 32 | f0[end_idx:] = end_f0 33 | 34 | # get non-zero frame index 35 | nz_frames = np.where(f0 != 0)[0] 36 | 37 | # perform linear interpolation 38 | f = scipy.interpolate.interp1d(nz_frames, f0[nz_frames]) 39 | cont_f0 = f(np.arange(0, f0.shape[0])) 40 | 41 | return np.log(cont_f0) 42 | # return uv, np.log(cont_f0) 43 | 44 | 45 | def get_f0(wav, sr, hop_ms, f_min=0, f_max=None): 46 | """ 47 | Extract f0 (1d-array of frame values) from wav (1d-array of point values). 48 | Args: 49 | wav - waveform (numpy array) 50 | sr - sampling rate 51 | hop_ms - stride (in milliseconds) for frames 52 | f_min - f0 floor frequency 53 | f_max - f0 ceil frequency 54 | Returns: 55 | f0 - interpolated main frequency, shape (n_frames,) 56 | """ 57 | if f_max is None: 58 | f_max = sr / 2 59 | 60 | _f0, t = pw.dio(wav.astype(np.float64), sr, frame_period=hop_ms, f0_floor=f_min, f0_ceil=f_max) # raw pitch extractor 61 | f0 = pw.stonemask(wav.astype(np.float64), _f0, t, sr) # pitch refinement 62 | 63 | return convert_continuos_f0(f0)[:, np.newaxis].astype(np.float32) -------------------------------------------------------------------------------- /features/ppg.py: -------------------------------------------------------------------------------- 1 | # ppg.py 2 | # PPG feature extraction utils. 3 | # Phonetic Posteriorgrams are fully linguistic features 4 | # and are often represented by bottleneck feature in ASR network 5 | 6 | from transformers import Wav2Vec2Model 7 | import torch 8 | 9 | 10 | def load_wav2vec2(ckpt_path='facebook/wav2vec2-base-960h'): 11 | """Load pretrained Wav2Vec2 model.""" 12 | def extract_features(self, wav, mask): 13 | return [self(wav).last_hidden_state] 14 | 15 | Wav2Vec2Model.extract_features = extract_features # for same behaviour as fairseq.Wav2Vec2Model 16 | model = Wav2Vec2Model.from_pretrained(ckpt_path).eval() 17 | return model 18 | 19 | 20 | def get_ppg(wav, sr, device='cpu', backend='wav2vec2', max_window=20.0, overlap=2.0): 21 | wav = torch.from_numpy(wav).unsqueeze(0).to(device) 22 | 23 | if backend == 'wav2vec2': 24 | # wav2vec has window of 400 and hop of 320, 25 | # so we pad to center windows 26 | hop_length = 320 27 | win_length = 400 28 | 29 | n_frames = wav.shape[1] // hop_length + 1 30 | wav = torch.nn.functional.pad(wav.unsqueeze(1), (win_length // 2, win_length // 2), mode='reflect').squeeze(1) 31 | model = load_wav2vec2().to(device) 32 | 33 | with torch.no_grad(): 34 | if wav.shape[-1] / sr > max_window: 35 | segment_n_frames = int(max_window * sr) // hop_length 36 | segment_n_points = segment_n_frames * hop_length 37 | overlap_n_frames = int(overlap * sr) // hop_length 38 | overlap_n_points = overlap_n_frames * hop_length 39 | hop_segment_len = segment_n_points - overlap_n_points 40 | 41 | n_segments = (wav.shape[-1] - segment_n_points) // hop_segment_len + 1 42 | ppgs = [] 43 | 44 | # process ppg for every window, except last 45 | for i in range(n_segments): 46 | sub_wav = wav[:, i * hop_segment_len:i * hop_segment_len + segment_n_points + win_length - hop_length] 47 | cur_ppg = model.extract_features(sub_wav, None)[0].squeeze(0).cpu() 48 | cur_ppg = cur_ppg[overlap_n_points // hop_length:] if i > 0 else cur_ppg 49 | ppgs.append(cur_ppg) 50 | 51 | # add last window 52 | n_frames_calced = (n_segments - 1) * hop_segment_len // hop_length + segment_n_frames 53 | n_frames_left = n_frames - n_frames_calced 54 | 55 | sub_wav = wav[:, (n_frames - segment_n_frames) * hop_length:] 56 | cur_ppg = model.extract_features(sub_wav, None)[0].squeeze(0).cpu() 57 | cur_ppg = cur_ppg[-n_frames_left:] 58 | ppgs.append(cur_ppg) 59 | 60 | # cat into one ppg 61 | ppg = torch.cat(ppgs, dim=0) 62 | else: 63 | ppg = model.extract_features(wav, None)[0].squeeze(0).cpu() 64 | 65 | return ppg.numpy() -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Precompute Wav2Vec2, f0, speaker embedding features.""" 3 | 4 | import os 5 | import json 6 | from pathlib import Path 7 | from multiprocessing import cpu_count 8 | 9 | import tqdm 10 | import torch 11 | from torch.utils.data import DataLoader 12 | 13 | from features import get_ppg, get_f0, get_loudness, get_speaker_embed, denoise 14 | from utils import ( 15 | PreprocessDataset, 16 | get_datafolder_files, define_train_list, train_test_split, save_dataset_filelist 17 | ) 18 | 19 | 20 | import hydra 21 | from omegaconf import DictConfig, OmegaConf 22 | 23 | 24 | @hydra.main(config_name="config") 25 | def main(cfg: DictConfig): 26 | """Preprocessing function for DAPS-like dataset (https://archive.org/details/daps_dataset). 27 | - Extracts features (PPG, f0, loudness, spk embedding) for every wav in datafolder 28 | - Builds train and test list 29 | """ 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | print(OmegaConf.to_yaml(cfg.dataset)) 33 | print(OmegaConf.to_yaml(cfg.data)) 34 | print('device:', device) 35 | 36 | # folders preparation 37 | for out_dir_path in [cfg.dataset.f0_dir, cfg.dataset.ppg_dir, cfg.dataset.loudness_dir]: 38 | out_dir_path = Path(out_dir_path) 39 | if out_dir_path.exists(): 40 | assert out_dir_path.is_dir() 41 | else: 42 | out_dir_path.mkdir(parents=True) 43 | 44 | # preprocess dataset and loader 45 | filepathes = get_datafolder_files(cfg.dataset.wav_dir) 46 | dataset = PreprocessDataset(filepathes, cfg.data) 47 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=cfg.train.num_workers) 48 | spk_embs = dict() 49 | 50 | 51 | # feature extraction 52 | pbar = tqdm.tqdm(total=len(dataset), ncols=0) 53 | for wav, filename in dataloader: 54 | # batch size is 1 55 | wav = denoise(wav[0].numpy(), device=device) 56 | filename = Path(filename[0]) 57 | 58 | with torch.no_grad(): 59 | ppg = torch.from_numpy(get_ppg(wav, cfg.data.sample_rate, device=device)) 60 | f0 = torch.from_numpy(get_f0( 61 | wav, cfg.data.sample_rate, 62 | hop_ms=cfg.data.hop_length * 1000 / cfg.data.sample_rate, f_min=cfg.data.f_min, f_max=cfg.data.f_max 63 | )) 64 | loudness = torch.from_numpy(get_loudness( 65 | wav, cfg.data.sample_rate, n_fft=cfg.data.n_fft, 66 | hop_length=cfg.data.hop_length, win_length=cfg.data.win_length 67 | )) 68 | spk_emb = torch.from_numpy(get_speaker_embed(wav, cfg.data.sample_rate, device=device)).cpu() 69 | 70 | 71 | torch.save(ppg, os.path.join(cfg.dataset.ppg_dir, filename.with_suffix('.pt'))) 72 | torch.save(f0, os.path.join(cfg.dataset.f0_dir, filename.with_suffix('.pt'))) 73 | torch.save(loudness, os.path.join(cfg.dataset.loudness_dir, filename.with_suffix('.pt'))) 74 | spk_embs[filename.stem] = spk_emb 75 | 76 | pbar.update(dataloader.batch_size) 77 | 78 | torch.save(spk_embs, cfg.dataset.spk_embs_file) 79 | 80 | # generation of train and test files 81 | train_list = define_train_list(filepathes) 82 | train_list, test_list = train_test_split(train_list) 83 | save_dataset_filelist(train_list, cfg.dataset.train_list) 84 | save_dataset_filelist(test_list, cfg.dataset.test_list) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() -------------------------------------------------------------------------------- /tests/test_features.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | def test_config_feature_dim_values(cfg): 5 | assert len(cfg.model.feature_dims) == 3 6 | assert len(cfg.model.cond_dims) == 2 7 | 8 | 9 | def test_denoise(example_wav): 10 | from features import denoise 11 | 12 | denoised = denoise(example_wav['wav']) 13 | 14 | assert len(denoised.shape) == 1 15 | assert denoised.shape == example_wav['wav'].shape 16 | assert denoised.dtype == np.float32 17 | 18 | 19 | @pytest.fixture(scope='module') 20 | def speaker_embed(example_wav): 21 | from features import get_speaker_embed 22 | speaker_embed = get_speaker_embed(example_wav['wav'], example_wav['sr']) 23 | return speaker_embed 24 | 25 | def test_speaker_embed(speaker_embed): 26 | assert len(speaker_embed.shape) == 1 27 | 28 | 29 | @pytest.fixture(scope='module') 30 | def loudness(example_wav, n_fft, hop_length, win_length, n_frames): 31 | from features import get_loudness 32 | loudness = get_loudness( 33 | example_wav['wav'], 34 | example_wav['sr'], 35 | n_fft=n_fft, hop_length=hop_length, win_length=win_length 36 | ) 37 | return loudness 38 | 39 | def test_loudness(loudness, n_frames): 40 | assert len(loudness.shape) == 2 41 | assert len(loudness) == n_frames 42 | assert loudness.dtype == np.float32 43 | 44 | 45 | @pytest.fixture(scope='module') 46 | def f0(example_wav, hop_ms, f_min, f_max, n_frames): 47 | from features import get_f0 48 | f0 = get_f0(example_wav['wav'], example_wav['sr'], hop_ms, f_min, f_max) 49 | return f0 50 | 51 | def test_f0(f0, n_frames): 52 | assert len(f0.shape) == 2 53 | assert len(f0) == n_frames 54 | assert f0.dtype == np.float32 55 | 56 | 57 | @pytest.fixture(scope='module') 58 | def ppg(example_wav, hop_ms): 59 | from features import get_ppg 60 | ppg = get_ppg(example_wav['wav'], example_wav['sr'], backend='wav2vec2') 61 | return ppg 62 | 63 | @pytest.fixture(scope='module') 64 | def ppg_windowed(example_wav, hop_ms): 65 | from features import get_ppg 66 | ppg = get_ppg(example_wav['wav'], example_wav['sr'], backend='wav2vec2', max_window=1.0, overlap=0.1) 67 | return ppg 68 | 69 | 70 | def test_ppg(ppg, ppg_windowed, example_wav): 71 | # wav2vec2 has 320 stride, receptive field (window) 400 72 | desired_len = example_wav['wav'].shape[-1] // 320 + 1 73 | assert len(ppg.shape) == 2 74 | assert len(ppg_windowed.shape) == 2 75 | 76 | assert ppg.shape[0] == desired_len 77 | assert ppg_windowed.shape[0] == desired_len 78 | 79 | assert ppg.dtype == np.float32 80 | assert ppg_windowed.dtype == np.float32 81 | 82 | 83 | def test_features_alignment(speaker_embed, loudness, f0, ppg): 84 | assert ppg.shape[0] == f0.shape[0] 85 | assert ppg.shape[0] == loudness.shape[0] 86 | 87 | 88 | def test_feature_dims_with_config(cfg, ppg, f0, loudness, speaker_embed): 89 | # source features 90 | assert ppg.shape[1] == cfg.model.feature_dims[0], f"extracted PPG feature dim ({ppg.shape[1]}) must match in config ({cfg.model.feature_dims[0]})" 91 | assert f0.shape[1] == cfg.model.feature_dims[1], f"extracted f0 feature dim ({f0.shape[1]}) must match in config ({cfg.model.feature_dims[1]})" 92 | assert loudness.shape[1] == cfg.model.feature_dims[2], f"extracted loudness feature dim ({loudness.shape[1]}) must match in config ({cfg.model.feature_dims[2]})" 93 | 94 | # condition features 95 | assert speaker_embed.shape[0] == cfg.model.cond_dims[0], f"extracted speaker embedding feature dim ({speaker_embed.shape[0]}) must match in config ({cfg.model.cond_dims[0]})" 96 | -------------------------------------------------------------------------------- /.dvc/plots/confusion.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "facet": { 8 | "field": "rev", 9 | "type": "nominal" 10 | }, 11 | "spec": { 12 | "transform": [ 13 | { 14 | "aggregate": [ 15 | { 16 | "op": "count", 17 | "as": "xy_count" 18 | } 19 | ], 20 | "groupby": [ 21 | "", 22 | "" 23 | ] 24 | }, 25 | { 26 | "impute": "xy_count", 27 | "groupby": [ 28 | "rev", 29 | "" 30 | ], 31 | "key": "", 32 | "value": 0 33 | }, 34 | { 35 | "impute": "xy_count", 36 | "groupby": [ 37 | "rev", 38 | "" 39 | ], 40 | "key": "", 41 | "value": 0 42 | }, 43 | { 44 | "joinaggregate": [ 45 | { 46 | "op": "max", 47 | "field": "xy_count", 48 | "as": "max_count" 49 | } 50 | ], 51 | "groupby": [] 52 | }, 53 | { 54 | "calculate": "datum.xy_count / datum.max_count", 55 | "as": "percent_of_max" 56 | } 57 | ], 58 | "encoding": { 59 | "x": { 60 | "field": "", 61 | "type": "nominal", 62 | "sort": "ascending", 63 | "title": "" 64 | }, 65 | "y": { 66 | "field": "", 67 | "type": "nominal", 68 | "sort": "ascending", 69 | "title": "" 70 | } 71 | }, 72 | "layer": [ 73 | { 74 | "mark": "rect", 75 | "width": 300, 76 | "height": 300, 77 | "encoding": { 78 | "color": { 79 | "field": "xy_count", 80 | "type": "quantitative", 81 | "title": "", 82 | "scale": { 83 | "domainMin": 0, 84 | "nice": true 85 | } 86 | } 87 | } 88 | }, 89 | { 90 | "mark": "text", 91 | "encoding": { 92 | "text": { 93 | "field": "xy_count", 94 | "type": "quantitative" 95 | }, 96 | "color": { 97 | "condition": { 98 | "test": "datum.percent_of_max > 0.5", 99 | "value": "white" 100 | }, 101 | "value": "black" 102 | } 103 | } 104 | } 105 | ] 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /.dvc/plots/confusion_normalized.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "facet": { 8 | "field": "rev", 9 | "type": "nominal" 10 | }, 11 | "spec": { 12 | "transform": [ 13 | { 14 | "aggregate": [ 15 | { 16 | "op": "count", 17 | "as": "xy_count" 18 | } 19 | ], 20 | "groupby": [ 21 | "", 22 | "" 23 | ] 24 | }, 25 | { 26 | "impute": "xy_count", 27 | "groupby": [ 28 | "rev", 29 | "" 30 | ], 31 | "key": "", 32 | "value": 0 33 | }, 34 | { 35 | "impute": "xy_count", 36 | "groupby": [ 37 | "rev", 38 | "" 39 | ], 40 | "key": "", 41 | "value": 0 42 | }, 43 | { 44 | "joinaggregate": [ 45 | { 46 | "op": "sum", 47 | "field": "xy_count", 48 | "as": "sum_y" 49 | } 50 | ], 51 | "groupby": [ 52 | "" 53 | ] 54 | }, 55 | { 56 | "calculate": "datum.xy_count / datum.sum_y", 57 | "as": "percent_of_y" 58 | } 59 | ], 60 | "encoding": { 61 | "x": { 62 | "field": "", 63 | "type": "nominal", 64 | "sort": "ascending", 65 | "title": "" 66 | }, 67 | "y": { 68 | "field": "", 69 | "type": "nominal", 70 | "sort": "ascending", 71 | "title": "" 72 | } 73 | }, 74 | "layer": [ 75 | { 76 | "mark": "rect", 77 | "width": 300, 78 | "height": 300, 79 | "encoding": { 80 | "color": { 81 | "field": "percent_of_y", 82 | "type": "quantitative", 83 | "title": "", 84 | "scale": { 85 | "domain": [ 86 | 0, 87 | 1 88 | ] 89 | } 90 | } 91 | } 92 | }, 93 | { 94 | "mark": "text", 95 | "encoding": { 96 | "text": { 97 | "field": "percent_of_y", 98 | "type": "quantitative", 99 | "format": ".2f" 100 | }, 101 | "color": { 102 | "condition": { 103 | "test": "datum.percent_of_y > 0.5", 104 | "value": "white" 105 | }, 106 | "value": "black" 107 | } 108 | } 109 | } 110 | ] 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /.dvc/plots/scatter.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "width": 300, 8 | "height": 300, 9 | "layer": [ 10 | { 11 | "encoding": { 12 | "x": { 13 | "field": "", 14 | "type": "quantitative", 15 | "title": "" 16 | }, 17 | "y": { 18 | "field": "", 19 | "type": "quantitative", 20 | "title": "", 21 | "scale": { 22 | "zero": false 23 | } 24 | }, 25 | "color": { 26 | "field": "rev", 27 | "type": "nominal" 28 | } 29 | }, 30 | "layer": [ 31 | { 32 | "mark": "point" 33 | }, 34 | { 35 | "selection": { 36 | "label": { 37 | "type": "single", 38 | "nearest": true, 39 | "on": "mouseover", 40 | "encodings": [ 41 | "x" 42 | ], 43 | "empty": "none", 44 | "clear": "mouseout" 45 | } 46 | }, 47 | "mark": "point", 48 | "encoding": { 49 | "opacity": { 50 | "condition": { 51 | "selection": "label", 52 | "value": 1 53 | }, 54 | "value": 0 55 | } 56 | } 57 | } 58 | ] 59 | }, 60 | { 61 | "transform": [ 62 | { 63 | "filter": { 64 | "selection": "label" 65 | } 66 | } 67 | ], 68 | "layer": [ 69 | { 70 | "encoding": { 71 | "text": { 72 | "type": "quantitative", 73 | "field": "" 74 | }, 75 | "x": { 76 | "field": "", 77 | "type": "quantitative" 78 | }, 79 | "y": { 80 | "field": "", 81 | "type": "quantitative" 82 | } 83 | }, 84 | "layer": [ 85 | { 86 | "mark": { 87 | "type": "text", 88 | "align": "left", 89 | "dx": 5, 90 | "dy": -5 91 | }, 92 | "encoding": { 93 | "color": { 94 | "type": "nominal", 95 | "field": "rev" 96 | } 97 | } 98 | } 99 | ] 100 | } 101 | ] 102 | } 103 | ] 104 | } 105 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # Regeneration Enhancer 2 | [![Python application](https://github.com/SolomidHero/speech-regeneration-enhancer/actions/workflows/python-app.yml/badge.svg)](https://github.com/SolomidHero/speech-regeneration-enhancer/actions/workflows/python-app.yml) 3 | 4 | This repository provides speech enhancement via regeneration implementation with Pytorch. Algorithm is based on [paper](https://arxiv.org/abs/2102.00429), but several changes were made in feature extraction and therefore model parameters. 5 | 6 | **TODO list:** 7 | - add inference scripts 8 | - implement streaming model and its inference 9 | - provide multilingual enhancement models (and adapt feature extraction too) 10 | - make pypi package 11 | - release pretrained models 12 | 13 | # Requirements 14 | 15 | This repository is tested on Ubuntu 16.04 with a GPU 1080 Ti. 16 | 17 | - Python 3.7+ (follow [installation page](https://www.python.org/downloads/)) 18 | - Cuda 10.0+ ([guide for ubuntu](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)) 19 | 21 | - libsndfile (you can install via `sudo apt install libsndfile-dev` in ubuntu) 22 | 23 | - pip requirements (defined in `requirements.txt`, install via `pip install -r requirements.txt`): 24 | - hydra-core 1.0.6+ 25 | - pytorch 1.7+ 26 | - torchaudio 0.7.2+ 27 | - librosa 0.8.0+ 28 | - pytest 6.2.0+ 29 | - transformers 4.3.0+, pyworld 0.2.12+, pyannote.audio 2.0+ (for feature extraction) 30 | - (optional) ffmpeg (for .mp3 support, you can install via `sudo apt install ffmpeg` in ubuntu) 31 | 32 | # Installation 33 | 34 | ```bash 35 | git clone https://github.com/SolomidHero/speech-regeneration-enhancer 36 | pip install -e ./speech-regeneration-enhancer 37 | ``` 38 | 39 | # Training 40 | 41 | For training you should use [DAPS dataset](https://archive.org/details/daps_dataset), or dataset with similar file namings (folder structure doesn't matter): 42 | 43 | ``` 44 | data_folder/ 45 | wav_1_clean.wav 46 | dirty/ 47 | wav_1_recoder_bathroom.wav 48 | wav_2_microphone_street.wav 49 | some_sub_tree/ 50 | wav_2_clean.wav 51 | ``` 52 | 53 | In this repository we use hydra configuration ([read more](https://hydra.cc/)), thus for training and inference you can only change `config.yaml` file. Also defining through parameters in bash is available. 54 | 55 | When changes to config are made, you can check yourself if your parameters are acceptable by any of these commands: 56 | ```bash 57 | pytest # to check if everything is working 58 | pytest tests/test_scripts.py # to check if training process can be done 59 | ``` 60 | 61 | 1. After data downloading and config changes, run preprocessing script (feature extraction made here): 62 | 63 | ```bash 64 | preprocess.py dataset.wav_dir=/path/to/wavs # parameters can be added into config directly 65 | ``` 66 | 67 | 2. Finally we are able to train model: 68 | 69 | ```bash 70 | train.py train.epochs=50 train.ckpt_dir=/path/to/ckpts # parameters can be added into config directly 71 | ``` 72 | 73 | In `/path/to/ckpt` checkpoints for generator and other stuff (discriminator, optimizers) will appear from now. 74 | 75 | 76 | # Reference 77 | 78 | - paper: *["HIGH FIDELITY SPEECH REGENERATION WITH APPLICATION TO SPEECH ENHANCEMENT"](https://arxiv.org/abs/2102.00429)* 79 | - data: [Device and Produced Speech Dataset](https://archive.org/details/daps_dataset) 80 | - [wav2vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html) 81 | - [pyannote-audio](https://github.com/pyannote/pyannote-audio/) 82 | - [python wrapper for world](https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder) 83 | - [demucs denoiser](https://github.com/facebookresearch/denoiser/) 84 | -------------------------------------------------------------------------------- /modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def adversarial_loss(scores, as_real=True): 7 | if as_real: 8 | return torch.mean((1 - scores) ** 2) 9 | return torch.mean(scores ** 2) 10 | 11 | 12 | def discriminator_loss(fake_scores, real_scores): 13 | loss = adversarial_loss(fake_scores, as_real=False) + adversarial_loss(real_scores, as_real=True) 14 | return loss 15 | 16 | 17 | def stft(x, n_fft, hop_length, win_length, window, eps=1e-6): 18 | """Perform STFT and convert to magnitude spectrogram. 19 | Args: 20 | x: Input signal tensor (B, T). 21 | Returns: 22 | Tensor: Magnitude spectrogram (B, T, n_fft // 2 + 1). 23 | """ 24 | x_stft = torch.stft(x, 25 | n_fft, hop_length, win_length, window, 26 | center=False, return_complex=True 27 | ).abs().clamp(min=eps) 28 | 29 | return x_stft 30 | 31 | 32 | class SpectralConvergence(nn.Module): 33 | def __init__(self): 34 | """Initilize spectral convergence loss module.""" 35 | super().__init__() 36 | 37 | def forward(self, predicts_mag, targets_mag): 38 | """Calculate norm of difference operator. 39 | Args: 40 | predicts_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 41 | targets_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 42 | Returns: 43 | Tensor: Spectral convergence loss value. 44 | """ 45 | 46 | return torch.mean( 47 | torch.norm(targets_mag - predicts_mag, dim=(1, 2), p='fro') / torch.norm(targets_mag, dim=(1, 2), p='fro') 48 | ) 49 | 50 | 51 | class LogSTFTMagnitude(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | 55 | def forward(self, predicts_mag, targets_mag): 56 | log_predicts_mag = torch.log(predicts_mag) 57 | log_targets_mag = torch.log(targets_mag) 58 | 59 | outputs = F.l1_loss(log_predicts_mag, log_targets_mag) 60 | 61 | return outputs 62 | 63 | 64 | class STFTLoss(nn.Module): 65 | def __init__(self, n_fft, hop_length, win_length, device='cpu'): 66 | super().__init__() 67 | 68 | self.n_fft = n_fft 69 | self.hop_length = hop_length 70 | self.win_length = win_length 71 | self.window = torch.hann_window(win_length).to(device) 72 | self.sc_loss = SpectralConvergence() 73 | self.mag_loss = LogSTFTMagnitude() 74 | 75 | def forward(self, predicts, targets): 76 | """ 77 | Args: 78 | x: predicted signal (B, T). 79 | y: truth signal (B, T). 80 | Returns: 81 | Tensor: STFT loss values. 82 | """ 83 | 84 | predicts_mag = stft(predicts, self.n_fft, self.hop_length, self.win_length, self.window) 85 | targets_mag = stft(targets, self.n_fft, self.hop_length, self.win_length, self.window) 86 | 87 | sc_loss = self.sc_loss(predicts_mag, targets_mag) 88 | mag_loss = self.mag_loss(predicts_mag, targets_mag) 89 | 90 | return sc_loss, mag_loss 91 | 92 | 93 | class MultiResolutionSTFTLoss(nn.Module): 94 | def __init__(self, 95 | fft_sizes=[2048, 1024, 512, 256, 128, 64], 96 | win_sizes=[2048, 1024, 512, 256, 128, 64], 97 | hop_sizes=[512, 256, 128, 64, 32, 16], 98 | device='cpu' 99 | ): 100 | super().__init__() 101 | 102 | self.loss_layers = torch.nn.ModuleList([ 103 | STFTLoss(n_fft, hop_length, win_length, device=device) 104 | for n_fft, win_length, hop_length in zip(fft_sizes, win_sizes, hop_sizes) 105 | ]) 106 | 107 | def forward(self, fake_signals, true_signals): 108 | res_losses = [] 109 | for layer in self.loss_layers: 110 | sc_loss, mag_loss = layer(fake_signals, true_signals) 111 | res_losses.append(sc_loss + mag_loss) 112 | 113 | loss = sum(res_losses) / len(res_losses) 114 | 115 | return loss -------------------------------------------------------------------------------- /dvc.lock: -------------------------------------------------------------------------------- 1 | schema: '2.0' 2 | stages: 3 | preprocess: 4 | cmd: python3 preprocess.py dataset.wav_dir=tests/data 5 | deps: 6 | - path: preprocess.py 7 | md5: c68d62690b22efc45fbb173212686e7c 8 | size: 2982 9 | params: 10 | config.yaml: 11 | data: 12 | segment_size: 9600 13 | sample_rate: 16000 14 | target_sample_rate: 24000 15 | hop_length: 320 16 | n_fft: 1280 17 | win_length: 1280 18 | f_min: 50 19 | f_max: 20 | dataset: 21 | wav_dir: data/wavs 22 | ppg_dir: data/ppg 23 | f0_dir: data/f0 24 | loudness_dir: data/loudness 25 | spk_embs_file: data/spk_embs.pt 26 | train_list: data/train.list 27 | test_list: data/test.list 28 | outs: 29 | - path: data 30 | md5: d5a01d5b45d55f1417f210e6054bb545.dir 31 | size: 712086 32 | nfiles: 9 33 | train: 34 | cmd: python3 train.py train.ckpt_dir=exp_ckpts train.logs_dir=exp_logs train.batch_size=1 35 | train.epochs=10 train.checkpoint_interval=10 train.validation_interval=10 36 | deps: 37 | - path: data 38 | md5: d5a01d5b45d55f1417f210e6054bb545.dir 39 | size: 712086 40 | nfiles: 9 41 | - path: train.py 42 | md5: c55ad4192429f80a6b9de5ba2d773108 43 | size: 9120 44 | params: 45 | config.yaml: 46 | data: 47 | segment_size: 9600 48 | sample_rate: 16000 49 | target_sample_rate: 24000 50 | hop_length: 320 51 | n_fft: 1280 52 | win_length: 1280 53 | f_min: 50 54 | f_max: 55 | dataset: 56 | wav_dir: data/wavs 57 | ppg_dir: data/ppg 58 | f0_dir: data/f0 59 | loudness_dir: data/loudness 60 | spk_embs_file: data/spk_embs.pt 61 | train_list: data/train.list 62 | test_list: data/test.list 63 | model: 64 | feature_dims: 65 | - 768 66 | - 1 67 | - 1 68 | cond_dims: 69 | - 512 70 | - 128 71 | generator: 72 | hidden_dim: 768 73 | n_blocks: 7 74 | upsamples: 75 | - 2 76 | - 2 77 | - 2 78 | - 3 79 | - 4 80 | - 5 81 | channel_divs: 82 | - 1 83 | - 1 84 | - 2 85 | - 1 86 | - 1 87 | - 2 88 | - 2 89 | discriminator: 90 | out_channels: 91 | - 16 92 | - 64 93 | - 256 94 | - 1024 95 | - 1024 96 | - 1024 97 | - 1 98 | kernels: 99 | - 15 100 | - 41 101 | - 41 102 | - 41 103 | - 41 104 | - 5 105 | - 3 106 | downsamples: 107 | - 1 108 | - 2 109 | - 2 110 | - 4 111 | - 4 112 | - 1 113 | - 1 114 | lrelu_slope: 0.2 115 | seed: 17 116 | train: 117 | batch_size: 16 118 | epochs: 1000 119 | num_workers: 4 120 | n_gpu: 121 | stdout_interval: 20 122 | checkpoint_interval: 5000 123 | summary_interval: 100 124 | validation_interval: 1000 125 | lambda_adv: 4.0 126 | grad_norm_clip_value: 1.0 127 | ckpt_dir: ckpts 128 | logs_dir: logs 129 | dist_config: 130 | dist_backend: nccl 131 | dist_url: tcp://localhost:54321 132 | world_size: 1 133 | outs: 134 | - path: exp_ckpts 135 | md5: b520f9691d17329bc76febdaf5742a33.dir 136 | size: 2388880764 137 | nfiles: 4 138 | - path: exp_logs 139 | md5: 65eb50df5a3863aec2ba35f0c303cdd6.dir 140 | size: 77444 141 | nfiles: 1 142 | -------------------------------------------------------------------------------- /.dvc/plots/linear.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://vega.github.io/schema/vega-lite/v4.json", 3 | "data": { 4 | "values": "" 5 | }, 6 | "title": "", 7 | "width": 300, 8 | "height": 300, 9 | "layer": [ 10 | { 11 | "encoding": { 12 | "x": { 13 | "field": "", 14 | "type": "quantitative", 15 | "title": "" 16 | }, 17 | "y": { 18 | "field": "", 19 | "type": "quantitative", 20 | "title": "", 21 | "scale": { 22 | "zero": false 23 | } 24 | }, 25 | "color": { 26 | "field": "rev", 27 | "type": "nominal" 28 | } 29 | }, 30 | "layer": [ 31 | { 32 | "mark": "line" 33 | }, 34 | { 35 | "selection": { 36 | "label": { 37 | "type": "single", 38 | "nearest": true, 39 | "on": "mouseover", 40 | "encodings": [ 41 | "x" 42 | ], 43 | "empty": "none", 44 | "clear": "mouseout" 45 | } 46 | }, 47 | "mark": "point", 48 | "encoding": { 49 | "opacity": { 50 | "condition": { 51 | "selection": "label", 52 | "value": 1 53 | }, 54 | "value": 0 55 | } 56 | } 57 | } 58 | ] 59 | }, 60 | { 61 | "transform": [ 62 | { 63 | "filter": { 64 | "selection": "label" 65 | } 66 | } 67 | ], 68 | "layer": [ 69 | { 70 | "mark": { 71 | "type": "rule", 72 | "color": "gray" 73 | }, 74 | "encoding": { 75 | "x": { 76 | "field": "", 77 | "type": "quantitative" 78 | } 79 | } 80 | }, 81 | { 82 | "encoding": { 83 | "text": { 84 | "type": "quantitative", 85 | "field": "" 86 | }, 87 | "x": { 88 | "field": "", 89 | "type": "quantitative" 90 | }, 91 | "y": { 92 | "field": "", 93 | "type": "quantitative" 94 | } 95 | }, 96 | "layer": [ 97 | { 98 | "mark": { 99 | "type": "text", 100 | "align": "left", 101 | "dx": 5, 102 | "dy": -5 103 | }, 104 | "encoding": { 105 | "color": { 106 | "type": "nominal", 107 | "field": "rev" 108 | } 109 | } 110 | } 111 | ] 112 | } 113 | ] 114 | } 115 | ] 116 | } 117 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | import pytest 4 | from modules import ( 5 | Generator, Discriminator, 6 | MultiResolutionSTFTLoss, adversarial_loss, discriminator_loss 7 | ) 8 | 9 | @pytest.fixture(autouse=True, scope="module") 10 | def n_frames(): 11 | return 13 12 | 13 | @pytest.fixture(autouse=True, scope="module") 14 | def origin_wav_len(n_frames, cfg): 15 | return n_frames * cfg.data.hop_length 16 | 17 | @pytest.fixture(autouse=True, scope="module") 18 | def n_features(cfg): 19 | return sum(cfg.model.feature_dims) 20 | 21 | @pytest.fixture(autouse=True, scope="module") 22 | def cond_dim(cfg): 23 | return cfg.model.cond_dims[0] 24 | 25 | @pytest.fixture(autouse=True, scope="module") 26 | def z_dim(cfg): 27 | return cfg.model.cond_dims[1] 28 | 29 | @pytest.fixture(autouse=True, scope="module") 30 | def generator(cfg, n_features, cond_dim, z_dim): 31 | generator = Generator(n_features, cond_dim, z_dim, **cfg.model.generator) 32 | return generator 33 | 34 | @pytest.fixture(autouse=True, scope="module") 35 | def discriminator(cfg): 36 | discriminator = Discriminator(**cfg.model.discriminator) 37 | return discriminator 38 | 39 | 40 | def test_generator(cfg, generator, n_features, cond_dim, z_dim, n_frames): 41 | total_upsample = reduce(lambda x, y: x * y, cfg.model.generator.upsamples) 42 | 43 | # assert total_upsample == cfg.data.hop_length, \ 44 | # "upsampling factor and hop_length should match" 45 | 46 | features = torch.randn(1, n_features, n_frames) 47 | cond = torch.randn(1, cond_dim) 48 | z = torch.randn(1, z_dim) 49 | 50 | generator.eval() 51 | with torch.no_grad(): 52 | output = generator(features, cond) 53 | output_with_z = generator(features, cond, z=z) 54 | 55 | assert output.shape == output_with_z.shape 56 | assert output.size(1) == n_frames * total_upsample, \ 57 | f"generator output shape {output.shape} expected to have {n_frames} samples" 58 | 59 | 60 | def test_generator_with_grad(cfg, generator, n_features, cond_dim, z_dim, n_frames): 61 | features = torch.randn(1, n_features, n_frames) 62 | cond = torch.randn(1, cond_dim) 63 | z = torch.randn(1, z_dim) 64 | 65 | generator.train() 66 | output = generator(features, cond, z=z) 67 | MultiResolutionSTFTLoss()(output, output.detach()).backward() 68 | 69 | 70 | def test_discriminator(cfg, discriminator, origin_wav_len): 71 | total_downsample = reduce(lambda x, y: x * y, cfg.model.discriminator.downsamples) 72 | 73 | wav = torch.randn(1, origin_wav_len) 74 | with torch.no_grad(): 75 | scores = discriminator(wav) 76 | 77 | assert scores.size(1) <= origin_wav_len // total_downsample, \ 78 | f"discriminator output shape {scores.shape} expected to have {origin_wav_len // total_downsample} or less samples" 79 | 80 | 81 | def test_discriminator_with_grad(cfg, discriminator, origin_wav_len): 82 | wav = torch.randn(1, origin_wav_len) 83 | scores = discriminator(wav) 84 | adversarial_loss(scores).backward() 85 | 86 | 87 | def test_discriminator_loss_with_grad(cfg, generator, discriminator, n_features, cond_dim, z_dim, n_frames, origin_wav_len): 88 | features = torch.randn(1, n_features, n_frames) 89 | cond = torch.randn(1, cond_dim) 90 | z = torch.randn(1, z_dim) 91 | wav = torch.randn(1, origin_wav_len) 92 | 93 | generator.train() 94 | discriminator.train() 95 | 96 | fake_wav = generator(features, cond, z=z) 97 | real_scores, fake_scores = discriminator(wav), discriminator(fake_wav.detach()) 98 | 99 | d_loss = discriminator_loss(fake_scores, real_scores) 100 | d_loss.backward() 101 | 102 | 103 | def test_generator_loss_with_grad(cfg, generator, discriminator, n_features, cond_dim, z_dim, n_frames, origin_wav_len): 104 | total_upsample = reduce(lambda x, y: x * y, cfg.model.generator.upsamples) 105 | features = torch.randn(1, n_features, n_frames) 106 | cond = torch.randn(1, cond_dim) 107 | z = torch.randn(1, z_dim) 108 | wav = torch.randn(1, n_frames * total_upsample) 109 | 110 | generator.train() 111 | discriminator.train() 112 | 113 | fake_wav = generator(features, cond, z=z) 114 | fake_scores = discriminator(fake_wav) 115 | 116 | g_loss = adversarial_loss(fake_scores) + MultiResolutionSTFTLoss()(fake_wav, wav) 117 | g_loss.backward() -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | # models.py 2 | # Implementation of necessary modules and enhancer model itself. 3 | 4 | import torch 5 | from .commons import Conv1d 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn.utils import spectral_norm 9 | 10 | 11 | class ConditionalBatchNorm1d(nn.Module): 12 | def __init__(self, n_features, cond_dim, eps=1e-5, momentum=0.1): 13 | super().__init__() 14 | self.n_features = n_features 15 | self.bn = nn.BatchNorm1d(n_features, affine=False, eps=eps, momentum=momentum,) 16 | 17 | # linear layers 18 | self.gamma_embed = spectral_norm(Conv1d(cond_dim, n_features, 1, bias=False)) 19 | self.beta_embed = spectral_norm(Conv1d(cond_dim, n_features, 1, bias=False)) 20 | 21 | def forward(self, x, y): 22 | out = self.bn(x) 23 | gamma = self.gamma_embed(y) + 1 24 | beta = self.beta_embed(y) 25 | out = gamma.view(-1, self.n_features, 1) * out + beta.view(-1, self.n_features, 1) 26 | return out 27 | 28 | 29 | class ResBlock(nn.Module): 30 | """ 31 | 1-dimensional ResBlock (with or w/o upsample only) 32 | This is half of GBlock presented in reference. 33 | Convolution in skip-connection and upsample here is an option 34 | 35 | Reference: https://arxiv.org/pdf/1909.11646.pdf 36 | """ 37 | def __init__(self, in_channel, out_channel, cond_dim, 38 | kernel_size=3, padding=1, stride=1, dilations=(1, 2), 39 | bn=True, activation=F.relu, upsample=1, 40 | ): 41 | super().__init__() 42 | 43 | 44 | self.upsample = upsample 45 | self.activation = activation 46 | 47 | self.conv0 = spectral_norm( 48 | Conv1d(in_channel, out_channel, kernel_size, stride, padding, bias=not bn) 49 | ) 50 | self.conv1 = spectral_norm( 51 | Conv1d(out_channel, out_channel, kernel_size, stride, padding, bias=not bn) 52 | ) 53 | 54 | self.skip_proj = False 55 | if in_channel != out_channel or upsample != 1: 56 | self.conv_skip = spectral_norm(Conv1d(in_channel, out_channel, 1, 1, 0)) 57 | self.skip_proj = True 58 | 59 | self.bn = bn 60 | if bn: 61 | self.cbn0 = ConditionalBatchNorm1d(in_channel, cond_dim) 62 | self.cbn1 = ConditionalBatchNorm1d(out_channel, cond_dim) 63 | 64 | def forward(self, x, cond): 65 | skip = x 66 | 67 | # first conv layers 68 | if self.bn: 69 | x = self.cbn0(x, cond) 70 | x = self.activation(x) 71 | if self.upsample != 1: 72 | x = F.interpolate(x, scale_factor=self.upsample) 73 | x = self.conv0(x) 74 | 75 | # second conv layers 76 | if self.bn: 77 | x = self.cbn1(x, cond) 78 | x = self.activation(x) 79 | x = self.conv1(x) 80 | 81 | # skip connection 82 | if self.upsample != 1: 83 | skip = F.interpolate(skip, scale_factor=self.upsample) 84 | if self.skip_proj: 85 | skip = self.conv_skip(skip) 86 | return x + skip 87 | 88 | 89 | class GBlock(nn.Module): 90 | def __init__(self, in_channel, out_channel, cond_dim, 91 | upsample=1, dilations=(1, 2, 4, 8) 92 | ): 93 | super().__init__() 94 | 95 | self.gblock = nn.ModuleList([ 96 | ResBlock(in_channel, out_channel, cond_dim, dilations=dilations[0:2], upsample=upsample), 97 | ResBlock(out_channel, out_channel, cond_dim, dilations=dilations[2:4]), 98 | ]) 99 | 100 | def forward(self, x, cond): 101 | for resblock in self.gblock: 102 | x = resblock(x, cond) 103 | 104 | return x 105 | 106 | 107 | class Generator(nn.Module): 108 | def __init__(self, in_channel, cond_dim, z_dim, hidden_dim=768, 109 | n_blocks=7, upsamples=[2, 2, 2, 3, 4, 5], channel_divs=[1, 1, 2, 1, 1, 2, 2], 110 | ): 111 | super().__init__() 112 | 113 | self.z_dim = z_dim 114 | upsamples = list(upsamples) 115 | channel_divs = list(channel_divs) 116 | 117 | assert len(upsamples) <= n_blocks, \ 118 | f"Number of blocks with upsample ({len(upsamples)} must be <= total number of blocks ({n_blocks})" 119 | assert len(channel_divs) == n_blocks, \ 120 | f"Number of channel divisions {channel_divs} for blocks != n_blocks ({n_blocks})" 121 | 122 | upsamples = [1] * (n_blocks - len(upsamples)) + upsamples 123 | 124 | from itertools import accumulate 125 | channel_divs = list(accumulate([1] + channel_divs, lambda x, y: x * y)) 126 | 127 | self.input_conv = spectral_norm(Conv1d(in_channel, hidden_dim, 3, 1, 1)) 128 | self.gblocks = nn.ModuleList([ 129 | GBlock(hidden_dim // channel_divs[i], hidden_dim // channel_divs[i + 1], cond_dim + z_dim, upsamples[i]) 130 | for i in range(len(upsamples)) 131 | ]) 132 | self.output_conv = spectral_norm(Conv1d(hidden_dim // channel_divs[-1], 1, 1, 1, 0)) 133 | self.tanh = nn.Tanh() 134 | 135 | def forward(self, x, cond, z=None): 136 | if z is None: 137 | z = torch.randn(x.shape[0], self.z_dim, device=cond.device) 138 | 139 | z = z.unsqueeze(2) if len(z.shape) == 2 else z 140 | cond = cond.unsqueeze(2) if len(cond.shape) == 2 else cond 141 | cond = torch.cat([cond, z], dim=1) 142 | 143 | x = self.input_conv(x) 144 | for block in self.gblocks: 145 | x = block(x, cond) 146 | x = self.output_conv(x).flatten(1, -1) 147 | 148 | x = self.tanh(x) 149 | return x 150 | 151 | 152 | class Discriminator(nn.Module): 153 | def __init__(self, 154 | out_channels=[16, 64, 256, 1024, 1024, 1024, 1], 155 | kernels=[15, 41, 41, 41, 41, 5, 3], 156 | downsamples=[1, 2, 2, 4, 4, 1, 1], 157 | lrelu_slope=0.2, 158 | ): 159 | super().__init__() 160 | 161 | assert out_channels[-1] == 1, "out_channels last value must be 1" 162 | assert len(out_channels) == len(kernels) and len(kernels) == len(downsamples), \ 163 | "out_channels, kernel sizes, downsamples numbers must match" 164 | 165 | self.lrelu = nn.LeakyReLU(lrelu_slope) 166 | self.convs = nn.ModuleList([ 167 | spectral_norm(Conv1d(in_c, out_c, kernel_size, stride=down, padding=kernel_size // 2)) 168 | for in_c, out_c, kernel_size, down in zip([1] + out_channels[:-1], out_channels, kernels, downsamples) 169 | ]) 170 | 171 | def forward(self, x): 172 | if len(x.shape) < 3: 173 | x = x.unsqueeze(1) 174 | 175 | for i, l in enumerate(self.convs): 176 | x = l(x) 177 | if i + 1 == len(self.convs): 178 | break 179 | x = self.lrelu(x) 180 | 181 | x = torch.flatten(x, 1, -1) 182 | 183 | return x 184 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from pathlib import Path 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torch.utils.data 9 | import torchaudio 10 | 11 | import numpy as np 12 | import librosa 13 | from librosa.util import normalize 14 | from scipy.io.wavfile import read 15 | from librosa.filters import mel as librosa_mel_fn 16 | 17 | 18 | def load_wav(path, sr=16000): 19 | """Load audio from path and resample it 20 | Return: 1d numpy.array of audio data 21 | """ 22 | wav = librosa.load(path, sr=sr, dtype=np.float32, mono=True)[0] 23 | return wav 24 | 25 | 26 | def get_datafolder_files(datafolder_path, pattern='.wav'): 27 | """Get all files with specified extension in directory tree 28 | Return: list of file pathes 29 | """ 30 | filelist = [] 31 | 32 | for root, _, filenames in os.walk(datafolder_path): 33 | for filename in filenames: 34 | if Path(filename).suffix == pattern: 35 | filelist.append(os.path.join(root, filename)) 36 | 37 | return filelist 38 | 39 | 40 | def define_train_list(filepathes, clean_suffix='clean', n_utterance_tokens=2): 41 | """Return dict in following format: 42 | { utterance_name : [filepath, filename1, filename2...] }, 43 | where first value of list is groundtruth file's path 44 | for all other files (with same utterance) 45 | """ 46 | assert clean_suffix in ['cleanraw', 'clean', 'produced'] 47 | train_list = defaultdict(list) 48 | 49 | for filepath in filepathes: 50 | p = Path(filepath) 51 | tokens = p.stem.split('_') 52 | utterance = '_'.join(tokens[:n_utterance_tokens]) 53 | 54 | if tokens[-1] == clean_suffix: 55 | train_list[utterance] = [filepath] + train_list[utterance] 56 | else: 57 | train_list[utterance].append(p.stem) 58 | 59 | return train_list 60 | 61 | 62 | def train_test_split(filelist, p=0.85, seed=17): 63 | """Return train and test set of filenames 64 | This function follows `define_train_list` and uses its output 65 | """ 66 | random.seed(seed) 67 | train_list, test_list = dict(), dict() 68 | 69 | for utterance, files in filelist.items(): 70 | gt_filepath, filenames = files[0], files[1:] 71 | random.shuffle(filenames) 72 | 73 | val_len = int((1 - p) * len(filenames)) 74 | 75 | train_list[utterance] = [gt_filepath] + filenames[val_len:] 76 | test_list[utterance] = [gt_filepath] + filenames[:val_len] 77 | 78 | return train_list, test_list 79 | 80 | 81 | def save_dataset_filelist(filelist, filelist_path, delim='|'): 82 | with open(filelist_path, 'w', encoding='utf-8') as f: 83 | for utterance, files in filelist.items(): 84 | print(utterance + delim + delim.join(files), file=f) 85 | 86 | 87 | def load_dataset_filelist(filelist_path, delim='|'): 88 | filelist = dict() 89 | with open(filelist_path, 'r', encoding='utf-8') as f: 90 | for line in f: 91 | line = line.strip() 92 | if len(line) <= 0: 93 | continue 94 | tokens = line.split(delim) 95 | utterance, files = tokens[0], tokens[1:] 96 | filelist[utterance] = files 97 | 98 | return filelist 99 | 100 | 101 | class PreprocessDataset(torch.utils.data.Dataset): 102 | """ 103 | Torch Dataset class for wav files for their following feature extraction. 104 | Assumes that wav files (all names are different) 105 | can be in some subdirectories of root dir. 106 | """ 107 | def __init__(self, filepathes, data_cfg): 108 | self.data_cfg = data_cfg 109 | self.audio_pathes = filepathes 110 | self.filenames = list(map(lambda p: Path(p).name, filepathes)) 111 | 112 | def __getitem__(self, index): 113 | file_path = self.audio_pathes[index] 114 | filename = self.filenames[index] 115 | wav = load_wav(file_path, sr=self.data_cfg.sample_rate) 116 | 117 | return wav, filename 118 | 119 | def __len__(self): 120 | return len(self.audio_pathes) 121 | 122 | 123 | class FeatureDataset(torch.utils.data.Dataset): 124 | """ 125 | Torch Dataset class for wav files and their features. 126 | Assumes that wav files (all names are different) 127 | can be in some subdirectories of root dir, 128 | but feature files in corresponding feature directories alone. 129 | """ 130 | def __init__(self, dataset_cfg, filelist, data_cfg, preload_gt=True, segmented=True, segment_size=None, seed=17): 131 | self.data_cfg = data_cfg 132 | self.dataset_cfg = dataset_cfg 133 | self.filelist = filelist 134 | self.segmented = segmented 135 | self.upsampling_rate = self.data_cfg.hop_length * self.data_cfg.target_sample_rate // self.data_cfg.sample_rate 136 | if self.segmented: 137 | segment_size = self.data_cfg.segment_size if segment_size is None else segment_size 138 | 139 | assert segment_size % (self.data_cfg.hop_length * self.data_cfg.target_sample_rate // self.data_cfg.sample_rate) == 0 140 | 141 | self.n_points = segment_size 142 | self.n_frames = self.n_points // self.upsampling_rate 143 | 144 | self.noise_to_gt_dict = dict() 145 | for files in filelist.values(): 146 | # first is ground truth 147 | self.noise_to_gt_dict[Path(files[0]).stem] = files[0] 148 | for filename in files[1:]: 149 | self.noise_to_gt_dict[filename] = files[0] 150 | 151 | self.filenames = list(map(lambda p: Path(p).stem, self.noise_to_gt_dict.keys())) 152 | self.gt_list = set(self.noise_to_gt_dict.values()) 153 | 154 | self.preload_gt = preload_gt 155 | if preload_gt: 156 | self.gt_data = { 157 | gt_path : load_wav(gt_path, sr=self.data_cfg.target_sample_rate) 158 | for gt_path in self.gt_list 159 | } 160 | 161 | self.spk_embs = torch.load(self.dataset_cfg.spk_embs_file) 162 | 163 | random.seed(seed) 164 | 165 | def __getitem__(self, index): 166 | filename = self.filenames[index] 167 | gt_path = self.noise_to_gt_dict[filename] 168 | 169 | if self.preload_gt: 170 | gt_wav = self.gt_data[gt_path] 171 | else: 172 | gt_wav = load_wav(gt_path, sr=self.data_cfg.target_sample_rate) 173 | 174 | # features: [ppg, f0, loudness] 175 | features = [] 176 | for feature_dir in ( 177 | self.dataset_cfg.ppg_dir, self.dataset_cfg.f0_dir, self.dataset_cfg.loudness_dir 178 | ): 179 | feat = torch.load(os.path.join(feature_dir, Path(filename).with_suffix('.pt'))) 180 | features.append(feat) 181 | features = torch.cat(features, dim=1) 182 | 183 | # condition: [embed] 184 | cond = self.spk_embs[filename] 185 | 186 | # because center=True for feature extraction 187 | n_pad = self.upsampling_rate // 2 188 | gt_wav = np.pad(gt_wav, (n_pad, n_pad), mode='reflect')[:self.upsampling_rate * features.shape[0]] 189 | 190 | if self.segmented: 191 | start_frame = random.randint(0, features.shape[0] - self.n_frames) 192 | start_point = start_frame * self.data_cfg.hop_length 193 | 194 | gt_wav = gt_wav[start_point:start_point + self.n_points] 195 | features = features[start_frame:start_frame + self.n_frames] 196 | 197 | return gt_wav, features, cond 198 | 199 | def __len__(self): 200 | return len(self.filenames) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import warnings 3 | warnings.simplefilter(action='ignore', category=FutureWarning) 4 | 5 | import os 6 | import time 7 | import datetime 8 | from tqdm.auto import tqdm 9 | 10 | import torch 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | from torch.optim import AdamW 18 | from modules import ( 19 | Generator, Discriminator, 20 | MultiResolutionSTFTLoss, adversarial_loss, discriminator_loss 21 | ) 22 | from utils import ( 23 | FeatureDataset, load_dataset_filelist, 24 | scan_checkpoint, load_checkpoint, save_checkpoint 25 | ) 26 | 27 | from omegaconf import DictConfig, OmegaConf 28 | import hydra 29 | 30 | torch.backends.cudnn.benchmark = True 31 | 32 | 33 | def train(rank: int, cfg: DictConfig): 34 | print(OmegaConf.to_yaml(cfg)) 35 | 36 | if cfg.train.n_gpu > 1: 37 | init_process_group( 38 | backend=cfg.train.dist_config['dist_backend'], init_method=cfg.train.dist_config['dist_url'], 39 | world_size=cfg.train.dist_config['world_size'] * cfg.train.n_gpu, rank=rank 40 | ) 41 | 42 | device = torch.device('cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu') 43 | criterion = MultiResolutionSTFTLoss(device=device) 44 | 45 | # Defining model (generator and discriminator) 46 | generator = Generator( 47 | sum(cfg.model.feature_dims), 48 | *cfg.model.cond_dims, 49 | **cfg.model.generator 50 | ).to(device) 51 | discriminator = Discriminator(**cfg.model.discriminator).to(device) 52 | 53 | if rank == 0: 54 | print(generator) 55 | os.makedirs(cfg.train.ckpt_dir, exist_ok=True) 56 | print("checkpoints directory : ", cfg.train.ckpt_dir) 57 | 58 | # Loading checkpoints (if exist) 59 | if os.path.isdir(cfg.train.ckpt_dir): 60 | cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_') 61 | cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_') 62 | 63 | steps = 1 64 | if cp_g is None or cp_do is None: 65 | state_dict_do = None 66 | last_epoch = -1 67 | else: 68 | state_dict_g = load_checkpoint(cp_g, device) 69 | state_dict_do = load_checkpoint(cp_do, device) 70 | generator.load_state_dict(state_dict_g['generator']) 71 | discriminator.load_state_dict(state_dict_do['discriminator']) 72 | steps = state_dict_do['steps'] + 1 73 | last_epoch = state_dict_do['epoch'] 74 | 75 | if cfg.train.n_gpu > 1: 76 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 77 | discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device) 78 | 79 | # Defining optimizers 80 | optim_g = AdamW(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) 81 | optim_d = AdamW(discriminator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) 82 | 83 | if state_dict_do is not None: 84 | optim_g.load_state_dict(state_dict_do['optim_g']) 85 | optim_d.load_state_dict(state_dict_do['optim_d']) 86 | 87 | # Defining schedulers 88 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) 89 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) 90 | 91 | # Data preparation 92 | print("Preparing train data...") 93 | train_filelist = load_dataset_filelist(cfg.dataset.train_list) 94 | trainset = FeatureDataset( 95 | cfg.dataset, train_filelist, cfg.data, segmented=True, preload_gt=True, seed=cfg.seed 96 | ) 97 | train_sampler = DistributedSampler(trainset) if cfg.train.n_gpu > 1 else None 98 | train_loader = DataLoader( 99 | trainset, batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers, shuffle=True, 100 | sampler=train_sampler, pin_memory=True, drop_last=True 101 | ) 102 | 103 | if rank == 0: 104 | print("Preparing validation data...") 105 | val_filelist = load_dataset_filelist(cfg.dataset.test_list) 106 | valset = FeatureDataset( 107 | cfg.dataset, val_filelist, cfg.data, segmented=True, preload_gt=True, 108 | segment_size=cfg.data.segment_size * cfg.train.batch_size 109 | ) 110 | val_loader = DataLoader( 111 | valset, batch_size=1, num_workers=cfg.train.num_workers, shuffle=False, 112 | sampler=train_sampler, pin_memory=True 113 | ) 114 | 115 | log_dir = f'{cfg.train.logs_dir}/{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}' 116 | sw = SummaryWriter(log_dir) 117 | 118 | # Train loop 119 | generator.train() 120 | discriminator.train() 121 | for epoch in range(max(0, last_epoch), cfg.train.epochs): 122 | if rank == 0: 123 | start = time.time() 124 | print("Epoch: {}".format(epoch+1)) 125 | 126 | if cfg.train.n_gpu > 1: 127 | train_sampler.set_epoch(epoch) 128 | 129 | for y, x_noised_features, x_noised_cond in train_loader: 130 | if rank == 0: 131 | start_b = time.time() 132 | 133 | y = y.to(device, non_blocking=True) 134 | x_noised_features = x_noised_features.transpose(1, 2).to(device, non_blocking=True) 135 | x_noised_cond = x_noised_cond.to(device, non_blocking=True) 136 | z1 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) 137 | z2 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) 138 | 139 | y_hat1 = generator(x_noised_features, x_noised_cond, z=z1) 140 | y_hat2 = generator(x_noised_features, x_noised_cond, z=z2) 141 | 142 | 143 | # Discriminator 144 | real_scores, fake_scores = discriminator(y), discriminator(y_hat1.detach()) 145 | d_loss = discriminator_loss(real_scores, fake_scores) 146 | 147 | optim_d.zero_grad() 148 | d_loss.backward() 149 | d_grad_norm = torch.nn.utils.clip_grad_norm_(discriminator.parameters(), cfg.train.grad_norm_clip_value) 150 | optim_d.step() 151 | 152 | # Generator 153 | fake_scores = discriminator(y_hat1) 154 | g_stft_loss = criterion(y, y_hat1) + criterion(y, y_hat2) - criterion(y_hat1, y_hat2) 155 | g_adv_loss = adversarial_loss(fake_scores) 156 | g_loss = g_adv_loss + cfg.train.lambda_adv * g_stft_loss 157 | 158 | optim_g.zero_grad() 159 | g_loss.backward() 160 | g_grad_norm = torch.nn.utils.clip_grad_norm_(generator.parameters(), cfg.train.grad_norm_clip_value) 161 | optim_g.step() 162 | 163 | if rank == 0: 164 | # stdout logging 165 | if steps % cfg.train.stdout_interval == 0: 166 | with torch.no_grad(): 167 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}'. 168 | format(steps, g_loss, g_stft_loss, time.time() - start_b)) 169 | 170 | # checkpointing 171 | if steps % cfg.train.checkpoint_interval == 0: 172 | ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps) 173 | save_checkpoint( 174 | ckpt_dir, 175 | { 'generator': (generator.module if cfg.train.n_gpu > 1 else generator).state_dict() } 176 | ) 177 | ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps) 178 | save_checkpoint( 179 | ckpt_dir, { 180 | 'discriminator': (discriminator.module if cfg.train.n_gpu > 1 else discriminator).state_dict(), 181 | 'optim_g': optim_g.state_dict(), 182 | 'optim_d': optim_d.state_dict(), 183 | 'steps': steps, 184 | 'epoch': epoch 185 | }) 186 | 187 | # Tensorboard summary logging 188 | if steps % cfg.train.summary_interval == 0: 189 | sw.add_scalar("loss/g_loss_total", g_loss, steps) 190 | sw.add_scalar("loss/g_adv_plus_d_total", g_adv_loss.item() + d_loss.item(), steps) 191 | 192 | sw.add_scalar("loss_component/g_stft_error", g_stft_loss, steps) 193 | sw.add_scalar("loss_component/g_adv_loss", g_adv_loss, steps) 194 | sw.add_scalar("loss_component/d_loss", d_loss, steps) 195 | 196 | sw.add_scalar("grad/d_grad_norm", d_grad_norm, steps) 197 | sw.add_scalar("grad/g_grad_norm", g_grad_norm, steps) 198 | 199 | # Validation 200 | if steps % cfg.train.validation_interval == 0: 201 | generator.eval() 202 | torch.cuda.empty_cache() 203 | 204 | val_err_tot = 0 205 | val_pbar = tqdm(total=len(valset), ncols=0, desc="Valid", unit=" uttr") 206 | 207 | for j, (y, x_noised_features, x_noised_cond) in enumerate(val_loader): 208 | with torch.no_grad(): 209 | y_hat = generator(x_noised_features.transpose(1, 2).to(device), x_noised_cond.to(device)) 210 | val_err_tot += criterion(y.to(device), y_hat).item() 211 | 212 | if j <= 4: 213 | sw.add_audio('generated/y_hat_{}'.format(j), y_hat[0], steps, cfg.data.target_sample_rate) 214 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, cfg.data.target_sample_rate) 215 | 216 | val_pbar.update(val_loader.batch_size) 217 | val_pbar.set_postfix(loss=f"{val_err_tot / (j + 1):.2f}") 218 | 219 | val_err = val_err_tot / (j + 1) 220 | sw.add_scalar("validation/stft_error", val_err, steps) 221 | 222 | val_pbar.close() 223 | generator.train() 224 | 225 | steps += 1 226 | 227 | scheduler_g.step() 228 | scheduler_d.step() 229 | 230 | if rank == 0: 231 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 232 | 233 | 234 | @hydra.main(config_name="config") 235 | def main(cfg: DictConfig): 236 | torch.manual_seed(cfg.seed) 237 | 238 | if torch.cuda.is_available(): 239 | torch.cuda.manual_seed(cfg.seed) 240 | cfg.train.n_gpu = torch.cuda.device_count() 241 | cfg.train.batch_size = int(cfg.train.batch_size / cfg.train.n_gpu) 242 | print('Batch size per GPU :', cfg.train.batch_size) 243 | else: 244 | cfg.train.n_gpu = 0 245 | print('No GPU registered for training!') 246 | 247 | if cfg.train.n_gpu > 1: 248 | mp.spawn(train, nprocs=cfg.train.n_gpu, args=(cfg,)) 249 | else: 250 | train(0, cfg) 251 | 252 | 253 | if __name__ == '__main__': 254 | main() 255 | --------------------------------------------------------------------------------