├── voicefixer ├── tools │ ├── modules │ │ ├── filters │ │ │ ├── f_2_64.mat │ │ │ ├── f_4_64.mat │ │ │ ├── f_8_64.mat │ │ │ ├── h_2_64.mat │ │ │ ├── h_4_64.mat │ │ │ └── h_8_64.mat │ │ ├── __init__.py │ │ ├── pqmf.py │ │ └── fDomainHelper.py │ ├── path.py │ ├── __init__.py │ ├── io.py │ ├── random_.py │ ├── pytorch_util.py │ ├── wav.py │ ├── base.py │ └── mel_scale.py ├── vocoder │ ├── model │ │ ├── __init__.py │ │ ├── pqmf.py │ │ ├── res_msd.py │ │ ├── util.py │ │ ├── generator.py │ │ └── modules.py │ ├── __init__.py │ ├── base.py │ └── config.py ├── __init__.py ├── restorer │ ├── __init__.py │ ├── model_kqq_bn.py │ ├── modules.py │ └── model.py ├── __main__.py └── base.py ├── bin ├── voicefixer.cmd └── voicefixer ├── .gitignore ├── test ├── figure.png ├── streamlit.png ├── utterance │ ├── output │ │ ├── oracle.flac │ │ ├── output_mode_0.flac │ │ ├── output_mode_1.flac │ │ └── output_mode_2.flac │ ├── target │ │ ├── oracle.flac │ │ ├── output_mode_0.flac │ │ ├── output_mode_1.flac │ │ └── output_mode_2.flac │ └── original │ │ ├── original.flac │ │ ├── original.wav │ │ └── p360_001_mic1.flac ├── streamlit.py ├── inference.py └── test.py ├── LICENSE ├── setup.py └── README.md /voicefixer/tools/modules/filters/f_2_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/f_4_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/f_8_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_2_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_4_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_8_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/voicefixer.cmd: -------------------------------------------------------------------------------- 1 | @echo OFF 2 | python -m voicefixer %* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS* 2 | __pycache__ 3 | dist 4 | *egg* 5 | .idea 6 | build -------------------------------------------------------------------------------- /test/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/figure.png -------------------------------------------------------------------------------- /test/streamlit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/streamlit.png -------------------------------------------------------------------------------- /test/utterance/output/oracle.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/oracle.flac -------------------------------------------------------------------------------- /test/utterance/target/oracle.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/oracle.flac -------------------------------------------------------------------------------- /test/utterance/original/original.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/original.flac -------------------------------------------------------------------------------- /test/utterance/original/original.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/original.wav -------------------------------------------------------------------------------- /test/utterance/output/output_mode_0.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_0.flac -------------------------------------------------------------------------------- /test/utterance/output/output_mode_1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_1.flac -------------------------------------------------------------------------------- /test/utterance/output/output_mode_2.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_2.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_0.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_0.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_1.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_2.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_2.flac -------------------------------------------------------------------------------- /test/utterance/original/p360_001_mic1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/p360_001_mic1.flac -------------------------------------------------------------------------------- /voicefixer/tools/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def find_and_build(root, path): 5 | path = os.path.join(root, path) 6 | if not os.path.exists(path): 7 | os.makedirs(path, exist_ok=True) 8 | return path 9 | 10 | 11 | def root_path(repo_name="voicefixer"): 12 | path = os.path.abspath(__file__) 13 | return path.split(repo_name)[0] 14 | -------------------------------------------------------------------------------- /voicefixer/tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:28 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:29 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 1:00 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:31 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | from voicefixer.vocoder.base import Vocoder 14 | from voicefixer.base import VoiceFixer 15 | -------------------------------------------------------------------------------- /voicefixer/vocoder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 1:00 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import os 14 | from voicefixer.vocoder.config import Config 15 | import urllib.request 16 | 17 | if not os.path.exists(Config.ckpt): 18 | os.makedirs(os.path.dirname(Config.ckpt), exist_ok=True) 19 | print("Downloading the weight of neural vocoder: TFGAN") 20 | urllib.request.urlretrieve( 21 | "https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1", 22 | Config.ckpt, 23 | ) 24 | print( 25 | "Weights downloaded in: {} Size: {}".format( 26 | Config.ckpt, os.path.getsize(Config.ckpt) 27 | ) 28 | ) 29 | # cmd = "wget https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1 -O " + Config.ckpt 30 | # os.system(cmd) 31 | -------------------------------------------------------------------------------- /voicefixer/tools/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | 4 | 5 | def read_list(fname): 6 | result = [] 7 | with open(fname, "r") as f: 8 | for each in f.readlines(): 9 | each = each.strip("\n") 10 | result.append(each) 11 | return result 12 | 13 | 14 | def write_list(list, fname): 15 | with open(fname, "w") as f: 16 | for word in list: 17 | f.write(word) 18 | f.write("\n") 19 | 20 | 21 | def write_json(my_dict, fname): 22 | # print("Save json file at "+fname) 23 | json_str = json.dumps(my_dict) 24 | with open(fname, "w") as json_file: 25 | json_file.write(json_str) 26 | 27 | 28 | def load_json(fname): 29 | with open(fname, "r") as f: 30 | data = json.load(f) 31 | return data 32 | 33 | 34 | def save_pickle(obj, fname): 35 | # print("Save pickle at "+fname) 36 | with open(fname, "wb") as f: 37 | pickle.dump(obj, f) 38 | 39 | 40 | def load_pickle(fname): 41 | # print("Load pickle at "+fname) 42 | with open(fname, "rb") as f: 43 | res = pickle.load(f) 44 | return res 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /voicefixer/tools/random_.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | RANDOM_RESOLUTION = 2**31 5 | 6 | 7 | def random_torch(high, to_int=True): 8 | if to_int: 9 | return int((torch.rand(1)) * high) # do not use numpy.random.random 10 | else: 11 | return (torch.rand(1)) * high # do not use numpy.random.random 12 | 13 | 14 | def shuffle_torch(list): 15 | length = len(list) 16 | res = [] 17 | order = torch.randperm(length) 18 | for each in order: 19 | res.append(list[each]) 20 | assert len(list) == len(res) 21 | return res 22 | 23 | 24 | def random_choose_list(list): 25 | num = int(uniform_torch(0, len(list))) 26 | return list[num] 27 | 28 | 29 | def normal_torch(mean=0, segma=1): 30 | return float(torch.normal(mean=mean, std=torch.Tensor([segma]))[0]) 31 | 32 | 33 | def uniform_torch(lower, upper): 34 | if abs(lower - upper) < 1e-5: 35 | return upper 36 | return (upper - lower) * torch.rand(1) + lower 37 | 38 | 39 | def random_key(keys: list, weights: list): 40 | return random.choices(keys, weights=weights)[0] 41 | 42 | 43 | def random_select(probs): 44 | res = [] 45 | chance = random_torch(RANDOM_RESOLUTION) 46 | threshold = None 47 | for prob in probs: 48 | # if(threshold is None):threshold=prob 49 | # else:threshold*=prob 50 | threshold = prob 51 | res.append(chance < threshold * RANDOM_RESOLUTION) 52 | return res, chance 53 | -------------------------------------------------------------------------------- /voicefixer/restorer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:31 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import os 14 | import torch 15 | import urllib.request 16 | 17 | meta = { 18 | "voicefixer_fe": { 19 | "path": os.path.join( 20 | os.path.expanduser("~"), 21 | ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt", 22 | ), 23 | "url": "https://zenodo.org/record/5600188/files/vf.ckpt?download=1", 24 | }, 25 | } 26 | 27 | if not os.path.exists(meta["voicefixer_fe"]["path"]): 28 | os.makedirs(os.path.dirname(meta["voicefixer_fe"]["path"]), exist_ok=True) 29 | print("Downloading the main structure of voicefixer") 30 | 31 | urllib.request.urlretrieve( 32 | meta["voicefixer_fe"]["url"], meta["voicefixer_fe"]["path"] 33 | ) 34 | print( 35 | "Weights downloaded in: {} Size: {}".format( 36 | meta["voicefixer_fe"]["path"], 37 | os.path.getsize(meta["voicefixer_fe"]["path"]), 38 | ) 39 | ) 40 | 41 | # cmd = "wget "+ meta["voicefixer_fe"]['url'] + " -O " + meta["voicefixer_fe"]['path'] 42 | # os.system(cmd) 43 | # temp = torch.load(meta["voicefixer_fe"]['path']) 44 | # torch.save(temp['state_dict'], os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt")) 45 | -------------------------------------------------------------------------------- /test/streamlit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import librosa 5 | import soundfile 6 | import streamlit as st 7 | import torch 8 | from io import BytesIO 9 | from voicefixer import VoiceFixer 10 | 11 | 12 | @st.experimental_singleton 13 | def init_voicefixer(): 14 | return VoiceFixer() 15 | 16 | 17 | # init with global shared singleton instance 18 | voice_fixer = init_voicefixer() 19 | 20 | 21 | sample_rate = 44100 22 | 23 | 24 | st.write("Wav player") 25 | 26 | 27 | w = st.file_uploader("Upload a wav file", type="wav") 28 | 29 | 30 | if w: 31 | st.write("Inference : ") 32 | 33 | # choose options 34 | mode = st.radio( 35 | "Voice fixer modes (0: original mode, 1: Add preprocessing module 2: Train mode (may work sometimes on seriously degraded speech))", 36 | [0, 1, 2], 37 | ) 38 | if torch.cuda.is_available(): 39 | is_cuda = st.radio("Turn on GPU", [True, False]) 40 | if is_cuda != list(voice_fixer._model.parameters())[0].is_cuda: 41 | device = "cuda" if is_cuda else "cpu" 42 | voice_fixer._model = voice_fixer._model.to(device) 43 | else: 44 | is_cuda = False 45 | 46 | t1 = time.time() 47 | 48 | # Load audio from binary 49 | audio, _ = librosa.load(w, sr=sample_rate, mono=True) 50 | 51 | # Inference 52 | pred_wav = voice_fixer.restore_inmem(audio, mode=mode, cuda=is_cuda) 53 | 54 | pred_time = time.time() - t1 55 | 56 | # original audio 57 | st.write("Original Audio : ") 58 | 59 | st.audio(w) 60 | 61 | # predicted audio 62 | st.write("Predicted Audio : ") 63 | 64 | # make buffer 65 | with BytesIO() as buffer: 66 | soundfile.write(buffer, pred_wav.T, samplerate=sample_rate, format="WAV") 67 | st.write("Time: {:.3f}s".format(pred_time)) 68 | st.audio(buffer.getvalue(), format="audio/wav") 69 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/pqmf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import scipy.io.wavfile 7 | 8 | 9 | class PQMF(nn.Module): 10 | def __init__(self, N, M, file_path="utils/pqmf_hk_4_64.dat"): 11 | super().__init__() 12 | self.N = N # nsubband 13 | self.M = M # nfilter 14 | self.ana_conv_filter = nn.Conv1d( 15 | 1, out_channels=N, kernel_size=M, stride=N, bias=False 16 | ) 17 | data = np.reshape(np.fromfile(file_path, dtype=np.float32), (N, M)) 18 | data = np.flipud(data.T).T 19 | gk = data.copy() 20 | data = np.reshape(data, (N, 1, M)).copy() 21 | dict_new = self.ana_conv_filter.state_dict().copy() 22 | dict_new["weight"] = torch.from_numpy(data) 23 | self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) 24 | self.ana_conv_filter.load_state_dict(dict_new) 25 | 26 | self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) 27 | self.syn_conv_filter = nn.Conv1d( 28 | N, out_channels=N, kernel_size=M // N, stride=1, bias=False 29 | ) 30 | gk = np.transpose(np.reshape(gk, (4, 16, 4)), (1, 0, 2)) * N 31 | gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() 32 | dict_new = self.syn_conv_filter.state_dict().copy() 33 | dict_new["weight"] = torch.from_numpy(gk) 34 | self.syn_conv_filter.load_state_dict(dict_new) 35 | 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def analysis(self, inputs): 40 | return self.ana_conv_filter(self.ana_pad(inputs)) 41 | 42 | def synthesis(self, inputs): 43 | return self.syn_conv_filter(self.syn_pad(inputs)) 44 | 45 | def forward(self, inputs): 46 | return self.ana_conv_filter(self.ana_pad(inputs)) 47 | 48 | 49 | if __name__ == "__main__": 50 | a = PQMF(4, 64) 51 | # x = np.load('data/train/audio/010000.npy') 52 | x = np.zeros([8, 24000], np.float32) 53 | x = np.reshape(x, (8, 1, -1)) 54 | x = torch.from_numpy(x) 55 | b = a.analysis(x) 56 | c = a.synthesis(b) 57 | print(x.shape, b.shape, c.shape) 58 | b = (b * 32768).numpy() 59 | b = np.reshape(np.transpose(b, (0, 2, 1)), (-1, 1)).astype(np.int16) 60 | # b.tofile('1.pcm') 61 | # np.reshape(np.transpose(c.numpy()*32768, (0, 2, 1)), (-1,1)).astype(np.int16).tofile('2.pcm') 62 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/res_msd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResStack(nn.Module): 21 | def __init__(self, channels=384, kernel_size=3, resstack_depth=3, hp=None): 22 | super(ResStack, self).__init__() 23 | dilation = [2 * i + 1 for i in range(resstack_depth)] # [1, 3, 5] 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[i], 33 | padding=get_padding(kernel_size, dilation[i]), 34 | ) 35 | ) 36 | for i in range(resstack_depth) 37 | ] 38 | ) 39 | self.convs1.apply(init_weights) 40 | 41 | self.convs2 = nn.ModuleList( 42 | [ 43 | weight_norm( 44 | Conv1d( 45 | channels, 46 | channels, 47 | kernel_size, 48 | 1, 49 | dilation=1, 50 | padding=get_padding(kernel_size, 1), 51 | ) 52 | ) 53 | for i in range(resstack_depth) 54 | ] 55 | ) 56 | self.convs2.apply(init_weights) 57 | 58 | def forward(self, x): 59 | for c1, c2 in zip(self.convs1, self.convs2): 60 | xt = F.leaky_relu(x, LRELU_SLOPE) 61 | xt = c1(xt) 62 | xt = F.leaky_relu(xt, LRELU_SLOPE) 63 | xt = c2(xt) 64 | x = xt + x 65 | return x 66 | 67 | def remove_weight_norm(self): 68 | for l in self.convs1: 69 | remove_weight_norm(l) 70 | for l in self.convs2: 71 | remove_weight_norm(l) 72 | -------------------------------------------------------------------------------- /test/inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : inference.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/6/21 3:08 PM Haohe Liu 1.0 None 11 | """ 12 | 13 | from voicefixer import VoiceFixer 14 | from voicefixer import Vocoder 15 | 16 | from os.path import isdir, exists, basename, join 17 | from argparse import ArgumentParser 18 | from progressbar import * 19 | 20 | parser = ArgumentParser() 21 | 22 | parser.add_argument( 23 | "-i", 24 | "--input_file_path", 25 | default="/Users/liuhaohe/Desktop/test.wav", 26 | help="The .wav file or the audio folder to be processed", 27 | ) 28 | parser.add_argument( 29 | "-o", "--output_path", default=".", help="The output dirpath for the results" 30 | ) 31 | parser.add_argument("-m", "--models", default="voicefixer_fe") 32 | parser.add_argument( 33 | "--cuda", type=bool, default=False, help="Whether use GPU acceleration." 34 | ) 35 | args = parser.parse_args() 36 | 37 | if __name__ == "__main__": 38 | voicefixer = VoiceFixer() 39 | 40 | if not isdir(args.output_path): 41 | raise ValueError("Error: output path need to be a directory, not a file name.") 42 | if not exists(args.output_path): 43 | os.makedirs(args.output_path, exist_ok=True) 44 | 45 | if not isdir(args.input_file_path): 46 | assert ( 47 | args.input_file_path[-3:] == "wav" or args.input_file_path[-4:] == "flac" 48 | ), ( 49 | "Error: invalid file " 50 | + args.input_file_path 51 | + ", we only accept .wav and .flac file." 52 | ) 53 | output_path = join(args.output_path, basename(args.input_file_path)) 54 | print("Start Prediction.") 55 | voicefixer.restore( 56 | input=args.input_file_path, output=output_path, cuda=args.cuda 57 | ) 58 | else: 59 | files = os.listdir(args.input_file_path) 60 | print("Found", len(files), "files in", args.input_file_path) 61 | widgets = [ 62 | "Performing Resotartion", 63 | " [", 64 | Timer(), 65 | "] ", 66 | Bar(), 67 | " (", 68 | ETA(), 69 | ") ", 70 | ] 71 | pbar = ProgressBar(widgets=widgets).start() 72 | print("Start Prediction.") 73 | for i, file in enumerate(files): 74 | if not file[-3:] == "wav" and not file[-4:] == "flac": 75 | print( 76 | "Ignore file", 77 | file, 78 | " unsupported file type. Please use wav or flac format.", 79 | ) 80 | continue 81 | output_path = join(args.output_path, basename(file)) 82 | voicefixer.restore( 83 | input=join(args.input_file_path, file), 84 | output=output_path, 85 | cuda=args.cuda, 86 | ) 87 | pbar.update(int((i / (len(files))) * 100)) 88 | print("Congratulations! Prediction Complete.") 89 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : test.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 11:02 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import git 14 | import os 15 | import sys 16 | import librosa 17 | import numpy as np 18 | import torch 19 | 20 | git_root = git.Repo("", search_parent_directories=True).git.rev_parse("--show-toplevel") 21 | sys.path.append(git_root) 22 | from voicefixer import VoiceFixer, Vocoder 23 | 24 | os.makedirs(os.path.join(git_root, "test/utterance/output"), exist_ok=True) 25 | 26 | 27 | def check(fname): 28 | """ 29 | check if the output is normal 30 | """ 31 | output = os.path.join(git_root, "test/utterance/output", fname) 32 | target = os.path.join(git_root, "test/utterance/target", fname) 33 | output, _ = librosa.load(output, sr=44100) 34 | target, _ = librosa.load(target, sr=44100) 35 | assert np.mean(np.abs(output - target)) < 0.01 36 | 37 | 38 | # TEST VOICEFIXER 39 | ## Initialize a voicefixer 40 | print("Initializing VoiceFixer...") 41 | voicefixer = VoiceFixer() 42 | # Mode 0: Original Model (suggested by default) 43 | # Mode 1: Add preprocessing module (remove higher frequency) 44 | # Mode 2: Train mode (might work sometimes on seriously degraded real speech) 45 | for mode in [0, 1, 2]: 46 | print("Test voicefixer mode", mode, end=", ") 47 | print("Using CPU:") 48 | voicefixer.restore( 49 | input=os.path.join( 50 | git_root, "test/utterance/original/original.flac" 51 | ), # low quality .wav/.flac file 52 | output=os.path.join( 53 | git_root, "test/utterance/output/output_mode_" + str(mode) + ".flac" 54 | ), # save file path 55 | cuda=False, # GPU acceleration 56 | mode=mode, 57 | ) 58 | if mode != 2: 59 | check("output_mode_" + str(mode) + ".flac") 60 | 61 | if torch.cuda.is_available(): 62 | print("Using GPU:") 63 | voicefixer.restore( 64 | input=os.path.join(git_root, "test/utterance/original/original.flac"), 65 | # low quality .wav/.flac file 66 | output=os.path.join( 67 | git_root, "test/utterance/output/output_mode_" + str(mode) + ".flac" 68 | ), 69 | # save file path 70 | cuda=True, # GPU acceleration 71 | mode=mode, 72 | ) 73 | if mode != 2: 74 | check("output_mode_" + str(mode) + ".flac") 75 | print("Pass") 76 | 77 | # TEST VOCODER 78 | ## Initialize a vocoder 79 | print("Initializing 44.1kHz speech vocoder...") 80 | vocoder = Vocoder(sample_rate=44100) 81 | 82 | ### read wave (fpath) -> mel spectrogram -> vocoder -> wave -> save wave (out_path) 83 | print("Test vocoder using groundtruth mel spectrogram...") 84 | print("Using CPU:") 85 | vocoder.oracle( 86 | fpath=os.path.join(git_root, "test/utterance/original/p360_001_mic1.flac"), 87 | out_path=os.path.join(git_root, "test/utterance/output/oracle.flac"), 88 | cuda=False, 89 | ) # GPU acceleration 90 | 91 | check("oracle.flac") 92 | 93 | if torch.cuda.is_available(): 94 | print("Using GPU:") 95 | vocoder.oracle( 96 | fpath=os.path.join(git_root, "test/utterance/original/p360_001_mic1.flac"), 97 | out_path=os.path.join(git_root, "test/utterance/output/oracle.flac"), 98 | cuda=True, 99 | ) # GPU acceleration 100 | # Another interface 101 | # vocoder.forward(mel=mel) 102 | check("oracle.flac") 103 | 104 | print("Pass") 105 | -------------------------------------------------------------------------------- /voicefixer/vocoder/base.py: -------------------------------------------------------------------------------- 1 | from voicefixer.vocoder.model.generator import Generator 2 | from voicefixer.tools.wav import read_wave, save_wave 3 | from voicefixer.tools.pytorch_util import * 4 | from voicefixer.vocoder.model.util import * 5 | from voicefixer.vocoder.config import Config 6 | import os 7 | import numpy as np 8 | 9 | 10 | class Vocoder(nn.Module): 11 | def __init__(self, sample_rate): 12 | super(Vocoder, self).__init__() 13 | Config.refresh(sample_rate) 14 | self.rate = sample_rate 15 | if(not os.path.exists(Config.ckpt)): 16 | raise RuntimeError("Error 1: The checkpoint for synthesis module / vocoder (model.ckpt-1490000_trimed) is not found in ~/.cache/voicefixer/synthesis_module/44100. \ 17 | By default the checkpoint should be download automatically by this program. Something bad may happened. Apologies for the inconvenience.\ 18 | But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/model.ckpt-1490000_trimed.pt?download=1") 19 | self._load_pretrain(Config.ckpt) 20 | self.weight_torch = Config.get_mel_weight_torch(percent=1.0)[ 21 | None, None, None, ... 22 | ] 23 | 24 | def _load_pretrain(self, pth): 25 | self.model = Generator(Config.cin_channels) 26 | checkpoint = load_checkpoint(pth, torch.device("cpu")) 27 | load_try(checkpoint["generator"], self.model) 28 | self.model.eval() 29 | self.model.remove_weight_norm() 30 | self.model.remove_weight_norm() 31 | for p in self.model.parameters(): 32 | p.requires_grad = False 33 | 34 | # def vocoder_mel_npy(self, mel, save_dir, sample_rate, gain): 35 | # mel = mel / Config.get_mel_weight(percent=gain)[...,None] 36 | # mel = normalize(amp_to_db(np.abs(mel)) - 20) 37 | # mel = pre(np.transpose(mel, (1, 0))) 38 | # with torch.no_grad(): 39 | # wav_re = self.model(mel) # torch.Size([1, 1, 104076]) 40 | # save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate) 41 | 42 | def forward(self, mel, cuda=False): 43 | """ 44 | :param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel] 45 | :return: [batchsize, 1, samples] 46 | """ 47 | assert mel.size()[-1] == 128 48 | check_cuda_availability(cuda=cuda) 49 | self.model = try_tensor_cuda(self.model, cuda=cuda) 50 | mel = try_tensor_cuda(mel, cuda=cuda) 51 | self.weight_torch = self.weight_torch.type_as(mel) 52 | mel = mel / self.weight_torch 53 | mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0) 54 | mel = tr_pre(mel[:, 0, ...]) 55 | wav_re = self.model(mel) 56 | return wav_re 57 | 58 | def oracle(self, fpath, out_path, cuda=False): 59 | check_cuda_availability(cuda=cuda) 60 | self.model = try_tensor_cuda(self.model, cuda=cuda) 61 | wav = read_wave(fpath, sample_rate=self.rate)[..., 0] 62 | wav = wav / np.max(np.abs(wav)) 63 | stft = np.abs( 64 | librosa.stft( 65 | wav, 66 | hop_length=Config.hop_length, 67 | win_length=Config.win_size, 68 | n_fft=Config.n_fft, 69 | ) 70 | ) 71 | mel = linear_to_mel(stft) 72 | mel = normalize(amp_to_db(np.abs(mel)) - 20) 73 | mel = pre(np.transpose(mel, (1, 0))) 74 | mel = try_tensor_cuda(mel, cuda=cuda) 75 | with torch.no_grad(): 76 | wav_re = self.model(mel) 77 | save_wave(tensor2numpy(wav_re * 2**15), out_path, sample_rate=self.rate) 78 | 79 | 80 | if __name__ == "__main__": 81 | model = Vocoder(sample_rate=44100) 82 | print(model.device) 83 | # model.load_pretrain(Config.ckpt) 84 | # model.oracle(path="/Users/liuhaohe/Desktop/test.wav", 85 | # sample_rate=44100, 86 | # save_dir="/Users/liuhaohe/Desktop/test_vocoder.wav") 87 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/pqmf.py: -------------------------------------------------------------------------------- 1 | """ 2 | @File : subband_util.py 3 | @Contact : liu.8948@buckeyemail.osu.edu 4 | @License : (C)Copyright 2020-2021 5 | @Modify Time @Author @Version @Desciption 6 | ------------ ------- -------- ----------- 7 | 2020/4/3 4:54 PM Haohe Liu 1.0 None 8 | """ 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.nn as nn 13 | import numpy as np 14 | import os.path as op 15 | from scipy.io import loadmat 16 | 17 | 18 | def load_mat2numpy(fname=""): 19 | if len(fname) == 0: 20 | return None 21 | else: 22 | return loadmat(fname) 23 | 24 | 25 | class PQMF(nn.Module): 26 | def __init__(self, N, M, project_root): 27 | super().__init__() 28 | self.N = N # nsubband 29 | self.M = M # nfilter 30 | try: 31 | assert (N, M) in [(8, 64), (4, 64), (2, 64)] 32 | except: 33 | print("Warning:", N, "subbandand ", M, " filter is not supported") 34 | self.pad_samples = 64 35 | self.name = str(N) + "_" + str(M) + ".mat" 36 | self.ana_conv_filter = nn.Conv1d( 37 | 1, out_channels=N, kernel_size=M, stride=N, bias=False 38 | ) 39 | data = load_mat2numpy( 40 | op.join( 41 | project_root, 42 | "arnold_workspace/restorer/tools/pytorch/modules/filters/f_" 43 | + self.name, 44 | ) 45 | ) 46 | data = data["f"].astype(np.float32) / N 47 | data = np.flipud(data.T).T 48 | data = np.reshape(data, (N, 1, M)).copy() 49 | dict_new = self.ana_conv_filter.state_dict().copy() 50 | dict_new["weight"] = torch.from_numpy(data) 51 | self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) 52 | self.ana_conv_filter.load_state_dict(dict_new) 53 | 54 | self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) 55 | self.syn_conv_filter = nn.Conv1d( 56 | N, out_channels=N, kernel_size=M // N, stride=1, bias=False 57 | ) 58 | gk = load_mat2numpy( 59 | op.join( 60 | project_root, 61 | "arnold_workspace/restorer/tools/pytorch/modules/filters/h_" 62 | + self.name, 63 | ) 64 | ) 65 | gk = gk["h"].astype(np.float32) 66 | gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N 67 | gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() 68 | dict_new = self.syn_conv_filter.state_dict().copy() 69 | dict_new["weight"] = torch.from_numpy(gk) 70 | self.syn_conv_filter.load_state_dict(dict_new) 71 | 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def __analysis_channel(self, inputs): 76 | return self.ana_conv_filter(self.ana_pad(inputs)) 77 | 78 | def __systhesis_channel(self, inputs): 79 | ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1) 80 | return torch.reshape(ret, (ret.shape[0], 1, -1)) 81 | 82 | def analysis(self, inputs): 83 | """ 84 | :param inputs: [batchsize,channel,raw_wav],value:[0,1] 85 | :return: 86 | """ 87 | inputs = F.pad(inputs, ((0, self.pad_samples))) 88 | ret = None 89 | for i in range(inputs.size()[1]): # channels 90 | if ret is None: 91 | ret = self.__analysis_channel(inputs[:, i : i + 1, :]) 92 | else: 93 | ret = torch.cat( 94 | (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1 95 | ) 96 | return ret 97 | 98 | def synthesis(self, data): 99 | """ 100 | :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1] 101 | :return: 102 | """ 103 | ret = None 104 | # data = F.pad(data,((0,self.pad_samples//self.N))) 105 | for i in range(data.size()[1]): # channels 106 | if i % self.N == 0: 107 | if ret is None: 108 | ret = self.__systhesis_channel(data[:, i : i + self.N, :]) 109 | else: 110 | new = self.__systhesis_channel(data[:, i : i + self.N, :]) 111 | ret = torch.cat((ret, new), dim=1) 112 | ret = ret[..., : -self.pad_samples] 113 | return ret 114 | 115 | def forward(self, inputs): 116 | return self.ana_conv_filter(self.ana_pad(inputs)) 117 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/util.py: -------------------------------------------------------------------------------- 1 | from voicefixer.vocoder.config import Config 2 | from voicefixer.tools.pytorch_util import try_tensor_cuda, check_cuda_availability 3 | import torch 4 | import librosa 5 | import numpy as np 6 | 7 | 8 | def tr_normalize(S): 9 | if Config.allow_clipping_in_normalization: 10 | if Config.symmetric_mels: 11 | return torch.clip( 12 | (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db)) 13 | - Config.max_abs_value, 14 | -Config.max_abs_value, 15 | Config.max_abs_value, 16 | ) 17 | else: 18 | return torch.clip( 19 | Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)), 20 | 0, 21 | Config.max_abs_value, 22 | ) 23 | 24 | assert S.max() <= 0 and S.min() - Config.min_db >= 0 25 | if Config.symmetric_mels: 26 | return (2 * Config.max_abs_value) * ( 27 | (S - Config.min_db) / (-Config.min_db) 28 | ) - Config.max_abs_value 29 | else: 30 | return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)) 31 | 32 | 33 | def tr_amp_to_db(x): 34 | min_level = torch.exp(Config.min_level_db / 20 * torch.log(torch.tensor(10.0))) 35 | min_level = min_level.type_as(x) 36 | return 20 * torch.log10(torch.maximum(min_level, x)) 37 | 38 | 39 | def normalize(S): 40 | if Config.allow_clipping_in_normalization: 41 | if Config.symmetric_mels: 42 | return np.clip( 43 | (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db)) 44 | - Config.max_abs_value, 45 | -Config.max_abs_value, 46 | Config.max_abs_value, 47 | ) 48 | else: 49 | return np.clip( 50 | Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)), 51 | 0, 52 | Config.max_abs_value, 53 | ) 54 | 55 | assert S.max() <= 0 and S.min() - Config.min_db >= 0 56 | if Config.symmetric_mels: 57 | return (2 * Config.max_abs_value) * ( 58 | (S - Config.min_db) / (-Config.min_db) 59 | ) - Config.max_abs_value 60 | else: 61 | return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)) 62 | 63 | 64 | def amp_to_db(x): 65 | min_level = np.exp(Config.min_level_db / 20 * np.log(10)) 66 | return 20 * np.log10(np.maximum(min_level, x)) 67 | 68 | 69 | def tr_pre(npy): 70 | # conditions = torch.FloatTensor(npy).type_as(npy) # to(device) 71 | conditions = npy.transpose(1, 2) 72 | l = conditions.size(-1) 73 | pad_tail = l % 2 + 4 74 | zeros = ( 75 | torch.zeros([conditions.size()[0], Config.num_mels, pad_tail]).type_as( 76 | conditions 77 | ) 78 | + -4.0 79 | ) 80 | return torch.cat([conditions, zeros], dim=-1) 81 | 82 | 83 | def pre(npy): 84 | conditions = npy 85 | ## padding tail 86 | if type(conditions) == np.ndarray: 87 | conditions = torch.FloatTensor(conditions).unsqueeze(0) 88 | else: 89 | conditions = torch.FloatTensor(conditions.float()).unsqueeze(0) 90 | conditions = conditions.transpose(1, 2) 91 | l = conditions.size(-1) 92 | pad_tail = l % 2 + 4 93 | zeros = torch.zeros([1, Config.num_mels, pad_tail]) + -4.0 94 | return torch.cat([conditions, zeros], dim=-1) 95 | 96 | 97 | def load_try(state, model): 98 | model_dict = model.state_dict() 99 | try: 100 | model_dict.update(state) 101 | model.load_state_dict(model_dict) 102 | except RuntimeError as e: 103 | print(str(e)) 104 | model_dict = model.state_dict() 105 | for k, v in state.items(): 106 | model_dict[k] = v 107 | model.load_state_dict(model_dict) 108 | 109 | 110 | def load_checkpoint(checkpoint_path, device): 111 | checkpoint = torch.load(checkpoint_path, map_location=device) 112 | return checkpoint 113 | 114 | 115 | def build_mel_basis(): 116 | return librosa.filters.mel( 117 | Config.sample_rate, 118 | Config.n_fft, 119 | htk=True, 120 | n_mels=Config.num_mels, 121 | fmin=0, 122 | fmax=int(Config.sample_rate // 2), 123 | ) 124 | 125 | 126 | def linear_to_mel(spectogram): 127 | _mel_basis = build_mel_basis() 128 | return np.dot(_mel_basis, spectogram) 129 | 130 | 131 | if __name__ == "__main__": 132 | data = torch.randn((3, 5, 100)) 133 | b = normalize(amp_to_db(data.numpy())) 134 | a = tr_normalize(tr_amp_to_db(data)).numpy() 135 | print(a - b) 136 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # python3 setup.py sdist bdist_wheel 4 | """ 5 | @File : setup.py.py 6 | @Contact : haoheliu@gmail.com 7 | @License : (C)Copyright 2020-2100 8 | 9 | @Modify Time @Author @Version @Desciption 10 | ------------ ------- -------- ----------- 11 | 9/6/21 5:16 PM Haohe Liu 1.0 None 12 | """ 13 | 14 | # !/usr/bin/env python 15 | # -*- coding: utf-8 -*- 16 | 17 | # Note: To use the 'upload' functionality of this file, you must: 18 | # $ pipenv install twine --dev 19 | 20 | import io 21 | import os 22 | import sys 23 | from shutil import rmtree 24 | 25 | from setuptools import find_packages, setup, Command 26 | 27 | # Package meta-data. 28 | NAME = "voicefixer" 29 | DESCRIPTION = "This package is written for the restoration of degraded speech" 30 | URL = "https://github.com/haoheliu/voicefixer" 31 | EMAIL = "haoheliu@gmail.com" 32 | AUTHOR = "Haohe Liu" 33 | REQUIRES_PYTHON = ">=3.7.0" 34 | VERSION = "0.1.2" 35 | 36 | # What packages are required for this module to be executed? 37 | REQUIRED = [ 38 | "librosa>=0.8.1,<0.9.0", 39 | "matplotlib", 40 | "torch>=1.7.0", 41 | "progressbar", 42 | "torchlibrosa==0.0.7", 43 | "GitPython", 44 | "streamlit>=1.12.0" 45 | "pyyaml", 46 | ] 47 | 48 | # What packages are optional? 49 | EXTRAS = {} 50 | 51 | # The rest you shouldn't have to touch too much :) 52 | # ------------------------------------------------ 53 | # Except, perhaps the License and Trove Classifiers! 54 | # If you do change the License, remember to change the Trove Classifier for that! 55 | 56 | here = os.path.abspath(os.path.dirname(__file__)) 57 | 58 | # Import the README and use it as the long-description. 59 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 60 | try: 61 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 62 | long_description = "\n" + f.read() 63 | except FileNotFoundError: 64 | long_description = DESCRIPTION 65 | 66 | # Load the package's __version__.py module as a dictionary. 67 | about = {} 68 | if not VERSION: 69 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 70 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 71 | exec(f.read(), about) 72 | else: 73 | about["__version__"] = VERSION 74 | 75 | 76 | class UploadCommand(Command): 77 | """Support setup.py upload.""" 78 | 79 | description = "Build and publish the package." 80 | user_options = [] 81 | 82 | @staticmethod 83 | def status(s): 84 | """Prints things in bold.""" 85 | print("\033[1m{0}\033[0m".format(s)) 86 | 87 | def initialize_options(self): 88 | pass 89 | 90 | def finalize_options(self): 91 | pass 92 | 93 | def run(self): 94 | try: 95 | self.status("Removing previous builds…") 96 | rmtree(os.path.join(here, "dist")) 97 | except OSError: 98 | pass 99 | 100 | self.status("Building Source and Wheel (universal) distribution…") 101 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 102 | 103 | self.status("Uploading the package to PyPI via Twine…") 104 | os.system("twine upload dist/*") 105 | 106 | self.status("Pushing git tags…") 107 | os.system("git tag v{0}".format(about["__version__"])) 108 | os.system("git push --tags") 109 | 110 | sys.exit() 111 | 112 | 113 | # Where the magic happens: 114 | setup( 115 | name=NAME, 116 | version=about["__version__"], 117 | description=DESCRIPTION, 118 | long_description=long_description, 119 | long_description_content_type="text/markdown", 120 | author=AUTHOR, 121 | author_email=EMAIL, 122 | python_requires=REQUIRES_PYTHON, 123 | url=URL, 124 | # packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 125 | # If your package is a single module, use this instead of 'packages': 126 | py_modules=["voicefixer"], 127 | # entry_points={ 128 | # 'console_scripts': ['mycli=mymodule:cli'], 129 | # }, 130 | install_requires=REQUIRED, 131 | extras_require=EXTRAS, 132 | packages=find_packages(), 133 | include_package_data=True, 134 | license="MIT", 135 | classifiers=[ 136 | # Trove classifiers 137 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 138 | "License :: OSI Approved :: MIT License", 139 | "Programming Language :: Python", 140 | "Programming Language :: Python :: 3", 141 | "Programming Language :: Python :: 3.7", 142 | "Programming Language :: Python :: Implementation :: CPython", 143 | "Programming Language :: Python :: Implementation :: PyPy", 144 | ], 145 | # $ setup.py publish support. 146 | cmdclass={ 147 | "upload": UploadCommand, 148 | }, 149 | scripts=['bin/voicefixer.cmd', "bin/voicefixer"] 150 | ) 151 | -------------------------------------------------------------------------------- /voicefixer/tools/pytorch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def check_cuda_availability(cuda): 7 | if cuda and not torch.cuda.is_available(): 8 | raise RuntimeError("Error: You set cuda=True but no cuda device found.") 9 | 10 | 11 | def try_tensor_cuda(tensor, cuda): 12 | if cuda and torch.cuda.is_available(): 13 | return tensor.cuda() 14 | else: 15 | return tensor.cpu() 16 | 17 | 18 | def to_log(input): 19 | assert torch.sum(input < 0) == 0, ( 20 | str(input) + " has negative values counts " + str(torch.sum(input < 0)) 21 | ) 22 | return torch.log10(torch.clip(input, min=1e-8)) 23 | 24 | 25 | def from_log(input): 26 | input = torch.clip(input, min=-np.inf, max=5) 27 | return 10**input 28 | 29 | 30 | def move_data_to_device(x, device): 31 | if "float" in str(x.dtype): 32 | x = torch.Tensor(x) 33 | elif "int" in str(x.dtype): 34 | x = torch.LongTensor(x) 35 | else: 36 | return x 37 | return x.to(device) 38 | 39 | 40 | def tensor2numpy(tensor): 41 | if "cuda" in str(tensor.device): 42 | return tensor.detach().cpu().numpy() 43 | else: 44 | return tensor.detach().numpy() 45 | 46 | 47 | def count_parameters(model): 48 | for p in model.parameters(): 49 | if p.requires_grad: 50 | print(p.shape) 51 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 52 | 53 | 54 | def count_flops(model, audio_length): 55 | multiply_adds = False 56 | list_conv2d = [] 57 | 58 | def conv2d_hook(self, input, output): 59 | batch_size, input_channels, input_height, input_width = input[0].size() 60 | output_channels, output_height, output_width = output[0].size() 61 | 62 | kernel_ops = ( 63 | self.kernel_size[0] 64 | * self.kernel_size[1] 65 | * (self.in_channels / self.groups) 66 | * (2 if multiply_adds else 1) 67 | ) 68 | bias_ops = 1 if self.bias is not None else 0 69 | 70 | params = output_channels * (kernel_ops + bias_ops) 71 | flops = batch_size * params * output_height * output_width 72 | 73 | list_conv2d.append(flops) 74 | 75 | list_conv1d = [] 76 | 77 | def conv1d_hook(self, input, output): 78 | batch_size, input_channels, input_length = input[0].size() 79 | output_channels, output_length = output[0].size() 80 | 81 | kernel_ops = ( 82 | self.kernel_size[0] 83 | * (self.in_channels / self.groups) 84 | * (2 if multiply_adds else 1) 85 | ) 86 | bias_ops = 1 if self.bias is not None else 0 87 | 88 | params = output_channels * (kernel_ops + bias_ops) 89 | flops = batch_size * params * output_length 90 | 91 | list_conv1d.append(flops) 92 | 93 | list_linear = [] 94 | 95 | def linear_hook(self, input, output): 96 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 97 | 98 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 99 | bias_ops = self.bias.nelement() 100 | 101 | flops = batch_size * (weight_ops + bias_ops) 102 | list_linear.append(flops) 103 | 104 | list_bn = [] 105 | 106 | def bn_hook(self, input, output): 107 | list_bn.append(input[0].nelement()) 108 | 109 | list_relu = [] 110 | 111 | def relu_hook(self, input, output): 112 | list_relu.append(input[0].nelement()) 113 | 114 | list_pooling2d = [] 115 | 116 | def pooling2d_hook(self, input, output): 117 | batch_size, input_channels, input_height, input_width = input[0].size() 118 | output_channels, output_height, output_width = output[0].size() 119 | 120 | kernel_ops = self.kernel_size * self.kernel_size 121 | bias_ops = 0 122 | params = output_channels * (kernel_ops + bias_ops) 123 | flops = batch_size * params * output_height * output_width 124 | 125 | list_pooling2d.append(flops) 126 | 127 | list_pooling1d = [] 128 | 129 | def pooling1d_hook(self, input, output): 130 | batch_size, input_channels, input_length = input[0].size() 131 | output_channels, output_length = output[0].size() 132 | 133 | kernel_ops = self.kernel_size 134 | bias_ops = 0 135 | params = output_channels * (kernel_ops + bias_ops) 136 | flops = batch_size * params * output_length 137 | 138 | list_pooling2d.append(flops) 139 | 140 | def foo(net): 141 | childrens = list(net.children()) 142 | if not childrens: 143 | if isinstance(net, nn.Conv2d): 144 | net.register_forward_hook(conv2d_hook) 145 | elif isinstance(net, nn.ConvTranspose2d): 146 | net.register_forward_hook(conv2d_hook) 147 | elif isinstance(net, nn.Conv1d): 148 | net.register_forward_hook(conv1d_hook) 149 | elif isinstance(net, nn.Linear): 150 | net.register_forward_hook(linear_hook) 151 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 152 | net.register_forward_hook(bn_hook) 153 | elif isinstance(net, nn.ReLU): 154 | net.register_forward_hook(relu_hook) 155 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 156 | net.register_forward_hook(pooling2d_hook) 157 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 158 | net.register_forward_hook(pooling1d_hook) 159 | else: 160 | print("Warning: flop of module {} is not counted!".format(net)) 161 | return 162 | for c in childrens: 163 | foo(c) 164 | 165 | foo(model) 166 | 167 | input = torch.rand(1, audio_length, 2) 168 | out = model(input) 169 | 170 | total_flops = ( 171 | sum(list_conv2d) 172 | + sum(list_conv1d) 173 | + sum(list_linear) 174 | + sum(list_bn) 175 | + sum(list_relu) 176 | + sum(list_pooling2d) 177 | + sum(list_pooling1d) 178 | ) 179 | 180 | return total_flops 181 | -------------------------------------------------------------------------------- /voicefixer/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | from genericpath import exists 3 | import os.path 4 | import argparse 5 | from voicefixer import VoiceFixer 6 | import torch 7 | import os 8 | 9 | 10 | def writefile(infile, outfile, mode, append_mode, cuda, verbose=False): 11 | if append_mode is True: 12 | outbasename, outext = os.path.splitext(os.path.basename(outfile)) 13 | outfile = os.path.join( 14 | os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext) 15 | ) 16 | if verbose: 17 | print("Processing {}, mode={}".format(infile, mode)) 18 | voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode)) 19 | 20 | def check_arguments(args): 21 | process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0 22 | # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \ 23 | # "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile) 24 | # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \ 25 | # "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder) 26 | assert ( 27 | process_file or process_folder 28 | ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h" 29 | 30 | # if(args.cuda and not torch.cuda.is_available()): 31 | # print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.") 32 | if process_file: 33 | assert os.path.exists(args.infile), ( 34 | "Error: The input file %s is not found." % args.infile 35 | ) 36 | output_dirname = os.path.dirname(args.outfile) 37 | if len(output_dirname) > 1: 38 | os.makedirs(output_dirname, exist_ok=True) 39 | if process_folder: 40 | assert os.path.exists(args.infolder), ( 41 | "Error: The input folder %s is not found." % args.infile 42 | ) 43 | output_dirname = args.outfolder 44 | if len(output_dirname) > 1: 45 | os.makedirs(args.outfolder, exist_ok=True) 46 | 47 | return process_file, process_folder 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="VoiceFixer - restores degraded speech" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--infile", 57 | type=str, 58 | default="", 59 | help="An input file to be processed by VoiceFixer.", 60 | ) 61 | parser.add_argument( 62 | "-o", 63 | "--outfile", 64 | type=str, 65 | default="outfile.wav", 66 | help="An output file to store the result.", 67 | ) 68 | 69 | parser.add_argument( 70 | "-ifdr", 71 | "--infolder", 72 | type=str, 73 | default="", 74 | help="Input folder. Place all your wav file that need process in this folder.", 75 | ) 76 | parser.add_argument( 77 | "-ofdr", 78 | "--outfolder", 79 | type=str, 80 | default="outfolder", 81 | help="Output folder. The processed files will be stored in this folder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--mode", help="mode", choices=["0", "1", "2", "all"], default="0" 86 | ) 87 | parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true") 88 | parser.add_argument( 89 | "--silent", 90 | help="Set this flag if you do not want to see any message.", 91 | default=False, 92 | action="store_true", 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | if torch.cuda.is_available() and not args.disable_cuda: 98 | cuda = True 99 | else: 100 | cuda = False 101 | 102 | process_file, process_folder = check_arguments(args) 103 | 104 | if not args.silent: 105 | print("Initializing VoiceFixer") 106 | voicefixer = VoiceFixer() 107 | 108 | if not args.silent: 109 | print("Start processing the input file %s." % args.infile) 110 | 111 | if process_file: 112 | audioext = os.path.splitext(os.path.basename(args.infile))[-1] 113 | if audioext != ".wav": 114 | raise ValueError( 115 | "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks." 116 | % audioext 117 | ) 118 | if args.mode == "all": 119 | for file_mode in range(3): 120 | writefile( 121 | args.infile, 122 | args.outfile, 123 | file_mode, 124 | True, 125 | cuda, 126 | verbose=not args.silent, 127 | ) 128 | else: 129 | writefile( 130 | args.infile, 131 | args.outfile, 132 | args.mode, 133 | False, 134 | cuda, 135 | verbose=not args.silent, 136 | ) 137 | 138 | if process_folder: 139 | files = [ 140 | file 141 | for file in os.listdir(args.infolder) 142 | if (os.path.splitext(os.path.basename(file))[-1] == ".wav") 143 | ] 144 | if not args.silent: 145 | print( 146 | "Found %s .wav files in the input folder %s. Start processing." 147 | % (len(files), args.infolder) 148 | ) 149 | for file in files: 150 | outbasename, outext = os.path.splitext(os.path.basename(file)) 151 | in_file = os.path.join(args.infolder, file) 152 | out_file = os.path.join(args.outfolder, file) 153 | 154 | if args.mode == "all": 155 | for file_mode in range(3): 156 | writefile( 157 | in_file, 158 | out_file, 159 | file_mode, 160 | True, 161 | cuda, 162 | verbose=not args.silent, 163 | ) 164 | else: 165 | writefile( 166 | in_file, out_file, args.mode, False, cuda, verbose=not args.silent 167 | ) 168 | 169 | if not args.silent: 170 | print("Done") 171 | -------------------------------------------------------------------------------- /bin/voicefixer: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | from genericpath import exists 3 | import os.path 4 | import argparse 5 | from voicefixer import VoiceFixer 6 | import torch 7 | import os 8 | 9 | 10 | def writefile(infile, outfile, mode, append_mode, cuda, verbose=False): 11 | if append_mode is True: 12 | outbasename, outext = os.path.splitext(os.path.basename(outfile)) 13 | outfile = os.path.join( 14 | os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext) 15 | ) 16 | if verbose: 17 | print("Processing {}, mode={}".format(infile, mode)) 18 | voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode)) 19 | 20 | def check_arguments(args): 21 | process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0 22 | # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \ 23 | # "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile) 24 | # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \ 25 | # "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder) 26 | assert ( 27 | process_file or process_folder 28 | ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h" 29 | 30 | # if(args.cuda and not torch.cuda.is_available()): 31 | # print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.") 32 | if process_file: 33 | assert os.path.exists(args.infile), ( 34 | "Error: The input file %s is not found." % args.infile 35 | ) 36 | output_dirname = os.path.dirname(args.outfile) 37 | if len(output_dirname) > 1: 38 | os.makedirs(output_dirname, exist_ok=True) 39 | if process_folder: 40 | assert os.path.exists(args.infolder), ( 41 | "Error: The input folder %s is not found." % args.infile 42 | ) 43 | output_dirname = args.outfolder 44 | if len(output_dirname) > 1: 45 | os.makedirs(args.outfolder, exist_ok=True) 46 | 47 | return process_file, process_folder 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="VoiceFixer - restores degraded speech" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--infile", 57 | type=str, 58 | default="", 59 | help="An input file to be processed by VoiceFixer.", 60 | ) 61 | parser.add_argument( 62 | "-o", 63 | "--outfile", 64 | type=str, 65 | default="outfile.wav", 66 | help="An output file to store the result.", 67 | ) 68 | 69 | parser.add_argument( 70 | "-ifdr", 71 | "--infolder", 72 | type=str, 73 | default="", 74 | help="Input folder. Place all your wav file that need process in this folder.", 75 | ) 76 | parser.add_argument( 77 | "-ofdr", 78 | "--outfolder", 79 | type=str, 80 | default="outfolder", 81 | help="Output folder. The processed files will be stored in this folder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--mode", help="mode", choices=["0", "1", "2", "all"], default="0" 86 | ) 87 | parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true") 88 | parser.add_argument( 89 | "--silent", 90 | help="Set this flag if you do not want to see any message.", 91 | default=False, 92 | action="store_true", 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | if torch.cuda.is_available() and not args.disable_cuda: 98 | cuda = True 99 | else: 100 | cuda = False 101 | 102 | process_file, process_folder = check_arguments(args) 103 | 104 | if not args.silent: 105 | print("Initializing VoiceFixer") 106 | voicefixer = VoiceFixer() 107 | 108 | if not args.silent: 109 | print("Start processing the input file %s." % args.infile) 110 | 111 | if process_file: 112 | audioext = os.path.splitext(os.path.basename(args.infile))[-1] 113 | if audioext != ".wav": 114 | raise ValueError( 115 | "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks." 116 | % audioext 117 | ) 118 | if args.mode == "all": 119 | for file_mode in range(3): 120 | writefile( 121 | args.infile, 122 | args.outfile, 123 | file_mode, 124 | True, 125 | cuda, 126 | verbose=not args.silent, 127 | ) 128 | else: 129 | writefile( 130 | args.infile, 131 | args.outfile, 132 | args.mode, 133 | False, 134 | cuda, 135 | verbose=not args.silent, 136 | ) 137 | 138 | if process_folder: 139 | if not args.silent: 140 | files = [ 141 | file 142 | for file in os.listdir(args.infolder) 143 | if (os.path.splitext(os.path.basename(file))[-1] == ".wav") 144 | ] 145 | print( 146 | "Found %s .wav files in the input folder %s. Start processing." 147 | % (len(files), args.infolder) 148 | ) 149 | for file in os.listdir(args.infolder): 150 | outbasename, outext = os.path.splitext(os.path.basename(file)) 151 | in_file = os.path.join(args.infolder, file) 152 | out_file = os.path.join(args.outfolder, file) 153 | 154 | if args.mode == "all": 155 | for file_mode in range(3): 156 | writefile( 157 | in_file, 158 | out_file, 159 | file_mode, 160 | True, 161 | cuda, 162 | verbose=not args.silent, 163 | ) 164 | else: 165 | writefile( 166 | in_file, out_file, args.mode, False, cuda, verbose=not args.silent 167 | ) 168 | 169 | if not args.silent: 170 | print("Done") -------------------------------------------------------------------------------- /voicefixer/restorer/model_kqq_bn.py: -------------------------------------------------------------------------------- 1 | from voicefixer.restorer.modules import * 2 | 3 | from voicefixer.tools.pytorch_util import * 4 | 5 | 6 | class UNetResComplex_100Mb(nn.Module): 7 | def __init__(self, channels, nsrc=1): 8 | super(UNetResComplex_100Mb, self).__init__() 9 | activation = "relu" 10 | momentum = 0.01 11 | 12 | self.nsrc = nsrc 13 | self.channels = channels 14 | self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks} 15 | 16 | self.encoder_block1 = EncoderBlockRes( 17 | in_channels=channels * nsrc, 18 | out_channels=32, 19 | downsample=(2, 2), 20 | activation=activation, 21 | momentum=momentum, 22 | ) 23 | self.encoder_block2 = EncoderBlockRes( 24 | in_channels=32, 25 | out_channels=64, 26 | downsample=(2, 2), 27 | activation=activation, 28 | momentum=momentum, 29 | ) 30 | self.encoder_block3 = EncoderBlockRes( 31 | in_channels=64, 32 | out_channels=128, 33 | downsample=(2, 2), 34 | activation=activation, 35 | momentum=momentum, 36 | ) 37 | self.encoder_block4 = EncoderBlockRes( 38 | in_channels=128, 39 | out_channels=256, 40 | downsample=(2, 2), 41 | activation=activation, 42 | momentum=momentum, 43 | ) 44 | self.encoder_block5 = EncoderBlockRes( 45 | in_channels=256, 46 | out_channels=384, 47 | downsample=(2, 2), 48 | activation=activation, 49 | momentum=momentum, 50 | ) 51 | self.encoder_block6 = EncoderBlockRes( 52 | in_channels=384, 53 | out_channels=384, 54 | downsample=(2, 2), 55 | activation=activation, 56 | momentum=momentum, 57 | ) 58 | self.conv_block7 = ConvBlockRes( 59 | in_channels=384, 60 | out_channels=384, 61 | size=3, 62 | activation=activation, 63 | momentum=momentum, 64 | ) 65 | self.decoder_block1 = DecoderBlockRes( 66 | in_channels=384, 67 | out_channels=384, 68 | stride=(2, 2), 69 | activation=activation, 70 | momentum=momentum, 71 | ) 72 | self.decoder_block2 = DecoderBlockRes( 73 | in_channels=384, 74 | out_channels=384, 75 | stride=(2, 2), 76 | activation=activation, 77 | momentum=momentum, 78 | ) 79 | self.decoder_block3 = DecoderBlockRes( 80 | in_channels=384, 81 | out_channels=256, 82 | stride=(2, 2), 83 | activation=activation, 84 | momentum=momentum, 85 | ) 86 | self.decoder_block4 = DecoderBlockRes( 87 | in_channels=256, 88 | out_channels=128, 89 | stride=(2, 2), 90 | activation=activation, 91 | momentum=momentum, 92 | ) 93 | self.decoder_block5 = DecoderBlockRes( 94 | in_channels=128, 95 | out_channels=64, 96 | stride=(2, 2), 97 | activation=activation, 98 | momentum=momentum, 99 | ) 100 | self.decoder_block6 = DecoderBlockRes( 101 | in_channels=64, 102 | out_channels=32, 103 | stride=(2, 2), 104 | activation=activation, 105 | momentum=momentum, 106 | ) 107 | 108 | self.after_conv_block1 = ConvBlockRes( 109 | in_channels=32, 110 | out_channels=32, 111 | size=3, 112 | activation=activation, 113 | momentum=momentum, 114 | ) 115 | 116 | self.after_conv2 = nn.Conv2d( 117 | in_channels=32, 118 | out_channels=1, 119 | kernel_size=(1, 1), 120 | stride=(1, 1), 121 | padding=(0, 0), 122 | bias=True, 123 | ) 124 | 125 | self.init_weights() 126 | 127 | def init_weights(self): 128 | init_layer(self.after_conv2) 129 | 130 | def forward(self, sp): 131 | """ 132 | Args: 133 | input: (batch_size, channels_num, segment_samples) 134 | 135 | Outputs: 136 | output_dict: { 137 | 'wav': (batch_size, channels_num, segment_samples), 138 | 'sp': (batch_size, channels_num, time_steps, freq_bins)} 139 | """ 140 | 141 | # Batch normalization 142 | x = sp 143 | 144 | # Pad spectrogram to be evenly divided by downsample ratio. 145 | origin_len = x.shape[2] # time_steps 146 | pad_len = ( 147 | int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio 148 | - origin_len 149 | ) 150 | x = F.pad(x, pad=(0, 0, 0, pad_len)) 151 | x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) 152 | 153 | # UNet 154 | (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2) 155 | (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4) 156 | (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8) 157 | (x4_pool, x4) = self.encoder_block4( 158 | x3_pool 159 | ) # x4_pool: (bs, 256, T / 16, F / 16) 160 | (x5_pool, x5) = self.encoder_block5( 161 | x4_pool 162 | ) # x5_pool: (bs, 512, T / 32, F / 32) 163 | (x6_pool, x6) = self.encoder_block6( 164 | x5_pool 165 | ) # x6_pool: (bs, 1024, T / 64, F / 64) 166 | x_center = self.conv_block7(x6_pool) # (bs, 2048, T / 64, F / 64) 167 | x7 = self.decoder_block1(x_center, x6) # (bs, 1024, T / 32, F / 32) 168 | x8 = self.decoder_block2(x7, x5) # (bs, 512, T / 16, F / 16) 169 | x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8) 170 | x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4) 171 | x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2) 172 | x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F) 173 | x = self.after_conv_block1(x12) # (bs, 32, T, F) 174 | x = self.after_conv2(x) # (bs, channels, T, F) 175 | 176 | # Recover shape 177 | x = F.pad(x, pad=(0, 1)) 178 | x = x[:, :, 0:origin_len, :] 179 | 180 | output_dict = {"mel": x} 181 | return output_dict 182 | 183 | 184 | if __name__ == "__main__": 185 | model = UNetResComplex_100Mb(channels=1) 186 | print(model(torch.randn((1, 1, 101, 128)))["mel"].size()) 187 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from voicefixer.vocoder.model.modules import UpsampleNet, ResStack 5 | from voicefixer.vocoder.config import Config 6 | from voicefixer.vocoder.model.pqmf import PQMF 7 | import os 8 | 9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 10 | 11 | 12 | class Generator(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels=128, 16 | use_elu=False, 17 | use_gcnn=False, 18 | up_org=False, 19 | group=1, 20 | hp=None, 21 | ): 22 | super(Generator, self).__init__() 23 | self.hp = hp 24 | channels = Config.channels 25 | self.upsample_scales = Config.upsample_scales 26 | self.use_condnet = Config.use_condnet 27 | self.out_channels = Config.out_channels 28 | self.resstack_depth = Config.resstack_depth 29 | self.use_postnet = Config.use_postnet 30 | self.use_cond_rnn = Config.use_cond_rnn 31 | if self.use_condnet: 32 | cond_channels = Config.cond_channels 33 | self.condnet = nn.Sequential( 34 | nn.utils.weight_norm( 35 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 36 | ), 37 | nn.ELU(), 38 | nn.utils.weight_norm( 39 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 40 | ), 41 | nn.ELU(), 42 | nn.utils.weight_norm( 43 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 44 | ), 45 | nn.ELU(), 46 | nn.utils.weight_norm( 47 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 48 | ), 49 | nn.ELU(), 50 | nn.utils.weight_norm( 51 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 52 | ), 53 | nn.ELU(), 54 | ) 55 | in_channels = cond_channels 56 | if self.use_cond_rnn: 57 | self.rnn = nn.GRU( 58 | cond_channels, 59 | cond_channels // 2, 60 | num_layers=1, 61 | batch_first=True, 62 | bidirectional=True, 63 | ) 64 | 65 | if use_elu: 66 | act = nn.ELU() 67 | else: 68 | act = nn.LeakyReLU(0.2, True) 69 | 70 | kernel_size = Config.kernel_size 71 | 72 | if self.out_channels == 1: 73 | self.generator = nn.Sequential( 74 | nn.ReflectionPad1d(3), 75 | nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)), 76 | act, 77 | UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp, 0), 78 | ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp), 79 | act, 80 | UpsampleNet( 81 | channels // 2, channels // 4, self.upsample_scales[1], hp, 1 82 | ), 83 | ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp), 84 | act, 85 | UpsampleNet( 86 | channels // 4, channels // 8, self.upsample_scales[2], hp, 2 87 | ), 88 | ResStack(channels // 8, kernel_size[2], self.resstack_depth[2], hp), 89 | act, 90 | UpsampleNet( 91 | channels // 8, channels // 16, self.upsample_scales[3], hp, 3 92 | ), 93 | ResStack(channels // 16, kernel_size[3], self.resstack_depth[3], hp), 94 | act, 95 | nn.ReflectionPad1d(3), 96 | nn.utils.weight_norm( 97 | nn.Conv1d(channels // 16, self.out_channels, kernel_size=7) 98 | ), 99 | nn.Tanh(), 100 | ) 101 | else: 102 | channels = Config.m_channels 103 | self.generator = nn.Sequential( 104 | nn.ReflectionPad1d(3), 105 | nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)), 106 | act, 107 | UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp), 108 | ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp), 109 | act, 110 | UpsampleNet(channels // 2, channels // 4, self.upsample_scales[1], hp), 111 | ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp), 112 | act, 113 | UpsampleNet(channels // 4, channels // 8, self.upsample_scales[3], hp), 114 | ResStack(channels // 8, kernel_size[3], self.resstack_depth[2], hp), 115 | act, 116 | nn.ReflectionPad1d(3), 117 | nn.utils.weight_norm( 118 | nn.Conv1d(channels // 8, self.out_channels, kernel_size=7) 119 | ), 120 | nn.Tanh(), 121 | ) 122 | if self.out_channels > 1: 123 | self.pqmf = PQMF(4, 64) 124 | 125 | self.num_params() 126 | 127 | def forward(self, conditions, use_res=False, f0=None): 128 | res = conditions 129 | if self.use_condnet: 130 | conditions = self.condnet(conditions) 131 | if self.use_cond_rnn: 132 | conditions, _ = self.rnn(conditions.transpose(1, 2)) 133 | conditions = conditions.transpose(1, 2) 134 | 135 | wav = self.generator(conditions) 136 | if self.out_channels > 1: 137 | B = wav.size(0) 138 | f_wav = ( 139 | self.pqmf.synthesis(wav) 140 | .transpose(1, 2) 141 | .reshape(B, 1, -1) 142 | .clamp(-0.99, 0.99) 143 | ) 144 | return f_wav, wav 145 | return wav 146 | 147 | def num_params(self): 148 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 149 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 150 | return parameters 151 | # print('Trainable Parameters: %.3f million' % parameters) 152 | 153 | def remove_weight_norm(self): 154 | def _remove_weight_norm(m): 155 | try: 156 | torch.nn.utils.remove_weight_norm(m) 157 | except ValueError: # this module didn't have weight norm 158 | return 159 | 160 | self.apply(_remove_weight_norm) 161 | 162 | 163 | if __name__ == "__main__": 164 | model = Generator(128) 165 | x = torch.randn(3, 128, 13) 166 | print(x.shape) 167 | y = model(x) 168 | print(y.shape) 169 | -------------------------------------------------------------------------------- /voicefixer/base.py: -------------------------------------------------------------------------------- 1 | import librosa.display 2 | from voicefixer.tools.pytorch_util import * 3 | from voicefixer.tools.wav import * 4 | from voicefixer.restorer.model import VoiceFixer as voicefixer_fe 5 | import os 6 | 7 | EPS = 1e-8 8 | 9 | 10 | class VoiceFixer(nn.Module): 11 | def __init__(self): 12 | super(VoiceFixer, self).__init__() 13 | self._model = voicefixer_fe(channels=2, sample_rate=44100) 14 | # print(os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/epoch=15_trimed_bn.ckpt")) 15 | self.analysis_module_ckpt = os.path.join( 16 | os.path.expanduser("~"), 17 | ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt", 18 | ) 19 | if(not os.path.exists(self.analysis_module_ckpt)): 20 | raise RuntimeError("Error 0: The checkpoint for analysis module (vf.ckpt) is not found in ~/.cache/voicefixer/analysis_module/checkpoints. \ 21 | By default the checkpoint should be download automatically by this program. Something bad may happened.\ 22 | But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/vf.ckpt?download=1.") 23 | self._model.load_state_dict( 24 | torch.load( 25 | self.analysis_module_ckpt 26 | ) 27 | ) 28 | self._model.eval() 29 | 30 | def _load_wav_energy(self, path, sample_rate, threshold=0.95): 31 | wav_10k, _ = librosa.load(path, sr=sample_rate) 32 | stft = np.log10(np.abs(librosa.stft(wav_10k)) + 1.0) 33 | fbins = stft.shape[0] 34 | e_stft = np.sum(stft, axis=1) 35 | for i in range(e_stft.shape[0]): 36 | e_stft[-i - 1] = np.sum(e_stft[: -i - 1]) 37 | total = e_stft[-1] 38 | for i in range(e_stft.shape[0]): 39 | if e_stft[i] < total * threshold: 40 | continue 41 | else: 42 | break 43 | return wav_10k, int((sample_rate // 2) * (i / fbins)) 44 | 45 | def _load_wav(self, path, sample_rate, threshold=0.95): 46 | wav_10k, _ = librosa.load(path, sr=sample_rate) 47 | return wav_10k 48 | 49 | def _amp_to_original_f(self, mel_sp_est, mel_sp_target, cutoff=0.2): 50 | freq_dim = mel_sp_target.size()[-1] 51 | mel_sp_est_low, mel_sp_target_low = ( 52 | mel_sp_est[..., 5 : int(freq_dim * cutoff)], 53 | mel_sp_target[..., 5 : int(freq_dim * cutoff)], 54 | ) 55 | energy_est, energy_target = torch.mean(mel_sp_est_low, dim=(2, 3)), torch.mean( 56 | mel_sp_target_low, dim=(2, 3) 57 | ) 58 | amp_ratio = energy_target / energy_est 59 | return mel_sp_est * amp_ratio[..., None, None], mel_sp_target 60 | 61 | def _trim_center(self, est, ref): 62 | diff = np.abs(est.shape[-1] - ref.shape[-1]) 63 | if est.shape[-1] == ref.shape[-1]: 64 | return est, ref 65 | elif est.shape[-1] > ref.shape[-1]: 66 | min_len = min(est.shape[-1], ref.shape[-1]) 67 | est, ref = est[..., int(diff // 2) : -int(diff // 2)], ref 68 | est, ref = est[..., :min_len], ref[..., :min_len] 69 | return est, ref 70 | else: 71 | min_len = min(est.shape[-1], ref.shape[-1]) 72 | est, ref = est, ref[..., int(diff // 2) : -int(diff // 2)] 73 | est, ref = est[..., :min_len], ref[..., :min_len] 74 | return est, ref 75 | 76 | def _pre(self, model, input, cuda): 77 | input = input[None, None, ...] 78 | input = torch.tensor(input) 79 | input = try_tensor_cuda(input, cuda=cuda) 80 | sp, _, _ = model.f_helper.wav_to_spectrogram_phase(input) 81 | mel_orig = model.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) 82 | # return models.to_log(sp), models.to_log(mel_orig) 83 | return sp, mel_orig 84 | 85 | def remove_higher_frequency(self, wav, ratio=0.95): 86 | stft = librosa.stft(wav) 87 | real, img = np.real(stft), np.imag(stft) 88 | mag = (real**2 + img**2) ** 0.5 89 | cos, sin = real / (mag + EPS), img / (mag + EPS) 90 | spec = np.abs(stft) # [1025,T] 91 | feature = spec.copy() 92 | feature = np.log10(feature + EPS) 93 | feature[feature < 0] = 0 94 | energy_level = np.sum(feature, axis=1) 95 | threshold = np.sum(energy_level) * ratio 96 | curent_level, i = energy_level[0], 0 97 | while i < energy_level.shape[0] and curent_level < threshold: 98 | curent_level += energy_level[i + 1, ...] 99 | i += 1 100 | spec[i:, ...] = np.zeros_like(spec[i:, ...]) 101 | stft = spec * cos + 1j * spec * sin 102 | return librosa.istft(stft) 103 | 104 | @torch.no_grad() 105 | def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None): 106 | check_cuda_availability(cuda=cuda) 107 | self._model = try_tensor_cuda(self._model, cuda=cuda) 108 | if mode == 0: 109 | self._model.eval() 110 | elif mode == 1: 111 | self._model.eval() 112 | elif mode == 2: 113 | self._model.train() # More effective on seriously demaged speech 114 | res = [] 115 | seg_length = 44100 * 30 116 | break_point = seg_length 117 | while break_point < wav_10k.shape[0] + seg_length: 118 | segment = wav_10k[break_point - seg_length : break_point] 119 | if mode == 1: 120 | segment = self.remove_higher_frequency(segment) 121 | sp, mel_noisy = self._pre(self._model, segment, cuda) 122 | out_model = self._model(sp, mel_noisy) 123 | denoised_mel = from_log(out_model["mel"]) 124 | if your_vocoder_func is None: 125 | out = self._model.vocoder(denoised_mel, cuda=cuda) 126 | else: 127 | out = your_vocoder_func(denoised_mel) 128 | # unify energy 129 | if torch.max(torch.abs(out)) > 1.0: 130 | out = out / torch.max(torch.abs(out)) 131 | print("Warning: Exceed energy limit,", input) 132 | # frame alignment 133 | out, _ = self._trim_center(out, segment) 134 | res.append(out) 135 | break_point += seg_length 136 | out = torch.cat(res, -1) 137 | return tensor2numpy(out.squeeze(0)) 138 | 139 | def restore(self, input, output, cuda=False, mode=0, your_vocoder_func=None): 140 | wav_10k = self._load_wav(input, sample_rate=44100) 141 | out_np_wav = self.restore_inmem( 142 | wav_10k, cuda=cuda, mode=mode, your_vocoder_func=your_vocoder_func 143 | ) 144 | save_wave(out_np_wav, fname=output, sample_rate=44100) 145 | -------------------------------------------------------------------------------- /voicefixer/restorer/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class ConvBlockRes(nn.Module): 8 | def __init__(self, in_channels, out_channels, size, activation, momentum): 9 | super(ConvBlockRes, self).__init__() 10 | 11 | self.activation = activation 12 | if type(size) == type((3, 4)): 13 | pad = size[0] // 2 14 | size = size[0] 15 | else: 16 | pad = size // 2 17 | size = size 18 | 19 | self.conv1 = nn.Conv2d( 20 | in_channels=in_channels, 21 | out_channels=out_channels, 22 | kernel_size=(size, size), 23 | stride=(1, 1), 24 | dilation=(1, 1), 25 | padding=(pad, pad), 26 | bias=False, 27 | ) 28 | 29 | self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) 30 | # self.abn1 = InPlaceABN(num_features=in_channels, momentum=momentum, activation='leaky_relu') 31 | 32 | self.conv2 = nn.Conv2d( 33 | in_channels=out_channels, 34 | out_channels=out_channels, 35 | kernel_size=(size, size), 36 | stride=(1, 1), 37 | dilation=(1, 1), 38 | padding=(pad, pad), 39 | bias=False, 40 | ) 41 | 42 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) 43 | 44 | # self.abn2 = InPlaceABN(num_features=out_channels, momentum=momentum, activation='leaky_relu') 45 | 46 | if in_channels != out_channels: 47 | self.shortcut = nn.Conv2d( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=(1, 1), 51 | stride=(1, 1), 52 | padding=(0, 0), 53 | ) 54 | self.is_shortcut = True 55 | else: 56 | self.is_shortcut = False 57 | 58 | self.init_weights() 59 | 60 | def init_weights(self): 61 | init_bn(self.bn1) 62 | init_layer(self.conv1) 63 | init_layer(self.conv2) 64 | 65 | if self.is_shortcut: 66 | init_layer(self.shortcut) 67 | 68 | def forward(self, x): 69 | origin = x 70 | x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) 71 | x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01)) 72 | 73 | if self.is_shortcut: 74 | return self.shortcut(origin) + x 75 | else: 76 | return origin + x 77 | 78 | 79 | class EncoderBlockRes(nn.Module): 80 | def __init__(self, in_channels, out_channels, downsample, activation, momentum): 81 | super(EncoderBlockRes, self).__init__() 82 | size = 3 83 | 84 | self.conv_block1 = ConvBlockRes( 85 | in_channels, out_channels, size, activation, momentum 86 | ) 87 | self.conv_block2 = ConvBlockRes( 88 | out_channels, out_channels, size, activation, momentum 89 | ) 90 | self.conv_block3 = ConvBlockRes( 91 | out_channels, out_channels, size, activation, momentum 92 | ) 93 | self.conv_block4 = ConvBlockRes( 94 | out_channels, out_channels, size, activation, momentum 95 | ) 96 | self.downsample = downsample 97 | 98 | def forward(self, x): 99 | encoder = self.conv_block1(x) 100 | encoder = self.conv_block2(encoder) 101 | encoder = self.conv_block3(encoder) 102 | encoder = self.conv_block4(encoder) 103 | encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) 104 | return encoder_pool, encoder 105 | 106 | 107 | class DecoderBlockRes(nn.Module): 108 | def __init__(self, in_channels, out_channels, stride, activation, momentum): 109 | super(DecoderBlockRes, self).__init__() 110 | size = 3 111 | self.activation = activation 112 | 113 | self.conv1 = torch.nn.ConvTranspose2d( 114 | in_channels=in_channels, 115 | out_channels=out_channels, 116 | kernel_size=(size, size), 117 | stride=stride, 118 | padding=(0, 0), 119 | output_padding=(0, 0), 120 | bias=False, 121 | dilation=(1, 1), 122 | ) 123 | 124 | self.bn1 = nn.BatchNorm2d(in_channels) 125 | self.conv_block2 = ConvBlockRes( 126 | out_channels * 2, out_channels, size, activation, momentum 127 | ) 128 | self.conv_block3 = ConvBlockRes( 129 | out_channels, out_channels, size, activation, momentum 130 | ) 131 | self.conv_block4 = ConvBlockRes( 132 | out_channels, out_channels, size, activation, momentum 133 | ) 134 | self.conv_block5 = ConvBlockRes( 135 | out_channels, out_channels, size, activation, momentum 136 | ) 137 | 138 | def init_weights(self): 139 | init_layer(self.conv1) 140 | 141 | def prune(self, x, both=False): 142 | """Prune the shape of x after transpose convolution.""" 143 | if both: 144 | x = x[:, :, 0:-1, 0:-1] 145 | else: 146 | x = x[:, :, 0:-1, :] 147 | return x 148 | 149 | def forward(self, input_tensor, concat_tensor, both=False): 150 | x = self.conv1(F.relu_(self.bn1(input_tensor))) 151 | x = self.prune(x, both=both) 152 | x = torch.cat((x, concat_tensor), dim=1) 153 | x = self.conv_block2(x) 154 | x = self.conv_block3(x) 155 | x = self.conv_block4(x) 156 | x = self.conv_block5(x) 157 | return x 158 | 159 | 160 | def init_layer(layer): 161 | """Initialize a Linear or Convolutional layer.""" 162 | nn.init.xavier_uniform_(layer.weight) 163 | 164 | if hasattr(layer, "bias"): 165 | if layer.bias is not None: 166 | layer.bias.data.fill_(0.0) 167 | 168 | 169 | def init_bn(bn): 170 | """Initialize a Batchnorm layer.""" 171 | bn.bias.data.fill_(0.0) 172 | bn.weight.data.fill_(1.0) 173 | 174 | 175 | def init_gru(rnn): 176 | """Initialize a GRU layer.""" 177 | 178 | def _concat_init(tensor, init_funcs): 179 | (length, fan_out) = tensor.shape 180 | fan_in = length // len(init_funcs) 181 | 182 | for (i, init_func) in enumerate(init_funcs): 183 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 184 | 185 | def _inner_uniform(tensor): 186 | fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") 187 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 188 | 189 | for i in range(rnn.num_layers): 190 | _concat_init( 191 | getattr(rnn, "weight_ih_l{}".format(i)), 192 | [_inner_uniform, _inner_uniform, _inner_uniform], 193 | ) 194 | torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) 195 | 196 | _concat_init( 197 | getattr(rnn, "weight_hh_l{}".format(i)), 198 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_], 199 | ) 200 | torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) 201 | 202 | 203 | from torch.cuda import init 204 | 205 | 206 | def act(x, activation): 207 | if activation == "relu": 208 | return F.relu_(x) 209 | 210 | elif activation == "leaky_relu": 211 | return F.leaky_relu_(x, negative_slope=0.2) 212 | 213 | elif activation == "swish": 214 | return x * torch.sigmoid(x) 215 | 216 | else: 217 | raise Exception("Incorrect activation!") 218 | -------------------------------------------------------------------------------- /voicefixer/tools/wav.py: -------------------------------------------------------------------------------- 1 | import wave 2 | import os 3 | import numpy as np 4 | import scipy.signal as signal 5 | import soundfile as sf 6 | import librosa 7 | 8 | 9 | def save_wave(frames: np.ndarray, fname, sample_rate=44100): 10 | shape = list(frames.shape) 11 | if len(shape) == 1: 12 | frames = frames[..., None] 13 | in_samples, in_channels = shape[-2], shape[-1] 14 | if in_channels >= 3: 15 | if len(shape) == 2: 16 | frames = np.transpose(frames, (1, 0)) 17 | elif len(shape) == 3: 18 | frames = np.transpose(frames, (0, 2, 1)) 19 | msg = ( 20 | "Warning: Save audio with " 21 | + str(in_channels) 22 | + " channels, save permute audio with shape " 23 | + str(list(frames.shape)) 24 | + " please check if it's correct." 25 | ) 26 | # print(msg) 27 | if ( 28 | np.max(frames) <= 1 29 | and frames.dtype == np.float32 30 | or frames.dtype == np.float16 31 | or frames.dtype == np.float64 32 | ): 33 | frames *= 2**15 34 | frames = frames.astype(np.short) 35 | if len(frames.shape) >= 3: 36 | frames = frames[0, ...] 37 | sf.write(fname, frames, samplerate=sample_rate) 38 | 39 | 40 | def constrain_length(chunk, length): 41 | frames_length = chunk.shape[0] 42 | if frames_length == length: 43 | return chunk 44 | elif frames_length < length: 45 | return np.pad(chunk, ((0, int(length - frames_length)), (0, 0)), "constant") 46 | else: 47 | return chunk[:length, ...] 48 | 49 | 50 | def random_chunk_wav_file(fname, chunk_length): 51 | """ 52 | fname: path to wav file 53 | chunk_length: frame length in seconds 54 | """ 55 | with wave.open(fname) as f: 56 | params = f.getparams() 57 | duration = params[3] / params[2] 58 | sample_rate = params[2] 59 | sample_length = params[3] 60 | if duration < chunk_length or abs(duration - chunk_length) < 1e-4: 61 | frames = read_wave(fname, sample_rate) 62 | return frames, duration, sample_rate # [-1,1] 63 | else: 64 | # Random trunk 65 | random_starts = np.random.randint( 66 | 0, sample_length - sample_rate * chunk_length 67 | ) 68 | random_end = random_starts + sample_rate * chunk_length 69 | random_starts, random_end = ( 70 | random_starts / sample_rate, 71 | random_end / sample_rate, 72 | ) 73 | random_starts, random_end = random_starts / duration, random_end / duration 74 | frames = read_wave( 75 | fname, sample_rate, portion_start=random_starts, portion_end=random_end 76 | ) 77 | frames = constrain_length(frames, length=int(chunk_length * sample_rate)) 78 | return frames, chunk_length, sample_rate 79 | 80 | 81 | def random_chunk_wav_file_v2(fname, chunk_length, random_starts=None, random_end=None): 82 | """ 83 | fname: path to wav file 84 | chunk_length: frame length in seconds 85 | """ 86 | with wave.open(fname) as f: 87 | params = f.getparams() 88 | duration = params[3] / params[2] 89 | sample_rate = params[2] 90 | sample_length = params[3] 91 | if duration < chunk_length or abs(duration - chunk_length) < 1e-4: 92 | frames = read_wave(fname, sample_rate) 93 | return frames, duration, sample_rate # [-1,1] 94 | else: 95 | # Random trunk 96 | if random_starts is None and random_end is None: 97 | random_starts = np.random.randint( 98 | 0, sample_length - sample_rate * chunk_length 99 | ) 100 | random_end = random_starts + sample_rate * chunk_length 101 | random_starts, random_end = ( 102 | random_starts / sample_rate, 103 | random_end / sample_rate, 104 | ) 105 | random_starts, random_end = ( 106 | random_starts / duration, 107 | random_end / duration, 108 | ) 109 | frames = read_wave( 110 | fname, sample_rate, portion_start=random_starts, portion_end=random_end 111 | ) 112 | frames = constrain_length(frames, length=int(chunk_length * sample_rate)) 113 | return frames, chunk_length, sample_rate, random_starts, random_end 114 | 115 | 116 | def read_wave( 117 | fname, 118 | sample_rate, 119 | portion_start=0, 120 | portion_end=1, 121 | ): # Whether you want raw bytes 122 | """ 123 | :param fname: wav file path 124 | :param sample_rate: 125 | :param portion_start: 126 | :param portion_end: 127 | :return: [sample, channels] 128 | """ 129 | # sr = get_sample_rate(fname) 130 | # if(sr != sample_rate): 131 | # print("Warning: Sample rate not match, may lead to unexpected behavior.") 132 | if portion_end > 1 and portion_end < 1.1: 133 | portion_end = 1 134 | if portion_end != 1: 135 | duration = get_duration(fname) 136 | wav, _ = librosa.load( 137 | fname, 138 | sr=sample_rate, 139 | offset=portion_start * duration, 140 | duration=(portion_end - portion_start) * duration, 141 | mono=False, 142 | ) 143 | else: 144 | wav, _ = librosa.load(fname, sr=sample_rate, mono=False) 145 | if len(list(wav.shape)) == 1: 146 | wav = wav[..., None] 147 | else: 148 | wav = np.transpose(wav, (1, 0)) 149 | return wav 150 | 151 | 152 | def get_channels_sampwidth_and_sample_rate(fname): 153 | with wave.open(fname) as f: 154 | params = f.getparams() 155 | return ( 156 | params[0], 157 | params[1], 158 | params[2], 159 | ) # == (2,2,44100),(params[0],params[1],params[2]) 160 | 161 | 162 | def get_channels(fname): 163 | with wave.open(fname) as f: 164 | params = f.getparams() 165 | return params[0] 166 | 167 | 168 | def get_sample_rate(fname): 169 | with wave.open(fname) as f: 170 | params = f.getparams() 171 | return params[2] 172 | 173 | 174 | def get_duration(fname): 175 | with wave.open(fname) as f: 176 | params = f.getparams() 177 | return params[3] / params[2] 178 | 179 | 180 | def get_framesLength(fname): 181 | with wave.open(fname) as f: 182 | params = f.getparams() 183 | return params[3] 184 | 185 | 186 | def restore_wave(zxx): 187 | _, w = signal.istft(zxx) 188 | return w 189 | 190 | 191 | def calculate_total_times(dir): 192 | total = 0 193 | for each in os.listdir(dir): 194 | fname = os.path.join(dir, each) 195 | try: 196 | duration = get_duration(fname) 197 | except: 198 | print(fname) 199 | total += duration 200 | return total 201 | 202 | 203 | def filter(pth): 204 | global dic 205 | temp = [] 206 | for each in os.listdir(pth): 207 | temp.append(os.path.join(pth, each)) 208 | for each in temp: 209 | sr = get_sample_rate(each) 210 | if sr not in dic.keys(): 211 | dic[sr] = [] 212 | dic[sr].append(each) 213 | for each in dic[16000]: 214 | # print(each) 215 | pass 216 | print(dic.keys()) 217 | for each in list(dic.keys()): 218 | print(each, len(dic[each])) 219 | 220 | 221 | if __name__ == "__main__": 222 | path = "/Users/admin/Desktop/p376_025.wav" 223 | stereo = "/Users/admin/Desktop/vocals.wav" 224 | path_16 = "/Users/admin/Desktop/SI869.WAV.wav" 225 | import time 226 | 227 | start = time.time() 228 | for i in range(1000): 229 | frames, duration, sample_rate = random_chunk_wav_file(stereo, chunk_length=3.0) 230 | print(frames.shape, np.max(frames)) 231 | save_wave(frames, "stero.wav", sample_rate=44100) 232 | frames, duration, sample_rate = random_chunk_wav_file(path, chunk_length=3.0) 233 | print(frames.shape, np.max(frames)) 234 | save_wave(frames, "mono.wav", sample_rate=44100) 235 | frames, duration, sample_rate = random_chunk_wav_file(path_16, chunk_length=3.0) 236 | print(frames.shape, np.max(frames)) 237 | save_wave(frames, "16.wav", sample_rate=16000) 238 | print(time.time() - start) 239 | # frames = read_wave(stereo,sample_rate=44100) 240 | print(frames.shape) 241 | 242 | print(frames) 243 | -------------------------------------------------------------------------------- /voicefixer/tools/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | import torch.fft 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 9 | 10 | 11 | def get_window(window_size, window_type, square_root_window=True): 12 | """Return the window""" 13 | window = { 14 | "hamming": torch.hamming_window(window_size), 15 | "hanning": torch.hann_window(window_size), 16 | }[window_type] 17 | if square_root_window: 18 | window = torch.sqrt(window) 19 | return window 20 | 21 | 22 | def fft_point(dim): 23 | assert dim > 0 24 | num = math.log(dim, 2) 25 | num_point = 2 ** (math.ceil(num)) 26 | return num_point 27 | 28 | 29 | def pre_emphasis(signal, coefficient=0.97): 30 | """Pre-emphasis original signal 31 | y(n) = x(n) - a*x(n-1) 32 | """ 33 | return np.append(signal[0], signal[1:] - coefficient * signal[:-1]) 34 | 35 | 36 | def de_emphasis(signal, coefficient=0.97): 37 | """De-emphasis original signal 38 | y(n) = x(n) + a*x(n-1) 39 | """ 40 | length = signal.shape[0] 41 | for i in range(1, length): 42 | signal[i] = signal[i] + coefficient * signal[i - 1] 43 | return signal 44 | 45 | 46 | def seperate_magnitude(magnitude, phase): 47 | real = torch.cos(phase) * magnitude 48 | imagine = torch.sin(phase) * magnitude 49 | expand_dim = len(list(real.size())) 50 | return torch.stack((real, imagine), expand_dim) 51 | 52 | 53 | def stft_single( 54 | signal, 55 | sample_rate=44100, 56 | frame_length=46, 57 | frame_shift=10, 58 | window_type="hanning", 59 | device=torch.device("cuda"), 60 | square_root_window=True, 61 | ): 62 | """Compute the Short Time Fourier Transform. 63 | 64 | Args: 65 | signal: input speech signal, 66 | sample_rate: waveform datas sample frequency (Hz) 67 | frame_length: frame length in milliseconds 68 | frame_shift: frame shift in milliseconds 69 | window_type: type of window 70 | square_root_window: square root window 71 | Return: 72 | fft: (n/2)+1 dim complex STFT restults 73 | """ 74 | hop_length = int( 75 | sample_rate * frame_shift / 1000 76 | ) # The greater sample_rate, the greater hop_length 77 | win_length = int(sample_rate * frame_length / 1000) 78 | # num_point = fft_point(win_length) 79 | num_point = win_length 80 | window = get_window(num_point, window_type, square_root_window) 81 | if "cuda" in str(device): 82 | window = window.cuda(device) 83 | feat = torch.stft( 84 | signal, 85 | n_fft=num_point, 86 | hop_length=hop_length, 87 | win_length=window.shape[0], 88 | window=window, 89 | ) 90 | real, imag = feat[..., 0], feat[..., 1] 91 | return real.permute(0, 2, 1).unsqueeze(1), imag.permute(0, 2, 1).unsqueeze(1) 92 | 93 | 94 | def istft( 95 | real, 96 | imag, 97 | length, 98 | sample_rate=44100, 99 | frame_length=46, 100 | frame_shift=10, 101 | window_type="hanning", 102 | preemphasis=0.0, 103 | device=torch.device("cuda"), 104 | square_root_window=True, 105 | ): 106 | """Convert frames to signal using overlap-and-add systhesis. 107 | Args: 108 | spectrum: magnitude spectrum [batchsize,x,y,2] 109 | signal: wave signal to supply phase information 110 | Return: 111 | wav: synthesied output waveform 112 | """ 113 | real = real.permute(0, 3, 2, 1) 114 | imag = imag.permute(0, 3, 2, 1) 115 | spectrum = torch.cat([real, imag], dim=-1) 116 | 117 | hop_length = int(sample_rate * frame_shift / 1000) 118 | win_length = int(sample_rate * frame_length / 1000) 119 | 120 | # num_point = fft_point(win_length) 121 | num_point = win_length 122 | if "cuda" in str(device): 123 | window = get_window(num_point, window_type, square_root_window).cuda(device) 124 | else: 125 | window = get_window(num_point, window_type, square_root_window) 126 | 127 | wav = torch_istft( 128 | spectrum, 129 | num_point, 130 | hop_length=hop_length, 131 | win_length=window.shape[0], 132 | window=window, 133 | ) 134 | return wav[..., :length] 135 | 136 | 137 | def torch_istft( 138 | stft_matrix, # type: Tensor 139 | n_fft, # type: int 140 | hop_length=None, # type: Optional[int] 141 | win_length=None, # type: Optional[int] 142 | window=None, # type: Optional[Tensor] 143 | center=True, # type: bool 144 | pad_mode="reflect", # type: str 145 | normalized=False, # type: bool 146 | onesided=True, # type: bool 147 | length=None, # type: Optional[int] 148 | ): 149 | # type: (...) -> Tensor 150 | 151 | stft_matrix_dim = stft_matrix.dim() 152 | assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) 153 | 154 | if stft_matrix_dim == 3: 155 | # add a channel dimension 156 | stft_matrix = stft_matrix.unsqueeze(0) 157 | 158 | dtype = stft_matrix.dtype 159 | device = stft_matrix.device 160 | fft_size = stft_matrix.size(1) 161 | assert (onesided and n_fft // 2 + 1 == fft_size) or ( 162 | not onesided and n_fft == fft_size 163 | ), ( 164 | "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. " 165 | + "Given values were onesided: %s, n_fft: %d, fft_size: %d" 166 | % ("True" if onesided else False, n_fft, fft_size) 167 | ) 168 | 169 | # use stft defaults for Optionals 170 | if win_length is None: 171 | win_length = n_fft 172 | 173 | if hop_length is None: 174 | hop_length = int(win_length // 4) 175 | 176 | # There must be overlap 177 | assert 0 < hop_length <= win_length 178 | assert 0 < win_length <= n_fft 179 | 180 | if window is None: 181 | window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype) 182 | 183 | assert window.dim() == 1 and window.size(0) == win_length 184 | 185 | if win_length != n_fft: 186 | # center window with pad left and right zeros 187 | left = (n_fft - win_length) // 2 188 | window = torch.nn.functional.pad(window, (left, n_fft - win_length - left)) 189 | assert window.size(0) == n_fft 190 | # win_length and n_fft are synonymous from here on 191 | 192 | stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) 193 | stft_matrix = torch.irfft( 194 | stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,) 195 | ) # size (channel, n_frames, n_fft) 196 | 197 | assert stft_matrix.size(2) == n_fft 198 | n_frames = stft_matrix.size(1) 199 | 200 | ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft) 201 | # each column of a channel is a frame which needs to be overlap added at the right place 202 | ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) 203 | 204 | eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze( 205 | 1 206 | ) # size (n_fft, 1, n_fft) 207 | 208 | # this does overlap add where the frames of ytmp are added such that the i'th frame of 209 | # ytmp is added starting at i*hop_length in the output 210 | y = torch.nn.functional.conv_transpose1d( 211 | ytmp, eye, stride=hop_length, padding=0 212 | ) # size (channel, 1, expected_signal_len) 213 | 214 | # do the same for the window function 215 | window_sq = ( 216 | window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) 217 | ) # size (1, n_fft, n_frames) 218 | window_envelop = torch.nn.functional.conv_transpose1d( 219 | window_sq, eye, stride=hop_length, padding=0 220 | ) # size (1, 1, expected_signal_len) 221 | 222 | expected_signal_len = n_fft + hop_length * (n_frames - 1) 223 | assert y.size(2) == expected_signal_len 224 | assert window_envelop.size(2) == expected_signal_len 225 | 226 | half_n_fft = n_fft // 2 227 | # we need to trim the front padding away if center 228 | start = half_n_fft if center else 0 229 | end = -half_n_fft if length is None else start + length 230 | 231 | y = y[:, :, start:end] 232 | window_envelop = window_envelop[:, :, start:end] 233 | 234 | # check NOLA non-zero overlap condition 235 | window_envelop_lowest = window_envelop.abs().min() 236 | assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % ( 237 | window_envelop_lowest 238 | ) 239 | 240 | y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) 241 | 242 | if stft_matrix_dim == 3: # remove the channel dimension 243 | y = y.squeeze(0) 244 | return y 245 | -------------------------------------------------------------------------------- /voicefixer/tools/mel_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | import math 5 | 6 | import warnings 7 | 8 | 9 | class MelScale(torch.nn.Module): 10 | r"""Turn a normal STFT into a mel frequency STFT, using a conversion 11 | matrix. This uses triangular filter banks. 12 | 13 | User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). 14 | 15 | Args: 16 | n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) 17 | sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) 18 | f_min (float, optional): Minimum frequency. (Default: ``0.``) 19 | f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) 20 | n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) 21 | norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band 22 | (area normalization). (Default: ``None``) 23 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 24 | 25 | See also: 26 | :py:func:`torchaudio.functional.melscale_fbanks` - The function used to 27 | generate the filter banks. 28 | """ 29 | __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"] 30 | 31 | def __init__( 32 | self, 33 | n_mels: int = 128, 34 | sample_rate: int = 16000, 35 | f_min: float = 0.0, 36 | f_max: Optional[float] = None, 37 | n_stft: int = 201, 38 | norm: Optional[str] = None, 39 | mel_scale: str = "htk", 40 | ) -> None: 41 | super(MelScale, self).__init__() 42 | self.n_mels = n_mels 43 | self.sample_rate = sample_rate 44 | self.f_max = f_max if f_max is not None else float(sample_rate // 2) 45 | self.f_min = f_min 46 | self.norm = norm 47 | self.mel_scale = mel_scale 48 | 49 | assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format( 50 | f_min, self.f_max 51 | ) 52 | fb = melscale_fbanks( 53 | n_stft, 54 | self.f_min, 55 | self.f_max, 56 | self.n_mels, 57 | self.sample_rate, 58 | self.norm, 59 | self.mel_scale, 60 | ) 61 | self.register_buffer("fb", fb) 62 | 63 | def forward(self, specgram: Tensor) -> Tensor: 64 | r""" 65 | Args: 66 | specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). 67 | 68 | Returns: 69 | Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). 70 | """ 71 | 72 | # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time) 73 | mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose( 74 | -1, -2 75 | ) 76 | 77 | return mel_specgram 78 | 79 | 80 | def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float: 81 | r"""Convert Hz to Mels. 82 | 83 | Args: 84 | freqs (float): Frequencies in Hz 85 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 86 | 87 | Returns: 88 | mels (float): Frequency in Mels 89 | """ 90 | 91 | if mel_scale not in ["slaney", "htk"]: 92 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 93 | 94 | if mel_scale == "htk": 95 | return 2595.0 * math.log10(1.0 + (freq / 700.0)) 96 | 97 | # Fill in the linear part 98 | f_min = 0.0 99 | f_sp = 200.0 / 3 100 | 101 | mels = (freq - f_min) / f_sp 102 | 103 | # Fill in the log-scale part 104 | min_log_hz = 1000.0 105 | min_log_mel = (min_log_hz - f_min) / f_sp 106 | logstep = math.log(6.4) / 27.0 107 | 108 | if freq >= min_log_hz: 109 | mels = min_log_mel + math.log(freq / min_log_hz) / logstep 110 | 111 | return mels 112 | 113 | 114 | def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor: 115 | """Convert mel bin numbers to frequencies. 116 | 117 | Args: 118 | mels (Tensor): Mel frequencies 119 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 120 | 121 | Returns: 122 | freqs (Tensor): Mels converted in Hz 123 | """ 124 | 125 | if mel_scale not in ["slaney", "htk"]: 126 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 127 | 128 | if mel_scale == "htk": 129 | return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) 130 | 131 | # Fill in the linear scale 132 | f_min = 0.0 133 | f_sp = 200.0 / 3 134 | freqs = f_min + f_sp * mels 135 | 136 | # And now the nonlinear scale 137 | min_log_hz = 1000.0 138 | min_log_mel = (min_log_hz - f_min) / f_sp 139 | logstep = math.log(6.4) / 27.0 140 | 141 | log_t = mels >= min_log_mel 142 | freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) 143 | 144 | return freqs 145 | 146 | 147 | def _create_triangular_filterbank( 148 | all_freqs: Tensor, 149 | f_pts: Tensor, 150 | ) -> Tensor: 151 | """Create a triangular filter bank. 152 | 153 | Args: 154 | all_freqs (Tensor): STFT freq points of size (`n_freqs`). 155 | f_pts (Tensor): Filter mid points of size (`n_filter`). 156 | 157 | Returns: 158 | fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`). 159 | """ 160 | # Adopted from Librosa 161 | # calculate the difference between each filter mid point and each stft freq point in hertz 162 | f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) 163 | slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2) 164 | # create overlapping triangles 165 | zero = torch.zeros(1) 166 | down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter) 167 | up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter) 168 | fb = torch.max(zero, torch.min(down_slopes, up_slopes)) 169 | 170 | return fb 171 | 172 | 173 | def melscale_fbanks( 174 | n_freqs: int, 175 | f_min: float, 176 | f_max: float, 177 | n_mels: int, 178 | sample_rate: int, 179 | norm: Optional[str] = None, 180 | mel_scale: str = "htk", 181 | ) -> Tensor: 182 | r"""Create a frequency bin conversion matrix. 183 | 184 | Note: 185 | For the sake of the numerical compatibility with librosa, not all the coefficients 186 | in the resulting filter bank has magnitude of 1. 187 | 188 | .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png 189 | :alt: Visualization of generated filter bank 190 | 191 | Args: 192 | n_freqs (int): Number of frequencies to highlight/apply 193 | f_min (float): Minimum frequency (Hz) 194 | f_max (float): Maximum frequency (Hz) 195 | n_mels (int): Number of mel filterbanks 196 | sample_rate (int): Sample rate of the audio waveform 197 | norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band 198 | (area normalization). (Default: ``None``) 199 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 200 | 201 | Returns: 202 | Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) 203 | meaning number of frequencies to highlight/apply to x the number of filterbanks. 204 | Each column is a filterbank so that assuming there is a matrix A of 205 | size (..., ``n_freqs``), the applied result would be 206 | ``A * melscale_fbanks(A.size(-1), ...)``. 207 | 208 | """ 209 | 210 | if norm is not None and norm != "slaney": 211 | raise ValueError("norm must be one of None or 'slaney'") 212 | 213 | # freq bins 214 | all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) 215 | 216 | # calculate mel freq bins 217 | m_min = _hz_to_mel(f_min, mel_scale=mel_scale) 218 | m_max = _hz_to_mel(f_max, mel_scale=mel_scale) 219 | 220 | m_pts = torch.linspace(m_min, m_max, n_mels + 2) 221 | f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale) 222 | 223 | # create filterbank 224 | fb = _create_triangular_filterbank(all_freqs, f_pts) 225 | 226 | if norm is not None and norm == "slaney": 227 | # Slaney-style mel is scaled to be approx constant energy per channel 228 | enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) 229 | fb *= enorm.unsqueeze(0) 230 | 231 | if (fb.max(dim=0).values == 0.0).any(): 232 | warnings.warn( 233 | "At least one mel filterbank has all zero values. " 234 | f"The value for `n_mels` ({n_mels}) may be set too high. " 235 | f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." 236 | ) 237 | 238 | return fb 239 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2109.13731-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2109.13731) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HYYUepIsl2aXsdET6P_AmNVXuWP1MCMf?usp=sharing) [![PyPI version](https://badge.fury.io/py/voicefixer.svg)](https://badge.fury.io/py/voicefixer) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://haoheliu.github.io/demopage-voicefixer)[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97-Models%20on%20Hub-yellow)](https://huggingface.co/spaces/akhaliq/VoiceFixer) 2 | 3 | - [:speaking_head: :wrench: VoiceFixer](#speaking_head-wrench-voicefixer) 4 | - [Demo](#demo) 5 | - [Usage](#usage) 6 | - [Command line](#command-line) 7 | - [Desktop App](#desktop-app) 8 | - [Python Examples](#python-examples) 9 | - [Others Features](#others-features) 10 | - [Materials](#materials) 11 | - [Change log](#change-log) 12 | 13 | # :speaking_head: :wrench: VoiceFixer 14 | 15 | *Voicefixer* aims to restore human speech regardless how serious its degraded. It can handle noise, reveberation, low resolution (2kHz~44.1kHz) and clipping (0.1-1.0 threshold) effect within one model. 16 | 17 | This package provides: 18 | - A pretrained *Voicefixer*, which is build based on neural vocoder. 19 | - A pretrained 44.1k universal speaker-independent neural vocoder. 20 | 21 | ![main](test/figure.png) 22 | 23 | - If you found this repo helpful, please consider citing or [!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/haoheliuP) 24 | 25 | ```bib 26 | @misc{liu2021voicefixer, 27 | title={VoiceFixer: Toward General Speech Restoration With Neural Vocoder}, 28 | author={Haohe Liu and Qiuqiang Kong and Qiao Tian and Yan Zhao and DeLiang Wang and Chuanzeng Huang and Yuxuan Wang}, 29 | year={2021}, 30 | eprint={2109.13731}, 31 | archivePrefix={arXiv}, 32 | primaryClass={cs.SD} 33 | } 34 | ``` 35 | 36 | ## Demo 37 | 38 | Please visit [demo page](https://haoheliu.github.io/demopage-voicefixer/) to view what voicefixer can do. 39 | 40 | ## Usage 41 | 42 | ### Command line 43 | 44 | First, install voicefixer via pip: 45 | ```shell 46 | pip install voicefixer==0.1.2 47 | ``` 48 | 49 | Process a file: 50 | ```shell 51 | # Specify the input .wav file. Output file is outfile.wav. 52 | voicefixer --infile test/utterance/original/original.wav 53 | # Or specify a output path 54 | voicefixer --infile test/utterance/original/original.wav --outfile test/utterance/original/original_processed.wav 55 | ``` 56 | 57 | Process files in a folder: 58 | ```shell 59 | voicefixer --infolder /path/to/input --outfolder /path/to/output 60 | ``` 61 | 62 | Change mode (The default mode is 0): 63 | ```shell 64 | voicefixer --infile /path/to/input.wav --outfile /path/to/output.wav --mode 1 65 | ``` 66 | 67 | Run all modes: 68 | ```shell 69 | # output file saved to `/path/to/output-modeX.wav`. 70 | voicefixer --infile /path/to/input.wav --outfile /path/to/output.wav --mode all 71 | ``` 72 | 73 | For more helper information please run: 74 | 75 | ```shell 76 | voicefixer -h 77 | ``` 78 | 79 | ### Desktop App 80 | 81 | [Demo on Youtube](https://www.youtube.com/watch?v=d_j8UKTZ7J8) (Thanks @Justin John) 82 | 83 | Install voicefixer via pip: 84 | ```shell script 85 | pip install voicefixer==0.1.2 86 | ``` 87 | 88 | You can test audio samples on your desktop by running website (powered by [streamlit](https://streamlit.io/)) 89 | 90 | 1. Clone the repo first. 91 | ```shell script 92 | git clone https://github.com/haoheliu/voicefixer.git 93 | cd voicefixer 94 | ``` 95 | :warning: **For windows users**, please make sure you have installed [WGET](https://eternallybored.org/misc/wget) and added the wget command to the system path (thanks @justinjohn0306). 96 | 97 | 98 | 2. Initialize and start web page. 99 | ```shell script 100 | # Run streamlit 101 | streamlit run test/streamlit.py 102 | ``` 103 | 104 | - If you run for the first time: the web page may leave blank for several minutes for downloading models. You can checkout the terminal for downloading progresses. 105 | 106 | - You can use [this low quality speech file](https://github.com/haoheliu/voicefixer/blob/main/test/utterance/original/original.wav) we provided for a test run. The page after processing will look like the following. 107 | 108 |

figure

109 | 110 | - For users from main land China, if you experience difficulty on downloading checkpoint. You can access them alternatively on [百度网盘](https://pan.baidu.com/s/194ufkUR_PYf1nE1KqkEZjQ) (提取密码: qis6). Please download the two checkpoints inside and place them in the following folder. 111 | - Place **vf.ckpt** inside *~/.cache/voicefixer/analysis_module/checkpoints*. (The "~" represents your home directory) 112 | - Place **model.ckpt-1490000_trimed.pt** inside *~/.cache/voicefixer/synthesis_module/44100*. (The "~" represents your home directory) 113 | 114 | ### Python Examples 115 | 116 | First, install voicefixer via pip: 117 | ```shell script 118 | pip install voicefixer==0.1.2 119 | ``` 120 | 121 | Then run the following scripts for a test run: 122 | 123 | ```shell script 124 | git clone https://github.com/haoheliu/voicefixer.git; cd voicefixer 125 | python3 test/test.py # test script 126 | ``` 127 | We expect it will give you the following output: 128 | ```shell script 129 | Initializing VoiceFixer... 130 | Test voicefixer mode 0, Pass 131 | Test voicefixer mode 1, Pass 132 | Test voicefixer mode 2, Pass 133 | Initializing 44.1kHz speech vocoder... 134 | Test vocoder using groundtruth mel spectrogram... 135 | Pass 136 | ``` 137 | *test/test.py* mainly contains the test of the following two APIs: 138 | - voicefixer.restore 139 | - vocoder.oracle 140 | 141 | ```python 142 | ... 143 | 144 | # TEST VOICEFIXER 145 | ## Initialize a voicefixer 146 | print("Initializing VoiceFixer...") 147 | voicefixer = VoiceFixer() 148 | # Mode 0: Original Model (suggested by default) 149 | # Mode 1: Add preprocessing module (remove higher frequency) 150 | # Mode 2: Train mode (might work sometimes on seriously degraded real speech) 151 | for mode in [0,1,2]: 152 | print("Testing mode",mode) 153 | voicefixer.restore(input=os.path.join(git_root,"test/utterance/original/original.flac"), # low quality .wav/.flac file 154 | output=os.path.join(git_root,"test/utterance/output/output_mode_"+str(mode)+".flac"), # save file path 155 | cuda=False, # GPU acceleration 156 | mode=mode) 157 | if(mode != 2): 158 | check("output_mode_"+str(mode)+".flac") 159 | print("Pass") 160 | 161 | # TEST VOCODER 162 | ## Initialize a vocoder 163 | print("Initializing 44.1kHz speech vocoder...") 164 | vocoder = Vocoder(sample_rate=44100) 165 | 166 | ### read wave (fpath) -> mel spectrogram -> vocoder -> wave -> save wave (out_path) 167 | print("Test vocoder using groundtruth mel spectrogram...") 168 | vocoder.oracle(fpath=os.path.join(git_root,"test/utterance/original/p360_001_mic1.flac"), 169 | out_path=os.path.join(git_root,"test/utterance/output/oracle.flac"), 170 | cuda=False) # GPU acceleration 171 | 172 | ... 173 | ``` 174 | 175 | You can clone this repo and try to run test.py inside the *test* folder. 176 | 177 | ### Others Features 178 | 179 | - How to use your own vocoder, like pre-trained HiFi-Gan? 180 | 181 | First you need to write a following helper function with your model. Similar to the helper function in this repo: https://github.com/haoheliu/voicefixer/blob/main/voicefixer/vocoder/base.py#L35 182 | 183 | ```shell script 184 | def convert_mel_to_wav(mel): 185 | """ 186 | :param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel] 187 | :return: [batchsize, 1, samples] 188 | """ 189 | return wav 190 | ``` 191 | 192 | Then pass this function to *voicefixer.restore*, for example: 193 | ``` 194 | voicefixer.restore(input="", # input wav file path 195 | output="", # output wav file path 196 | cuda=False, # whether to use gpu acceleration 197 | mode = 0, 198 | your_vocoder_func = convert_mel_to_wav) 199 | ``` 200 | 201 | Note: 202 | - For compatibility, your vocoder should working on 44.1kHz wave with mel frequency bins 128. 203 | - The input mel spectrogram to the helper function should not be normalized by the width of each mel filter. 204 | 205 | ## Materials 206 | - Voicefixer training: https://github.com/haoheliu/voicefixer_main.git 207 | - Demo page: https://haoheliu.github.io/demopage-voicefixer/ 208 | 209 | [![46dnPO.png](https://z3.ax1x.com/2021/09/26/46dnPO.png)](https://imgtu.com/i/46dnPO) 210 | [![46dMxH.png](https://z3.ax1x.com/2021/09/26/46dMxH.png)](https://imgtu.com/i/46dMxH) 211 | 212 | 213 | ## Change log 214 | - 2022-09-03: Fix bugs on commandline voicefixer for windows users. 215 | - 2022-08-18: Add commandline voicefixer tool to the pip package. 216 | 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /voicefixer/vocoder/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from voicefixer.tools.path import root_path 5 | 6 | 7 | class Config: 8 | @classmethod 9 | def refresh(cls, sr): 10 | if sr == 44100: 11 | Config.ckpt = os.path.join( 12 | os.path.expanduser("~"), 13 | ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt", 14 | ) 15 | Config.cond_channels = 512 16 | Config.m_channels = 768 17 | Config.resstack_depth = [8, 8, 8, 8] 18 | Config.channels = 1024 19 | Config.cin_channels = 128 20 | Config.upsample_scales = [7, 7, 3, 3] 21 | Config.num_mels = 128 22 | Config.n_fft = 2048 23 | Config.hop_length = 441 24 | Config.sample_rate = 44100 25 | Config.fmax = 22000 26 | Config.mel_win = 128 27 | Config.local_condition_dim = 128 28 | else: 29 | raise RuntimeError( 30 | "Error: Vocoder currently only support 44100 samplerate." 31 | ) 32 | 33 | ckpt = os.path.join( 34 | os.path.expanduser("~"), 35 | ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt", 36 | ) 37 | m_channels = 384 38 | bits = 10 39 | opt = "Ralamb" 40 | cond_channels = 256 41 | clip = 0.5 42 | num_bands = 1 43 | cin_channels = 128 44 | upsample_scales = [7, 7, 3, 3] 45 | filterbands = "test/filterbanks_4bands.dat" 46 | ##For inference 47 | tag = "" 48 | min_db = -115 49 | num_mels = 128 50 | n_fft = 2048 51 | hop_length = 441 52 | win_size = None 53 | sample_rate = 44100 54 | frame_shift_ms = None 55 | 56 | trim_fft_size = 512 57 | trim_hop_size = 128 58 | trim_top_db = 23 59 | 60 | signal_normalization = True 61 | allow_clipping_in_normalization = True 62 | symmetric_mels = True 63 | max_abs_value = 4.0 64 | 65 | preemphasis = 0.85 66 | min_level_db = -100 67 | ref_level_db = 20 68 | fmin = 50 69 | fmax = 22000 70 | power = 1.5 71 | griffin_lim_iters = 60 72 | rescale = False 73 | rescaling_max = 0.95 74 | trim_silence = False 75 | clip_mels_length = True 76 | max_mel_frames = 2000 77 | 78 | mel_win = 128 79 | batch_size = 24 80 | g_learning_rate = 0.001 81 | d_learning_rate = 0.001 82 | warmup_steps = 100000 83 | decay_learning_rate = 0.5 84 | exponential_moving_average = True 85 | ema_decay = 0.99 86 | 87 | reset_opt = False 88 | reset_g_opt = False 89 | reset_d_opt = False 90 | 91 | local_condition_dim = 128 92 | lambda_update_G = 1 93 | multiscale_D = 3 94 | 95 | lambda_adv = 4.0 96 | lambda_fm_loss = 0.0 97 | lambda_sc_loss = 5.0 98 | lambda_mag_loss = 5.0 99 | lambda_mel_loss = 50.0 100 | use_mle_loss = False 101 | lambda_mle_loss = 5.0 102 | 103 | lambda_freq_loss = 2.0 104 | lambda_energy_loss = 100.0 105 | lambda_t_loss = 200.0 106 | lambda_phase_loss = 100.0 107 | lambda_f0_loss = 1.0 108 | use_elu = False 109 | de_preem = False # train 110 | up_org = False 111 | use_one = True 112 | use_small_D = False 113 | use_condnet = True 114 | use_depreem = False # inference 115 | use_msd = False 116 | model_type = "tfgan" # or bytewave, frame level vocoder using istft 117 | use_hjcud = False 118 | no_skip = False 119 | out_channels = 1 120 | use_postnet = False # wn in postnet 121 | use_wn = False # wn in resstack 122 | up_type = "transpose" 123 | use_smooth = False 124 | use_drop = False 125 | use_shift_scale = False 126 | use_gcnn = False 127 | resstack_depth = [6, 6, 6, 6] 128 | kernel_size = [3, 3, 3, 3] 129 | channels = 512 130 | use_f0_loss = False 131 | use_sine = False 132 | use_cond_rnn = False 133 | use_rnn = False 134 | 135 | f0_step = 120 136 | use_lowfreq_loss = False 137 | lambda_lowfreq_loss = 1.0 138 | use_film = False 139 | use_mb_mr_gan = False 140 | 141 | use_mssl = False 142 | use_ml_gan = False 143 | use_mb_gan = True 144 | use_mpd = False 145 | use_spec_gan = True 146 | use_rwd = False 147 | use_mr_gan = True 148 | use_pqmf_rwd = False 149 | no_sine = False 150 | use_frame_mask = False 151 | 152 | lambda_var_loss = 0.0 153 | discriminator_train_start_steps = 40000 # 80k 154 | aux_d_train_start_steps = 40000 # 100k 155 | rescale_out = 0.40 156 | use_dist = True 157 | dist_backend = "nccl" 158 | dist_url = "tcp://localhost:12345" 159 | world_size = 1 160 | 161 | mel_weight_torch = torch.tensor( 162 | [ 163 | 19.40951426, 164 | 19.94047336, 165 | 20.4859038, 166 | 21.04629067, 167 | 21.62194148, 168 | 22.21335214, 169 | 22.8210215, 170 | 23.44529231, 171 | 24.08660962, 172 | 24.74541882, 173 | 25.42234287, 174 | 26.11770576, 175 | 26.83212784, 176 | 27.56615283, 177 | 28.32007747, 178 | 29.0947679, 179 | 29.89060111, 180 | 30.70832636, 181 | 31.54828121, 182 | 32.41121487, 183 | 33.29780773, 184 | 34.20865341, 185 | 35.14437675, 186 | 36.1056621, 187 | 37.09332763, 188 | 38.10795802, 189 | 39.15039691, 190 | 40.22119881, 191 | 41.32154931, 192 | 42.45172373, 193 | 43.61293329, 194 | 44.80609379, 195 | 46.031602, 196 | 47.29070223, 197 | 48.58427549, 198 | 49.91327905, 199 | 51.27863232, 200 | 52.68119708, 201 | 54.1222372, 202 | 55.60274206, 203 | 57.12364703, 204 | 58.68617876, 205 | 60.29148652, 206 | 61.94081306, 207 | 63.63501986, 208 | 65.37562658, 209 | 67.16408954, 210 | 69.00109084, 211 | 70.88850318, 212 | 72.82736101, 213 | 74.81985537, 214 | 76.86654792, 215 | 78.96885475, 216 | 81.12900906, 217 | 83.34840929, 218 | 85.62810662, 219 | 87.97005418, 220 | 90.37689804, 221 | 92.84887686, 222 | 95.38872881, 223 | 97.99777002, 224 | 100.67862715, 225 | 103.43232942, 226 | 106.26140638, 227 | 109.16827015, 228 | 112.15470471, 229 | 115.22184756, 230 | 118.37439245, 231 | 121.6122689, 232 | 124.93877158, 233 | 128.35661454, 234 | 131.86761321, 235 | 135.47417938, 236 | 139.18059494, 237 | 142.98713744, 238 | 146.89771854, 239 | 150.91684347, 240 | 155.0446638, 241 | 159.28614648, 242 | 163.64270198, 243 | 168.12035831, 244 | 172.71749158, 245 | 177.44220154, 246 | 182.29556933, 247 | 187.28286676, 248 | 192.40502126, 249 | 197.6682721, 250 | 203.07516896, 251 | 208.63088733, 252 | 214.33770931, 253 | 220.19910108, 254 | 226.22363072, 255 | 232.41087124, 256 | 238.76803591, 257 | 245.30079083, 258 | 252.01064464, 259 | 258.90261676, 260 | 265.98474, 261 | 273.26010248, 262 | 280.73496362, 263 | 288.41440094, 264 | 296.30489752, 265 | 304.41180337, 266 | 312.7377183, 267 | 321.28877878, 268 | 330.07870237, 269 | 339.10812951, 270 | 348.38276173, 271 | 357.91393924, 272 | 367.70513992, 273 | 377.76413924, 274 | 388.09467408, 275 | 398.70920178, 276 | 409.61813793, 277 | 420.81980127, 278 | 432.33215467, 279 | 444.16083117, 280 | 456.30919947, 281 | 468.78589276, 282 | 481.61325588, 283 | 494.78824596, 284 | 508.31969844, 285 | 522.2238331, 286 | 536.51163441, 287 | 551.18859414, 288 | 566.26142988, 289 | 581.75006061, 290 | 597.66210737, 291 | ] 292 | ) 293 | 294 | x_orig = np.linspace(1, mel_weight_torch.shape[0], num=mel_weight_torch.shape[0]) 295 | 296 | x_orig_torch = torch.linspace( 297 | 1, mel_weight_torch.shape[0], steps=mel_weight_torch.shape[0] 298 | ) 299 | 300 | @classmethod 301 | def get_mel_weight(cls, percent=1, a=18.8927416350036, b=0.0269863588184314): 302 | b = percent * b 303 | 304 | def func(a, b, x): 305 | return a * np.exp(b * x) 306 | 307 | return func(a, b, Config.x_orig) 308 | 309 | @classmethod 310 | def get_mel_weight_torch(cls, percent=1, a=18.8927416350036, b=0.0269863588184314): 311 | b = percent * b 312 | 313 | def func(a, b, x): 314 | return a * torch.exp(b * x) 315 | 316 | return func(a, b, Config.x_orig_torch) 317 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/fDomainHelper.py: -------------------------------------------------------------------------------- 1 | from torchlibrosa.stft import STFT, ISTFT, magphase 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from voicefixer.tools.modules.pqmf import PQMF 6 | 7 | class FDomainHelper(nn.Module): 8 | def __init__( 9 | self, 10 | window_size=2048, 11 | hop_size=441, 12 | center=True, 13 | pad_mode="reflect", 14 | window="hann", 15 | freeze_parameters=True, 16 | subband=None, 17 | root="/Users/admin/Documents/projects/", 18 | ): 19 | super(FDomainHelper, self).__init__() 20 | self.subband = subband 21 | # assert torchlibrosa.__version__ == "0.0.7", "Error: Found torchlibrosa version %s. Please install 0.0.7 version of torchlibrosa by: pip install torchlibrosa==0.0.7." % torchlibrosa.__version__ 22 | if self.subband is None: 23 | self.stft = STFT( 24 | n_fft=window_size, 25 | hop_length=hop_size, 26 | win_length=window_size, 27 | window=window, 28 | center=center, 29 | pad_mode=pad_mode, 30 | freeze_parameters=freeze_parameters, 31 | ) 32 | 33 | self.istft = ISTFT( 34 | n_fft=window_size, 35 | hop_length=hop_size, 36 | win_length=window_size, 37 | window=window, 38 | center=center, 39 | pad_mode=pad_mode, 40 | freeze_parameters=freeze_parameters, 41 | ) 42 | else: 43 | self.stft = STFT( 44 | n_fft=window_size // self.subband, 45 | hop_length=hop_size // self.subband, 46 | win_length=window_size // self.subband, 47 | window=window, 48 | center=center, 49 | pad_mode=pad_mode, 50 | freeze_parameters=freeze_parameters, 51 | ) 52 | 53 | self.istft = ISTFT( 54 | n_fft=window_size // self.subband, 55 | hop_length=hop_size // self.subband, 56 | win_length=window_size // self.subband, 57 | window=window, 58 | center=center, 59 | pad_mode=pad_mode, 60 | freeze_parameters=freeze_parameters, 61 | ) 62 | 63 | if subband is not None and root is not None: 64 | self.qmf = PQMF(subband, 64, root) 65 | 66 | def complex_spectrogram(self, input, eps=0.0): 67 | # [batchsize, samples] 68 | # return [batchsize, 2, t-steps, f-bins] 69 | real, imag = self.stft(input) 70 | return torch.cat([real, imag], dim=1) 71 | 72 | def reverse_complex_spectrogram(self, input, eps=0.0, length=None): 73 | # [batchsize, 2[real,imag], t-steps, f-bins] 74 | wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length) 75 | return wav 76 | 77 | def spectrogram(self, input, eps=0.0): 78 | (real, imag) = self.stft(input.float()) 79 | return torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5 80 | 81 | def spectrogram_phase(self, input, eps=0.0): 82 | (real, imag) = self.stft(input.float()) 83 | mag = torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5 84 | cos = real / mag 85 | sin = imag / mag 86 | return mag, cos, sin 87 | 88 | def wav_to_spectrogram_phase(self, input, eps=1e-8): 89 | """Waveform to spectrogram. 90 | 91 | Args: 92 | input: (batch_size, channels_num, segment_samples) 93 | 94 | Outputs: 95 | output: (batch_size, channels_num, time_steps, freq_bins) 96 | """ 97 | sp_list = [] 98 | cos_list = [] 99 | sin_list = [] 100 | channels_num = input.shape[1] 101 | for channel in range(channels_num): 102 | mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) 103 | sp_list.append(mag) 104 | cos_list.append(cos) 105 | sin_list.append(sin) 106 | 107 | sps = torch.cat(sp_list, dim=1) 108 | coss = torch.cat(cos_list, dim=1) 109 | sins = torch.cat(sin_list, dim=1) 110 | return sps, coss, sins 111 | 112 | def spectrogram_phase_to_wav(self, sps, coss, sins, length): 113 | channels_num = sps.size()[1] 114 | res = [] 115 | for i in range(channels_num): 116 | res.append( 117 | self.istft( 118 | sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...], 119 | sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...], 120 | length, 121 | ) 122 | ) 123 | res[-1] = res[-1].unsqueeze(1) 124 | return torch.cat(res, dim=1) 125 | 126 | def wav_to_spectrogram(self, input, eps=1e-8): 127 | """Waveform to spectrogram. 128 | 129 | Args: 130 | input: (batch_size,channels_num, segment_samples) 131 | 132 | Outputs: 133 | output: (batch_size, channels_num, time_steps, freq_bins) 134 | """ 135 | sp_list = [] 136 | channels_num = input.shape[1] 137 | for channel in range(channels_num): 138 | sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) 139 | output = torch.cat(sp_list, dim=1) 140 | return output 141 | 142 | def spectrogram_to_wav(self, input, spectrogram, length=None): 143 | """Spectrogram to waveform. 144 | Args: 145 | input: (batch_size, segment_samples, channels_num) 146 | spectrogram: (batch_size, channels_num, time_steps, freq_bins) 147 | 148 | Outputs: 149 | output: (batch_size, segment_samples, channels_num) 150 | """ 151 | channels_num = input.shape[1] 152 | wav_list = [] 153 | for channel in range(channels_num): 154 | (real, imag) = self.stft(input[:, channel, :]) 155 | (_, cos, sin) = magphase(real, imag) 156 | wav_list.append( 157 | self.istft( 158 | spectrogram[:, channel : channel + 1, :, :] * cos, 159 | spectrogram[:, channel : channel + 1, :, :] * sin, 160 | length, 161 | ) 162 | ) 163 | 164 | output = torch.stack(wav_list, dim=1) 165 | return output 166 | 167 | # todo the following code is not bug free! 168 | def wav_to_complex_spectrogram(self, input, eps=0.0): 169 | # [batchsize , channels, samples] 170 | # [batchsize, 2[real,imag]*channels, t-steps, f-bins] 171 | res = [] 172 | channels_num = input.shape[1] 173 | for channel in range(channels_num): 174 | res.append(self.complex_spectrogram(input[:, channel, :], eps=eps)) 175 | return torch.cat(res, dim=1) 176 | 177 | def complex_spectrogram_to_wav(self, input, eps=0.0, length=None): 178 | # [batchsize, 2[real,imag]*channels, t-steps, f-bins] 179 | # return [batchsize, channels, samples] 180 | channels = input.size()[1] // 2 181 | wavs = [] 182 | for i in range(channels): 183 | wavs.append( 184 | self.reverse_complex_spectrogram( 185 | input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length 186 | ) 187 | ) 188 | wavs[-1] = wavs[-1].unsqueeze(1) 189 | return torch.cat(wavs, dim=1) 190 | 191 | def wav_to_complex_subband_spectrogram(self, input, eps=0.0): 192 | # [batchsize, channels, samples] 193 | # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] 194 | subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] 195 | subspec = self.wav_to_complex_spectrogram(subwav) 196 | return subspec 197 | 198 | def complex_subband_spectrogram_to_wav(self, input, eps=0.0): 199 | # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] 200 | # [batchsize, channels, samples] 201 | subwav = self.complex_spectrogram_to_wav(input) 202 | data = self.qmf.synthesis(subwav) 203 | return data 204 | 205 | def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8): 206 | """ 207 | :param input: 208 | :param eps: 209 | :return: 210 | loss = torch.nn.L1Loss() 211 | models = FDomainHelper(subband=4) 212 | data = torch.randn((3,1, 44100*3)) 213 | 214 | sps, coss, sins = models.wav_to_mag_phase_subband_spectrogram(data) 215 | wav = models.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4) 216 | 217 | print(loss(data,wav)) 218 | print(torch.max(torch.abs(data-wav))) 219 | 220 | """ 221 | # [batchsize, channels, samples] 222 | # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] 223 | subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] 224 | sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps) 225 | return sps, coss, sins 226 | 227 | def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0): 228 | # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] 229 | # [batchsize, channels, samples] 230 | subwav = self.spectrogram_phase_to_wav( 231 | sps, coss, sins, length + self.qmf.pad_samples // self.qmf.N 232 | ) 233 | data = self.qmf.synthesis(subwav) 234 | return data 235 | -------------------------------------------------------------------------------- /voicefixer/restorer/model.py: -------------------------------------------------------------------------------- 1 | # import pytorch_lightning as pl 2 | 3 | import torch.utils 4 | from voicefixer.tools.mel_scale import MelScale 5 | import torch.utils.data 6 | import matplotlib.pyplot as plt 7 | import librosa.display 8 | from voicefixer.vocoder.base import Vocoder 9 | from voicefixer.tools.pytorch_util import * 10 | from voicefixer.restorer.model_kqq_bn import UNetResComplex_100Mb 11 | from voicefixer.tools.random_ import * 12 | from voicefixer.tools.wav import * 13 | from voicefixer.tools.modules.fDomainHelper import FDomainHelper 14 | 15 | from voicefixer.tools.io import load_json, write_json 16 | from matplotlib import cm 17 | 18 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 19 | EPS = 1e-8 20 | 21 | 22 | class BN_GRU(torch.nn.Module): 23 | def __init__( 24 | self, 25 | input_dim, 26 | hidden_dim, 27 | layer=1, 28 | bidirectional=False, 29 | batchnorm=True, 30 | dropout=0.0, 31 | ): 32 | super(BN_GRU, self).__init__() 33 | self.batchnorm = batchnorm 34 | if batchnorm: 35 | self.bn = nn.BatchNorm2d(1) 36 | self.gru = torch.nn.GRU( 37 | input_size=input_dim, 38 | hidden_size=hidden_dim, 39 | num_layers=layer, 40 | bidirectional=bidirectional, 41 | dropout=dropout, 42 | batch_first=True, 43 | ) 44 | self.init_weights() 45 | 46 | def init_weights(self): 47 | for m in self.modules(): 48 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 49 | for name, param in m.named_parameters(): 50 | if "weight_ih" in name: 51 | torch.nn.init.xavier_uniform_(param.data) 52 | elif "weight_hh" in name: 53 | torch.nn.init.orthogonal_(param.data) 54 | elif "bias" in name: 55 | param.data.fill_(0) 56 | 57 | def forward(self, inputs): 58 | # (batch, 1, seq, feature) 59 | if self.batchnorm: 60 | inputs = self.bn(inputs) 61 | out, _ = self.gru(inputs.squeeze(1)) 62 | return out.unsqueeze(1) 63 | 64 | 65 | class Generator(nn.Module): 66 | def __init__(self, n_mel, hidden, channels): 67 | super(Generator, self).__init__() 68 | # todo the currently running trail don't have dropout 69 | self.denoiser = nn.Sequential( 70 | nn.BatchNorm2d(1), 71 | nn.Linear(n_mel, n_mel * 2), 72 | nn.ReLU(inplace=True), 73 | nn.BatchNorm2d(1), 74 | nn.Linear(n_mel * 2, n_mel * 4), 75 | nn.Dropout(0.5), 76 | nn.ReLU(inplace=True), 77 | BN_GRU( 78 | input_dim=n_mel * 4, 79 | hidden_dim=n_mel * 2, 80 | bidirectional=True, 81 | layer=2, 82 | batchnorm=True, 83 | ), 84 | BN_GRU( 85 | input_dim=n_mel * 4, 86 | hidden_dim=n_mel * 2, 87 | bidirectional=True, 88 | layer=2, 89 | batchnorm=True, 90 | ), 91 | nn.BatchNorm2d(1), 92 | nn.ReLU(inplace=True), 93 | nn.Linear(n_mel * 4, n_mel * 4), 94 | nn.Dropout(0.5), 95 | nn.BatchNorm2d(1), 96 | nn.ReLU(inplace=True), 97 | nn.Linear(n_mel * 4, n_mel), 98 | nn.Sigmoid(), 99 | ) 100 | 101 | self.unet = UNetResComplex_100Mb(channels=channels) 102 | 103 | def forward(self, sp, mel_orig): 104 | # Denoising 105 | noisy = mel_orig.clone() 106 | clean = self.denoiser(noisy) * noisy 107 | x = to_log(clean.detach()) 108 | unet_in = torch.cat([to_log(mel_orig), x], dim=1) 109 | # unet_in = lstm_out 110 | unet_out = self.unet(unet_in)["mel"] 111 | # masks 112 | mel = unet_out + x 113 | # todo mel and addition here are in log scales 114 | return { 115 | "mel": mel, 116 | "lstm_out": unet_out, 117 | "unet_out": unet_out, 118 | "noisy": noisy, 119 | "clean": clean, 120 | } 121 | 122 | 123 | class VoiceFixer(nn.Module): 124 | def __init__( 125 | self, 126 | channels, 127 | type_target="vocals", 128 | nsrc=1, 129 | loss="l1", 130 | lr=0.002, 131 | gamma=0.9, 132 | batchsize=None, 133 | frame_length=None, 134 | sample_rate=None, 135 | warm_up_steps=1000, 136 | reduce_lr_steps=15000, 137 | # datas 138 | check_val_every_n_epoch=5, 139 | ): 140 | super(VoiceFixer, self).__init__() 141 | 142 | if sample_rate == 44100: 143 | window_size = 2048 144 | hop_size = 441 145 | n_mel = 128 146 | elif sample_rate == 24000: 147 | window_size = 768 148 | hop_size = 240 149 | n_mel = 80 150 | elif sample_rate == 16000: 151 | window_size = 512 152 | hop_size = 160 153 | n_mel = 80 154 | else: 155 | raise ValueError( 156 | "Error: Sample rate " + str(sample_rate) + " not supported" 157 | ) 158 | 159 | center = (True,) 160 | pad_mode = "reflect" 161 | window = "hann" 162 | freeze_parameters = True 163 | 164 | # self.save_hyperparameters() 165 | self.nsrc = nsrc 166 | self.type_target = type_target 167 | self.channels = channels 168 | self.lr = lr 169 | self.generated = None 170 | self.gamma = gamma 171 | self.sample_rate = sample_rate 172 | self.sample_rate = sample_rate 173 | self.batchsize = batchsize 174 | self.frame_length = frame_length 175 | # self.hparams['channels'] = 2 176 | 177 | # self.am = AudioMetrics() 178 | # self.im = ImgMetrics() 179 | 180 | self.vocoder = Vocoder(sample_rate=44100) 181 | 182 | self.valid = None 183 | self.fake = None 184 | 185 | self.train_step = 0 186 | self.val_step = 0 187 | self.val_result_save_dir = None 188 | self.val_result_save_dir_step = None 189 | self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks} 190 | self.check_val_every_n_epoch = check_val_every_n_epoch 191 | 192 | self.f_helper = FDomainHelper( 193 | window_size=window_size, 194 | hop_size=hop_size, 195 | center=center, 196 | pad_mode=pad_mode, 197 | window=window, 198 | freeze_parameters=freeze_parameters, 199 | ) 200 | 201 | hidden = window_size // 2 + 1 202 | 203 | self.mel = MelScale(n_mels=n_mel, sample_rate=sample_rate, n_stft=hidden) 204 | 205 | # masking 206 | self.generator = Generator(n_mel, hidden, channels) 207 | 208 | self.lr_lambda = lambda step: self.get_lr_lambda( 209 | step, 210 | gamma=self.gamma, 211 | warm_up_steps=warm_up_steps, 212 | reduce_lr_steps=reduce_lr_steps, 213 | ) 214 | 215 | self.lr_lambda_2 = lambda step: self.get_lr_lambda( 216 | step, gamma=self.gamma, warm_up_steps=10, reduce_lr_steps=reduce_lr_steps 217 | ) 218 | 219 | self.mel_weight_44k_128 = ( 220 | torch.tensor( 221 | [ 222 | 19.40951426, 223 | 19.94047336, 224 | 20.4859038, 225 | 21.04629067, 226 | 21.62194148, 227 | 22.21335214, 228 | 22.8210215, 229 | 23.44529231, 230 | 24.08660962, 231 | 24.74541882, 232 | 25.42234287, 233 | 26.11770576, 234 | 26.83212784, 235 | 27.56615283, 236 | 28.32007747, 237 | 29.0947679, 238 | 29.89060111, 239 | 30.70832636, 240 | 31.54828121, 241 | 32.41121487, 242 | 33.29780773, 243 | 34.20865341, 244 | 35.14437675, 245 | 36.1056621, 246 | 37.09332763, 247 | 38.10795802, 248 | 39.15039691, 249 | 40.22119881, 250 | 41.32154931, 251 | 42.45172373, 252 | 43.61293329, 253 | 44.80609379, 254 | 46.031602, 255 | 47.29070223, 256 | 48.58427549, 257 | 49.91327905, 258 | 51.27863232, 259 | 52.68119708, 260 | 54.1222372, 261 | 55.60274206, 262 | 57.12364703, 263 | 58.68617876, 264 | 60.29148652, 265 | 61.94081306, 266 | 63.63501986, 267 | 65.37562658, 268 | 67.16408954, 269 | 69.00109084, 270 | 70.88850318, 271 | 72.82736101, 272 | 74.81985537, 273 | 76.86654792, 274 | 78.96885475, 275 | 81.12900906, 276 | 83.34840929, 277 | 85.62810662, 278 | 87.97005418, 279 | 90.37689804, 280 | 92.84887686, 281 | 95.38872881, 282 | 97.99777002, 283 | 100.67862715, 284 | 103.43232942, 285 | 106.26140638, 286 | 109.16827015, 287 | 112.15470471, 288 | 115.22184756, 289 | 118.37439245, 290 | 121.6122689, 291 | 124.93877158, 292 | 128.35661454, 293 | 131.86761321, 294 | 135.47417938, 295 | 139.18059494, 296 | 142.98713744, 297 | 146.89771854, 298 | 150.91684347, 299 | 155.0446638, 300 | 159.28614648, 301 | 163.64270198, 302 | 168.12035831, 303 | 172.71749158, 304 | 177.44220154, 305 | 182.29556933, 306 | 187.28286676, 307 | 192.40502126, 308 | 197.6682721, 309 | 203.07516896, 310 | 208.63088733, 311 | 214.33770931, 312 | 220.19910108, 313 | 226.22363072, 314 | 232.41087124, 315 | 238.76803591, 316 | 245.30079083, 317 | 252.01064464, 318 | 258.90261676, 319 | 265.98474, 320 | 273.26010248, 321 | 280.73496362, 322 | 288.41440094, 323 | 296.30489752, 324 | 304.41180337, 325 | 312.7377183, 326 | 321.28877878, 327 | 330.07870237, 328 | 339.10812951, 329 | 348.38276173, 330 | 357.91393924, 331 | 367.70513992, 332 | 377.76413924, 333 | 388.09467408, 334 | 398.70920178, 335 | 409.61813793, 336 | 420.81980127, 337 | 432.33215467, 338 | 444.16083117, 339 | 456.30919947, 340 | 468.78589276, 341 | 481.61325588, 342 | 494.78824596, 343 | 508.31969844, 344 | 522.2238331, 345 | 536.51163441, 346 | 551.18859414, 347 | 566.26142988, 348 | 581.75006061, 349 | 597.66210737, 350 | ] 351 | ) 352 | / 19.40951426 353 | ) 354 | self.mel_weight_44k_128 = self.mel_weight_44k_128[None, None, None, ...] 355 | 356 | self.g_loss_weight = 0.01 357 | self.d_loss_weight = 1 358 | 359 | def get_vocoder(self): 360 | return self.vocoder 361 | 362 | def get_f_helper(self): 363 | return self.f_helper 364 | 365 | def get_lr_lambda(self, step, gamma, warm_up_steps, reduce_lr_steps): 366 | r"""Get lr_lambda for LambdaLR. E.g., 367 | 368 | .. code-block: python 369 | lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000) 370 | 371 | from torch.optim.lr_scheduler import LambdaLR 372 | LambdaLR(optimizer, lr_lambda) 373 | """ 374 | if step <= warm_up_steps: 375 | return step / warm_up_steps 376 | else: 377 | return gamma ** (step // reduce_lr_steps) 378 | 379 | def init_weights(self, module: nn.Module): 380 | for m in module.modules(): 381 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 382 | for name, param in m.named_parameters(): 383 | if "weight_ih" in name: 384 | torch.nn.init.xavier_uniform_(param.data) 385 | elif "weight_hh" in name: 386 | torch.nn.init.orthogonal_(param.data) 387 | elif "bias" in name: 388 | param.data.fill_(0) 389 | 390 | def pre(self, input): 391 | sp, _, _ = self.f_helper.wav_to_spectrogram_phase(input) 392 | mel_orig = self.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) 393 | return sp, mel_orig 394 | 395 | def forward(self, sp, mel_orig): 396 | """ 397 | Args: 398 | input: (batch_size, channels_num, segment_samples) 399 | 400 | Outputs: 401 | output_dict: { 402 | 'wav': (batch_size, channels_num, segment_samples), 403 | 'sp': (batch_size, channels_num, time_steps, freq_bins)} 404 | """ 405 | return self.generator(sp, mel_orig) 406 | 407 | def configure_optimizers(self): 408 | optimizer_g = torch.optim.Adam( 409 | [{"params": self.generator.parameters()}], 410 | lr=self.lr, 411 | amsgrad=True, 412 | betas=(0.5, 0.999), 413 | ) 414 | optimizer_d = torch.optim.Adam( 415 | [{"params": self.discriminator.parameters()}], 416 | lr=self.lr, 417 | amsgrad=True, 418 | betas=(0.5, 0.999), 419 | ) 420 | 421 | scheduler_g = { 422 | "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_g, self.lr_lambda), 423 | "interval": "step", 424 | "frequency": 1, 425 | } 426 | scheduler_d = { 427 | "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_d, self.lr_lambda), 428 | "interval": "step", 429 | "frequency": 1, 430 | } 431 | return [optimizer_g, optimizer_d], [scheduler_g, scheduler_d] 432 | 433 | def preprocess(self, batch, train=False, cutoff=None): 434 | if train: 435 | vocal = batch[self.type_target] # final target 436 | noise = batch["noise_LR"] # augmented low resolution audio with noise 437 | augLR = batch[ 438 | self.type_target + "_aug_LR" 439 | ] # # augment low resolution audio 440 | LR = batch[self.type_target + "_LR"] 441 | # embed() 442 | vocal, LR, augLR, noise = ( 443 | vocal.float().permute(0, 2, 1), 444 | LR.float().permute(0, 2, 1), 445 | augLR.float().permute(0, 2, 1), 446 | noise.float().permute(0, 2, 1), 447 | ) 448 | # LR, noise = self.add_random_noise(LR, noise) 449 | snr, scale = [], [] 450 | for i in range(vocal.size()[0]): 451 | ( 452 | vocal[i, ...], 453 | LR[i, ...], 454 | augLR[i, ...], 455 | noise[i, ...], 456 | _snr, 457 | _scale, 458 | ) = add_noise_and_scale_with_HQ_with_Aug( 459 | vocal[i, ...], 460 | LR[i, ...], 461 | augLR[i, ...], 462 | noise[i, ...], 463 | snr_l=-5, 464 | snr_h=45, 465 | scale_lower=0.6, 466 | scale_upper=1.0, 467 | ) 468 | snr.append(_snr), scale.append(_scale) 469 | # vocal, LR = self.amp_to_original_f(vocal, LR) 470 | # noise = (noise * 0.0) + 1e-8 # todo 471 | return vocal, augLR, LR, noise + augLR 472 | else: 473 | if cutoff is None: 474 | LR_noisy = batch["noisy"] 475 | LR = batch["vocals"] 476 | vocals = batch["vocals"] 477 | vocals, LR, LR_noisy = ( 478 | vocals.float().permute(0, 2, 1), 479 | LR.float().permute(0, 2, 1), 480 | LR_noisy.float().permute(0, 2, 1), 481 | ) 482 | return vocals, LR, LR_noisy, batch["fname"][0] 483 | else: 484 | LR_noisy = batch["noisy" + "LR" + "_" + str(cutoff)] 485 | LR = batch["vocals" + "LR" + "_" + str(cutoff)] 486 | vocals = batch["vocals"] 487 | vocals, LR, LR_noisy = ( 488 | vocals.float().permute(0, 2, 1), 489 | LR.float().permute(0, 2, 1), 490 | LR_noisy.float().permute(0, 2, 1), 491 | ) 492 | return vocals, LR, LR_noisy, batch["fname"][0] 493 | 494 | def training_step(self, batch, batch_nb, optimizer_idx): 495 | # dict_keys(['vocals', 'vocals_aug', 'vocals_augLR', 'noise']) 496 | config = load_json("temp_path.json") 497 | if "g_loss_weight" not in config.keys(): 498 | config["g_loss_weight"] = self.g_loss_weight 499 | config["d_loss_weight"] = self.d_loss_weight 500 | write_json(config, "temp_path.json") 501 | elif ( 502 | config["g_loss_weight"] != self.g_loss_weight 503 | or config["d_loss_weight"] != self.d_loss_weight 504 | ): 505 | print( 506 | "Update d_loss weight, from", 507 | self.d_loss_weight, 508 | "to", 509 | config["d_loss_weight"], 510 | ) 511 | print( 512 | "Update g_loss weight, from", 513 | self.g_loss_weight, 514 | "to", 515 | config["g_loss_weight"], 516 | ) 517 | self.g_loss_weight = config["g_loss_weight"] 518 | self.d_loss_weight = config["d_loss_weight"] 519 | 520 | if optimizer_idx == 0: 521 | self.vocal, self.augLR, _, self.LR_noisy = self.preprocess( 522 | batch, train=True 523 | ) 524 | 525 | for i in range(self.vocal.size()[0]): 526 | save_wave( 527 | tensor2numpy(self.vocal[i, ...]), 528 | str(i) + "vocal" + ".wav", 529 | sample_rate=44100, 530 | ) 531 | save_wave( 532 | tensor2numpy(self.LR_noisy[i, ...]), 533 | str(i) + "LR_noisy" + ".wav", 534 | sample_rate=44100, 535 | ) 536 | 537 | # all_mel_e2e in non-log scale 538 | _, self.mel_target = self.pre(self.vocal) 539 | self.sp_LR_target, self.mel_LR_target = self.pre(self.augLR) 540 | self.sp_LR_target_noisy, self.mel_LR_target_noisy = self.pre(self.LR_noisy) 541 | 542 | if self.valid is None or self.valid.size()[0] != self.mel_target.size()[0]: 543 | self.valid = torch.ones( 544 | self.mel_target.size()[0], 1, self.mel_target.size()[2], 1 545 | ) 546 | self.valid = self.valid.type_as(self.mel_target) 547 | if self.fake is None or self.fake.size()[0] != self.mel_target.size()[0]: 548 | self.fake = torch.zeros( 549 | self.mel_target.size()[0], 1, self.mel_target.size()[2], 1 550 | ) 551 | self.fake = self.fake.type_as(self.mel_target) 552 | 553 | self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy) 554 | 555 | denoise_loss = self.l1loss(self.generated["clean"], self.mel_LR_target) 556 | targ_loss = self.l1loss(self.generated["mel"], to_log(self.mel_target)) 557 | 558 | self.log( 559 | "targ-l", 560 | targ_loss, 561 | on_step=True, 562 | on_epoch=False, 563 | logger=True, 564 | sync_dist=True, 565 | prog_bar=True, 566 | ) 567 | self.log( 568 | "noise-l", 569 | denoise_loss, 570 | on_step=True, 571 | on_epoch=False, 572 | logger=True, 573 | sync_dist=True, 574 | prog_bar=True, 575 | ) 576 | 577 | loss = targ_loss + denoise_loss 578 | 579 | if self.train_step >= 18000: 580 | g_loss = self.bce_loss( 581 | self.discriminator(self.generated["mel"]), self.valid 582 | ) 583 | self.log( 584 | "g_l", 585 | g_loss, 586 | on_step=True, 587 | on_epoch=False, 588 | logger=True, 589 | sync_dist=True, 590 | prog_bar=True, 591 | ) 592 | # print("g_loss", g_loss) 593 | all_loss = loss + self.g_loss_weight * g_loss 594 | self.log( 595 | "all_loss", 596 | all_loss, 597 | on_step=True, 598 | on_epoch=True, 599 | logger=True, 600 | sync_dist=True, 601 | ) 602 | else: 603 | all_loss = loss 604 | self.train_step += 0.5 605 | return {"loss": all_loss} 606 | 607 | elif optimizer_idx == 1: 608 | if self.train_step >= 16000: 609 | self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy) 610 | self.train_step += 0.5 611 | real_loss = self.bce_loss( 612 | self.discriminator(to_log(self.mel_target)), self.valid 613 | ) 614 | self.log( 615 | "r_l", 616 | real_loss, 617 | on_step=True, 618 | on_epoch=False, 619 | logger=True, 620 | sync_dist=True, 621 | prog_bar=True, 622 | ) 623 | fake_loss = self.bce_loss( 624 | self.discriminator(self.generated["mel"].detach()), self.fake 625 | ) 626 | self.log( 627 | "d_l", 628 | fake_loss, 629 | on_step=True, 630 | on_epoch=False, 631 | logger=True, 632 | sync_dist=True, 633 | prog_bar=True, 634 | ) 635 | d_loss = self.d_loss_weight * (real_loss + fake_loss) / 2 636 | self.log( 637 | "discriminator_loss", 638 | d_loss, 639 | on_step=True, 640 | on_epoch=True, 641 | logger=True, 642 | sync_dist=True, 643 | ) 644 | return {"loss": d_loss} 645 | 646 | def draw_and_save( 647 | self, mel: torch.Tensor, path, clip_max=None, clip_min=None, needlog=True 648 | ): 649 | plt.figure(figsize=(15, 5)) 650 | if clip_min is None: 651 | clip_max, clip_min = self.clip(mel) 652 | mel = np.transpose(tensor2numpy(mel)[0, 0, ...], (1, 0)) 653 | # assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + str(np.sum(mel < 0)) 654 | 655 | if needlog: 656 | assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + "-" + path 657 | mel_log = np.log10(mel + EPS) 658 | else: 659 | mel_log = mel 660 | 661 | # plt.imshow(mel) 662 | librosa.display.specshow( 663 | mel_log, 664 | sr=44100, 665 | x_axis="frames", 666 | y_axis="mel", 667 | cmap=cm.jet, 668 | vmax=clip_max, 669 | vmin=clip_min, 670 | ) 671 | plt.colorbar() 672 | plt.savefig(path) 673 | plt.close() 674 | 675 | def clip(self, *args): 676 | val_max, val_min = [], [] 677 | for each in args: 678 | val_max.append(torch.max(each)) 679 | val_min.append(torch.min(each)) 680 | return max(val_max), min(val_min) 681 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from voicefixer.vocoder.config import Config 7 | 8 | # From xin wang of nii 9 | class SineGen(torch.nn.Module): 10 | """Definition of sine generator 11 | SineGen(samp_rate, harmonic_num = 0, 12 | sine_amp = 0.1, noise_std = 0.003, 13 | voiced_threshold = 0, 14 | flag_for_pulse=False) 15 | 16 | samp_rate: sampling rate in Hz 17 | harmonic_num: number of harmonic overtones (default 0) 18 | sine_amp: amplitude of sine-wavefrom (default 0.1) 19 | noise_std: std of Gaussian noise (default 0.003) 20 | voiced_thoreshold: F0 threshold for U/V classification (default 0) 21 | flag_for_pulse: this SinGen is used inside PulseGen (default False) 22 | 23 | Note: when flag_for_pulse is True, the first time step of a voiced 24 | segment is always sin(np.pi) or cos(0) 25 | """ 26 | 27 | def __init__( 28 | self, 29 | samp_rate=24000, 30 | harmonic_num=0, 31 | sine_amp=0.1, 32 | noise_std=0.003, 33 | voiced_threshold=0, 34 | flag_for_pulse=False, 35 | ): 36 | super(SineGen, self).__init__() 37 | self.sine_amp = sine_amp 38 | self.noise_std = noise_std 39 | self.harmonic_num = harmonic_num 40 | self.dim = self.harmonic_num + 1 41 | self.sampling_rate = samp_rate 42 | self.voiced_threshold = voiced_threshold 43 | self.flag_for_pulse = flag_for_pulse 44 | 45 | def _f02uv(self, f0): 46 | # generate uv signal 47 | uv = torch.ones_like(f0) 48 | uv = uv * (f0 > self.voiced_threshold) 49 | return uv 50 | 51 | def _f02sine(self, f0_values): 52 | """f0_values: (batchsize, length, dim) 53 | where dim indicates fundamental tone and overtones 54 | """ 55 | # convert to F0 in rad. The interger part n can be ignored 56 | # because 2 * np.pi * n doesn't affect phase 57 | rad_values = (f0_values / self.sampling_rate) % 1 58 | 59 | # initial phase noise (no noise for fundamental component) 60 | rand_ini = torch.rand( 61 | f0_values.shape[0], f0_values.shape[2], device=f0_values.device 62 | ) 63 | rand_ini[:, 0] = 0 64 | rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini 65 | 66 | # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) 67 | if not self.flag_for_pulse: 68 | # for normal case 69 | 70 | # To prevent torch.cumsum numerical overflow, 71 | # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. 72 | # Buffer tmp_over_one_idx indicates the time step to add -1. 73 | # This will not change F0 of sine because (x-1) * 2*pi = x *2*pi 74 | tmp_over_one = torch.cumsum(rad_values, 1) % 1 75 | tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 76 | cumsum_shift = torch.zeros_like(rad_values) 77 | cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 78 | 79 | sines = torch.sin( 80 | torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi 81 | ) 82 | else: 83 | # If necessary, make sure that the first time step of every 84 | # voiced segments is sin(pi) or cos(0) 85 | # This is used for pulse-train generation 86 | 87 | # identify the last time step in unvoiced segments 88 | uv = self._f02uv(f0_values) 89 | uv_1 = torch.roll(uv, shifts=-1, dims=1) 90 | uv_1[:, -1, :] = 1 91 | u_loc = (uv < 1) * (uv_1 > 0) 92 | 93 | # get the instantanouse phase 94 | tmp_cumsum = torch.cumsum(rad_values, dim=1) 95 | # different batch needs to be processed differently 96 | for idx in range(f0_values.shape[0]): 97 | temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] 98 | temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] 99 | # stores the accumulation of i.phase within 100 | # each voiced segments 101 | tmp_cumsum[idx, :, :] = 0 102 | tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum 103 | 104 | # rad_values - tmp_cumsum: remove the accumulation of i.phase 105 | # within the previous voiced segment. 106 | i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) 107 | 108 | # get the sines 109 | sines = torch.cos(i_phase * 2 * np.pi) 110 | return sines 111 | 112 | def forward(self, f0): 113 | """sine_tensor, uv = forward(f0) 114 | input F0: tensor(batchsize=1, length, dim=1) 115 | f0 for unvoiced steps should be 0 116 | output sine_tensor: tensor(batchsize=1, length, dim) 117 | output uv: tensor(batchsize=1, length, 1) 118 | """ 119 | 120 | with torch.no_grad(): 121 | f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) 122 | # fundamental component 123 | f0_buf[:, :, 0] = f0[:, :, 0] 124 | for idx in np.arange(self.harmonic_num): 125 | # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic 126 | f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) 127 | 128 | # generate sine waveforms 129 | sine_waves = self._f02sine(f0_buf) * self.sine_amp 130 | 131 | # generate uv signal 132 | # uv = torch.ones(f0.shape) 133 | # uv = uv * (f0 > self.voiced_threshold) 134 | uv = self._f02uv(f0) 135 | 136 | # noise: for unvoiced should be similar to sine_amp 137 | # std = self.sine_amp/3 -> max value ~ self.sine_amp 138 | # . for voiced regions is self.noise_std 139 | noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 140 | noise = noise_amp * torch.randn_like(sine_waves) 141 | 142 | # first: set the unvoiced part to 0 by uv 143 | # then: additive noise 144 | sine_waves = sine_waves * uv + noise 145 | return sine_waves, uv, noise 146 | 147 | 148 | class LowpassBlur(nn.Module): 149 | """perform low pass filter after upsampling for anti-aliasing""" 150 | 151 | def __init__(self, channels=128, filt_size=3, pad_type="reflect", pad_off=0): 152 | super(LowpassBlur, self).__init__() 153 | self.filt_size = filt_size 154 | self.pad_off = pad_off 155 | self.pad_sizes = [ 156 | int(1.0 * (filt_size - 1) / 2), 157 | int(np.ceil(1.0 * (filt_size - 1) / 2)), 158 | ] 159 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 160 | self.off = 0 161 | self.channels = channels 162 | 163 | if self.filt_size == 1: 164 | a = np.array( 165 | [ 166 | 1.0, 167 | ] 168 | ) 169 | elif self.filt_size == 2: 170 | a = np.array([1.0, 1.0]) 171 | elif self.filt_size == 3: 172 | a = np.array([1.0, 2.0, 1.0]) 173 | elif self.filt_size == 4: 174 | a = np.array([1.0, 3.0, 3.0, 1.0]) 175 | elif self.filt_size == 5: 176 | a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) 177 | elif self.filt_size == 6: 178 | a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) 179 | elif self.filt_size == 7: 180 | a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) 181 | 182 | filt = torch.Tensor(a) 183 | filt = filt / torch.sum(filt) 184 | self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1))) 185 | 186 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 187 | 188 | def forward(self, inp): 189 | if self.filt_size == 1: 190 | return inp 191 | return F.conv1d(self.pad(inp), self.filt, groups=inp.shape[1]) 192 | 193 | 194 | def get_pad_layer_1d(pad_type): 195 | if pad_type in ["refl", "reflect"]: 196 | PadLayer = nn.ReflectionPad1d 197 | elif pad_type in ["repl", "replicate"]: 198 | PadLayer = nn.ReplicationPad1d 199 | elif pad_type == "zero": 200 | PadLayer = nn.ZeroPad1d 201 | else: 202 | print("Pad type [%s] not recognized" % pad_type) 203 | return PadLayer 204 | 205 | 206 | class MovingAverageSmooth(torch.nn.Conv1d): 207 | def __init__(self, channels, window_len=3): 208 | """Initialize Conv1d module.""" 209 | super(MovingAverageSmooth, self).__init__( 210 | in_channels=channels, 211 | out_channels=channels, 212 | kernel_size=1, 213 | groups=channels, 214 | bias=False, 215 | ) 216 | 217 | torch.nn.init.constant_(self.weight, 1.0 / window_len) 218 | for p in self.parameters(): 219 | p.requires_grad = False 220 | 221 | def forward(self, data): 222 | return super(MovingAverageSmooth, self).forward(data) 223 | 224 | 225 | class Conv1d(torch.nn.Conv1d): 226 | """Conv1d module with customized initialization.""" 227 | 228 | def __init__(self, *args, **kwargs): 229 | """Initialize Conv1d module.""" 230 | super(Conv1d, self).__init__(*args, **kwargs) 231 | 232 | def reset_parameters(self): 233 | """Reset parameters.""" 234 | torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") 235 | if self.bias is not None: 236 | torch.nn.init.constant_(self.bias, 0.0) 237 | 238 | 239 | class Stretch2d(torch.nn.Module): 240 | """Stretch2d module.""" 241 | 242 | def __init__(self, x_scale, y_scale, mode="nearest"): 243 | """Initialize Stretch2d module. 244 | Args: 245 | x_scale (int): X scaling factor (Time axis in spectrogram). 246 | y_scale (int): Y scaling factor (Frequency axis in spectrogram). 247 | mode (str): Interpolation mode. 248 | """ 249 | super(Stretch2d, self).__init__() 250 | self.x_scale = x_scale 251 | self.y_scale = y_scale 252 | self.mode = mode 253 | 254 | def forward(self, x): 255 | """Calculate forward propagation. 256 | Args: 257 | x (Tensor): Input tensor (B, C, F, T). 258 | Returns: 259 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), 260 | """ 261 | return F.interpolate( 262 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode 263 | ) 264 | 265 | 266 | class Conv2d(torch.nn.Conv2d): 267 | """Conv2d module with customized initialization.""" 268 | 269 | def __init__(self, *args, **kwargs): 270 | """Initialize Conv2d module.""" 271 | super(Conv2d, self).__init__(*args, **kwargs) 272 | 273 | def reset_parameters(self): 274 | """Reset parameters.""" 275 | self.weight.data.fill_(1.0 / np.prod(self.kernel_size)) 276 | if self.bias is not None: 277 | torch.nn.init.constant_(self.bias, 0.0) 278 | 279 | 280 | class UpsampleNetwork(torch.nn.Module): 281 | """Upsampling network module.""" 282 | 283 | def __init__( 284 | self, 285 | upsample_scales, 286 | nonlinear_activation=None, 287 | nonlinear_activation_params={}, 288 | interpolate_mode="nearest", 289 | freq_axis_kernel_size=1, 290 | use_causal_conv=False, 291 | ): 292 | """Initialize upsampling network module. 293 | Args: 294 | upsample_scales (list): List of upsampling scales. 295 | nonlinear_activation (str): Activation function name. 296 | nonlinear_activation_params (dict): Arguments for specified activation function. 297 | interpolate_mode (str): Interpolation mode. 298 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 299 | """ 300 | super(UpsampleNetwork, self).__init__() 301 | self.use_causal_conv = use_causal_conv 302 | self.up_layers = torch.nn.ModuleList() 303 | for scale in upsample_scales: 304 | # interpolation layer 305 | stretch = Stretch2d(scale, 1, interpolate_mode) 306 | self.up_layers += [stretch] 307 | 308 | # conv layer 309 | assert ( 310 | freq_axis_kernel_size - 1 311 | ) % 2 == 0, "Not support even number freq axis kernel size." 312 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 313 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1) 314 | if use_causal_conv: 315 | padding = (freq_axis_padding, scale * 2) 316 | else: 317 | padding = (freq_axis_padding, scale) 318 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 319 | self.up_layers += [conv] 320 | 321 | # nonlinear 322 | if nonlinear_activation is not None: 323 | nonlinear = getattr(torch.nn, nonlinear_activation)( 324 | **nonlinear_activation_params 325 | ) 326 | self.up_layers += [nonlinear] 327 | 328 | def forward(self, c): 329 | """Calculate forward propagation. 330 | Args: 331 | c : Input tensor (B, C, T). 332 | Returns: 333 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). 334 | """ 335 | c = c.unsqueeze(1) # (B, 1, C, T) 336 | for f in self.up_layers: 337 | if self.use_causal_conv and isinstance(f, Conv2d): 338 | c = f(c)[..., : c.size(-1)] 339 | else: 340 | c = f(c) 341 | return c.squeeze(1) # (B, C, T') 342 | 343 | 344 | class ConvInUpsampleNetwork(torch.nn.Module): 345 | """Convolution + upsampling network module.""" 346 | 347 | def __init__( 348 | self, 349 | upsample_scales=[3, 4, 5, 5], 350 | nonlinear_activation="ReLU", 351 | nonlinear_activation_params={}, 352 | interpolate_mode="nearest", 353 | freq_axis_kernel_size=1, 354 | aux_channels=80, 355 | aux_context_window=0, 356 | use_causal_conv=False, 357 | ): 358 | """Initialize convolution + upsampling network module. 359 | Args: 360 | upsample_scales (list): List of upsampling scales. 361 | nonlinear_activation (str): Activation function name. 362 | nonlinear_activation_params (dict): Arguments for specified activation function. 363 | mode (str): Interpolation mode. 364 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 365 | aux_channels (int): Number of channels of pre-convolutional layer. 366 | aux_context_window (int): Context window size of the pre-convolutional layer. 367 | use_causal_conv (bool): Whether to use causal structure. 368 | """ 369 | super(ConvInUpsampleNetwork, self).__init__() 370 | self.aux_context_window = aux_context_window 371 | self.use_causal_conv = use_causal_conv and aux_context_window > 0 372 | # To capture wide-context information in conditional features 373 | kernel_size = ( 374 | aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 375 | ) 376 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded 377 | self.conv_in = Conv1d( 378 | aux_channels, aux_channels, kernel_size=kernel_size, bias=False 379 | ) 380 | self.upsample = UpsampleNetwork( 381 | upsample_scales=upsample_scales, 382 | nonlinear_activation=nonlinear_activation, 383 | nonlinear_activation_params=nonlinear_activation_params, 384 | interpolate_mode=interpolate_mode, 385 | freq_axis_kernel_size=freq_axis_kernel_size, 386 | use_causal_conv=use_causal_conv, 387 | ) 388 | 389 | def forward(self, c): 390 | """Calculate forward propagation. 391 | Args: 392 | c : Input tensor (B, C, T'). 393 | Returns: 394 | Tensor: Upsampled tensor (B, C, T), 395 | where T = (T' - aux_context_window * 2) * prod(upsample_scales). 396 | Note: 397 | The length of inputs considers the context window size. 398 | """ 399 | c_ = self.conv_in(c) 400 | c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ 401 | return self.upsample(c) 402 | 403 | 404 | class DownsampleNet(nn.Module): 405 | def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0): 406 | super(DownsampleNet, self).__init__() 407 | self.input_size = input_size 408 | self.output_size = output_size 409 | self.upsample_factor = upsample_factor 410 | self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1) 411 | self.index = index 412 | layer = nn.Conv1d( 413 | input_size, 414 | output_size, 415 | kernel_size=upsample_factor * 2, 416 | stride=upsample_factor, 417 | padding=upsample_factor // 2 + upsample_factor % 2, 418 | ) 419 | 420 | self.layer = nn.utils.weight_norm(layer) 421 | 422 | def forward(self, inputs): 423 | B, C, T = inputs.size() 424 | res = inputs[:, :, :: self.upsample_factor] 425 | skip = self.skip_conv(res) 426 | 427 | outputs = self.layer(inputs) 428 | outputs = outputs + skip 429 | 430 | return outputs 431 | 432 | 433 | class UpsampleNet(nn.Module): 434 | def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0): 435 | 436 | super(UpsampleNet, self).__init__() 437 | self.up_type = Config.up_type 438 | self.use_smooth = Config.use_smooth 439 | self.use_drop = Config.use_drop 440 | self.input_size = input_size 441 | self.output_size = output_size 442 | self.upsample_factor = upsample_factor 443 | self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1) 444 | self.index = index 445 | if self.use_smooth: 446 | window_lens = [5, 5, 4, 3] 447 | self.window_len = window_lens[index] 448 | 449 | if self.up_type != "pn" or self.index < 3: 450 | # if self.up_type != "pn": 451 | layer = nn.ConvTranspose1d( 452 | input_size, 453 | output_size, 454 | upsample_factor * 2, 455 | upsample_factor, 456 | padding=upsample_factor // 2 + upsample_factor % 2, 457 | output_padding=upsample_factor % 2, 458 | ) 459 | self.layer = nn.utils.weight_norm(layer) 460 | else: 461 | self.layer = nn.Sequential( 462 | nn.ReflectionPad1d(1), 463 | nn.utils.weight_norm( 464 | nn.Conv1d(input_size, output_size * upsample_factor, kernel_size=3) 465 | ), 466 | nn.LeakyReLU(), 467 | nn.ReflectionPad1d(1), 468 | nn.utils.weight_norm( 469 | nn.Conv1d( 470 | output_size * upsample_factor, 471 | output_size * upsample_factor, 472 | kernel_size=3, 473 | ) 474 | ), 475 | nn.LeakyReLU(), 476 | nn.ReflectionPad1d(1), 477 | nn.utils.weight_norm( 478 | nn.Conv1d( 479 | output_size * upsample_factor, 480 | output_size * upsample_factor, 481 | kernel_size=3, 482 | ) 483 | ), 484 | nn.LeakyReLU(), 485 | ) 486 | 487 | if hp is not None: 488 | self.org = Config.up_org 489 | self.no_skip = Config.no_skip 490 | else: 491 | self.org = False 492 | self.no_skip = True 493 | 494 | if self.use_smooth: 495 | self.mas = nn.Sequential( 496 | # LowpassBlur(output_size, self.window_len), 497 | MovingAverageSmooth(output_size, self.window_len), 498 | # MovingAverageSmooth(output_size, self.window_len), 499 | ) 500 | 501 | def forward(self, inputs): 502 | 503 | if not self.org: 504 | inputs = inputs + torch.sin(inputs) 505 | B, C, T = inputs.size() 506 | res = inputs.repeat(1, self.upsample_factor, 1).view(B, C, -1) 507 | skip = self.skip_conv(res) 508 | if self.up_type == "repeat": 509 | return skip 510 | 511 | outputs = self.layer(inputs) 512 | if self.up_type == "pn" and self.index > 2: 513 | B, c, l = outputs.size() 514 | outputs = outputs.view(B, -1, l * self.upsample_factor) 515 | 516 | if self.no_skip: 517 | return outputs 518 | 519 | if not self.org: 520 | outputs = outputs + skip 521 | 522 | if self.use_smooth: 523 | outputs = self.mas(outputs) 524 | 525 | if self.use_drop: 526 | outputs = F.dropout(outputs, p=0.05) 527 | 528 | return outputs 529 | 530 | 531 | class ResStack(nn.Module): 532 | def __init__(self, channel, kernel_size=3, resstack_depth=4, hp=None): 533 | super(ResStack, self).__init__() 534 | 535 | self.use_wn = Config.use_wn 536 | self.use_shift_scale = Config.use_shift_scale 537 | self.channel = channel 538 | 539 | def get_padding(kernel_size, dilation=1): 540 | return int((kernel_size * dilation - dilation) / 2) 541 | 542 | if self.use_shift_scale: 543 | self.scale_conv = nn.utils.weight_norm( 544 | nn.Conv1d( 545 | channel, 2 * channel, kernel_size=kernel_size, dilation=1, padding=1 546 | ) 547 | ) 548 | 549 | if not self.use_wn: 550 | self.layers = nn.ModuleList( 551 | [ 552 | nn.Sequential( 553 | nn.LeakyReLU(), 554 | nn.utils.weight_norm( 555 | nn.Conv1d( 556 | channel, 557 | channel, 558 | kernel_size=kernel_size, 559 | dilation=3 ** (i % 10), 560 | padding=get_padding(kernel_size, 3 ** (i % 10)), 561 | ) 562 | ), 563 | nn.LeakyReLU(), 564 | nn.utils.weight_norm( 565 | nn.Conv1d( 566 | channel, 567 | channel, 568 | kernel_size=kernel_size, 569 | dilation=1, 570 | padding=get_padding(kernel_size, 1), 571 | ) 572 | ), 573 | ) 574 | for i in range(resstack_depth) 575 | ] 576 | ) 577 | else: 578 | self.wn = WaveNet( 579 | in_channels=channel, 580 | out_channels=channel, 581 | cin_channels=-1, 582 | num_layers=resstack_depth, 583 | residual_channels=channel, 584 | gate_channels=channel, 585 | skip_channels=channel, 586 | # kernel_size=5, 587 | # dilation_rate=3, 588 | causal=False, 589 | use_downup=False, 590 | ) 591 | 592 | def forward(self, x): 593 | if not self.use_wn: 594 | for layer in self.layers: 595 | x = x + layer(x) 596 | else: 597 | x = self.wn(x) 598 | 599 | if self.use_shift_scale: 600 | m_s = self.scale_conv(x) 601 | m_s = m_s[:, :, :-1] 602 | 603 | m, s = torch.split(m_s, self.channel, dim=1) 604 | s = F.softplus(s) 605 | 606 | x = m + s * x[:, :, 1:] # key!!! 607 | x = F.pad(x, pad=(1, 0), mode="constant", value=0) 608 | 609 | return x 610 | 611 | 612 | class WaveNet(nn.Module): 613 | def __init__( 614 | self, 615 | in_channels=1, 616 | out_channels=1, 617 | num_layers=10, 618 | residual_channels=64, 619 | gate_channels=64, 620 | skip_channels=64, 621 | kernel_size=3, 622 | dilation_rate=2, 623 | cin_channels=80, 624 | hp=None, 625 | causal=False, 626 | use_downup=False, 627 | ): 628 | super(WaveNet, self).__init__() 629 | 630 | self.in_channels = in_channels 631 | self.causal = causal 632 | self.num_layers = num_layers 633 | self.out_channels = out_channels 634 | self.gate_channels = gate_channels 635 | self.residual_channels = residual_channels 636 | self.skip_channels = skip_channels 637 | self.cin_channels = cin_channels 638 | self.kernel_size = kernel_size 639 | self.use_downup = use_downup 640 | 641 | self.front_conv = nn.Sequential( 642 | nn.Conv1d( 643 | in_channels=self.in_channels, 644 | out_channels=self.residual_channels, 645 | kernel_size=3, 646 | padding=1, 647 | ), 648 | nn.ReLU(), 649 | ) 650 | if self.use_downup: 651 | self.downup_conv = nn.Sequential( 652 | nn.Conv1d( 653 | in_channels=self.residual_channels, 654 | out_channels=self.residual_channels, 655 | kernel_size=3, 656 | stride=2, 657 | padding=1, 658 | ), 659 | nn.ReLU(), 660 | nn.Conv1d( 661 | in_channels=self.residual_channels, 662 | out_channels=self.residual_channels, 663 | kernel_size=3, 664 | stride=2, 665 | padding=1, 666 | ), 667 | nn.ReLU(), 668 | UpsampleNet(self.residual_channels, self.residual_channels, 4, hp), 669 | ) 670 | 671 | self.res_blocks = nn.ModuleList() 672 | for n in range(self.num_layers): 673 | self.res_blocks.append( 674 | ResBlock( 675 | self.residual_channels, 676 | self.gate_channels, 677 | self.skip_channels, 678 | self.kernel_size, 679 | dilation=dilation_rate**n, 680 | cin_channels=self.cin_channels, 681 | local_conditioning=(self.cin_channels > 0), 682 | causal=self.causal, 683 | mode="SAME", 684 | ) 685 | ) 686 | self.final_conv = nn.Sequential( 687 | nn.ReLU(), 688 | Conv(self.skip_channels, self.skip_channels, 1, causal=self.causal), 689 | nn.ReLU(), 690 | Conv(self.skip_channels, self.out_channels, 1, causal=self.causal), 691 | ) 692 | 693 | def forward(self, x, c=None): 694 | return self.wavenet(x, c) 695 | 696 | def wavenet(self, tensor, c=None): 697 | 698 | h = self.front_conv(tensor) 699 | if self.use_downup: 700 | h = self.downup_conv(h) 701 | skip = 0 702 | for i, f in enumerate(self.res_blocks): 703 | h, s = f(h, c) 704 | skip += s 705 | out = self.final_conv(skip) 706 | return out 707 | 708 | def receptive_field_size(self): 709 | num_dir = 1 if self.causal else 2 710 | dilations = [2 ** (i % self.num_layers) for i in range(self.num_layers)] 711 | return ( 712 | num_dir * (self.kernel_size - 1) * sum(dilations) 713 | + 1 714 | + (self.front_channels - 1) 715 | ) 716 | 717 | def remove_weight_norm(self): 718 | for f in self.res_blocks: 719 | f.remove_weight_norm() 720 | 721 | 722 | class Conv(nn.Module): 723 | def __init__( 724 | self, 725 | in_channels, 726 | out_channels, 727 | kernel_size, 728 | dilation=1, 729 | causal=False, 730 | mode="SAME", 731 | ): 732 | super(Conv, self).__init__() 733 | 734 | self.causal = causal 735 | self.mode = mode 736 | if self.causal and self.mode == "SAME": 737 | self.padding = dilation * (kernel_size - 1) 738 | elif self.mode == "SAME": 739 | self.padding = dilation * (kernel_size - 1) // 2 740 | else: 741 | self.padding = 0 742 | self.conv = nn.Conv1d( 743 | in_channels, 744 | out_channels, 745 | kernel_size, 746 | dilation=dilation, 747 | padding=self.padding, 748 | ) 749 | self.conv = nn.utils.weight_norm(self.conv) 750 | nn.init.kaiming_normal_(self.conv.weight) 751 | 752 | def forward(self, tensor): 753 | out = self.conv(tensor) 754 | if self.causal and self.padding is not 0: 755 | out = out[:, :, : -self.padding] 756 | return out 757 | 758 | def remove_weight_norm(self): 759 | nn.utils.remove_weight_norm(self.conv) 760 | 761 | 762 | class ResBlock(nn.Module): 763 | def __init__( 764 | self, 765 | in_channels, 766 | out_channels, 767 | skip_channels, 768 | kernel_size, 769 | dilation, 770 | cin_channels=None, 771 | local_conditioning=True, 772 | causal=False, 773 | mode="SAME", 774 | ): 775 | super(ResBlock, self).__init__() 776 | self.causal = causal 777 | self.local_conditioning = local_conditioning 778 | self.cin_channels = cin_channels 779 | self.mode = mode 780 | 781 | self.filter_conv = Conv( 782 | in_channels, out_channels, kernel_size, dilation, causal, mode 783 | ) 784 | self.gate_conv = Conv( 785 | in_channels, out_channels, kernel_size, dilation, causal, mode 786 | ) 787 | self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1) 788 | self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1) 789 | self.res_conv = nn.utils.weight_norm(self.res_conv) 790 | self.skip_conv = nn.utils.weight_norm(self.skip_conv) 791 | 792 | if self.local_conditioning: 793 | self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1) 794 | self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1) 795 | self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c) 796 | self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c) 797 | 798 | def forward(self, tensor, c=None): 799 | h_filter = self.filter_conv(tensor) 800 | h_gate = self.gate_conv(tensor) 801 | 802 | if self.local_conditioning: 803 | h_filter += self.filter_conv_c(c) 804 | h_gate += self.gate_conv_c(c) 805 | 806 | out = torch.tanh(h_filter) * torch.sigmoid(h_gate) 807 | 808 | res = self.res_conv(out) 809 | skip = self.skip_conv(out) 810 | if self.mode == "SAME": 811 | return (tensor + res) * math.sqrt(0.5), skip 812 | else: 813 | return (tensor[:, :, 1:] + res) * math.sqrt(0.5), skip 814 | 815 | def remove_weight_norm(self): 816 | self.filter_conv.remove_weight_norm() 817 | self.gate_conv.remove_weight_norm() 818 | nn.utils.remove_weight_norm(self.res_conv) 819 | nn.utils.remove_weight_norm(self.skip_conv) 820 | nn.utils.remove_weight_norm(self.filter_conv_c) 821 | nn.utils.remove_weight_norm(self.gate_conv_c) 822 | 823 | 824 | @torch.jit.script 825 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 826 | n_channels_int = n_channels[0] 827 | in_act = input_a + input_b 828 | t_act = torch.tanh(in_act[:, :n_channels_int]) 829 | s_act = torch.sigmoid(in_act[:, n_channels_int:]) 830 | acts = t_act * s_act 831 | return acts 832 | 833 | 834 | @torch.jit.script 835 | def fused_res_skip(tensor, res_skip, n_channels): 836 | n_channels_int = n_channels[0] 837 | res = res_skip[:, :n_channels_int] 838 | skip = res_skip[:, n_channels_int:] 839 | return (tensor + res), skip 840 | 841 | 842 | class ResStack2D(nn.Module): 843 | def __init__(self, channels=16, kernel_size=3, resstack_depth=4, hp=None): 844 | super(ResStack2D, self).__init__() 845 | channels = 16 846 | kernel_size = 3 847 | resstack_depth = 2 848 | self.channels = channels 849 | 850 | def get_padding(kernel_size, dilation=1): 851 | return int((kernel_size * dilation - dilation) / 2) 852 | 853 | self.layers = nn.ModuleList( 854 | [ 855 | nn.Sequential( 856 | nn.LeakyReLU(), 857 | nn.utils.weight_norm( 858 | nn.Conv2d( 859 | 1, 860 | self.channels, 861 | kernel_size, 862 | dilation=(1, 3 ** (i)), 863 | padding=(1, get_padding(kernel_size, 3 ** (i))), 864 | ) 865 | ), 866 | nn.LeakyReLU(), 867 | nn.utils.weight_norm( 868 | nn.Conv2d( 869 | self.channels, 870 | self.channels, 871 | kernel_size, 872 | dilation=(1, 3 ** (i)), 873 | padding=(1, get_padding(kernel_size, 3 ** (i))), 874 | ) 875 | ), 876 | nn.LeakyReLU(), 877 | nn.utils.weight_norm(nn.Conv2d(self.channels, 1, kernel_size=1)), 878 | ) 879 | for i in range(resstack_depth) 880 | ] 881 | ) 882 | 883 | def forward(self, tensor): 884 | x = tensor.unsqueeze(1) 885 | for layer in self.layers: 886 | x = x + layer(x) 887 | x = x.squeeze(1) 888 | 889 | return x 890 | 891 | 892 | class FiLM(nn.Module): 893 | """ 894 | feature-wise linear modulation 895 | """ 896 | 897 | def __init__(self, input_dim, attribute_dim): 898 | super().__init__() 899 | self.input_dim = input_dim 900 | self.generator = nn.Conv1d( 901 | attribute_dim, input_dim * 2, kernel_size=3, padding=1 902 | ) 903 | 904 | def forward(self, x, c): 905 | """ 906 | x: (B, input_dim, seq) 907 | c: (B, attribute_dim, seq) 908 | """ 909 | c = self.generator(c) 910 | m, s = torch.split(c, self.input_dim, dim=1) 911 | 912 | return x * s + m 913 | 914 | 915 | class FiLMConv1d(nn.Module): 916 | """ 917 | Conv1d with FiLMs in between 918 | """ 919 | 920 | def __init__(self, in_size, out_size, attribute_dim, ins_norm=True, loop=1): 921 | super().__init__() 922 | self.loop = loop 923 | self.mlps = nn.ModuleList( 924 | [nn.Conv1d(in_size, out_size, kernel_size=3, padding=1)] 925 | + [ 926 | nn.Conv1d(out_size, out_size, kernel_size=3, padding=1) 927 | for i in range(loop - 1) 928 | ] 929 | ) 930 | self.films = nn.ModuleList([FiLM(out_size, attribute_dim) for i in range(loop)]) 931 | self.ins_norm = ins_norm 932 | if self.ins_norm: 933 | self.norm = nn.InstanceNorm1d(attribute_dim) 934 | 935 | def forward(self, x, c): 936 | """ 937 | x: (B, input_dim, seq) 938 | c: (B, attribute_dim, seq) 939 | """ 940 | if self.ins_norm: 941 | c = self.norm(c) 942 | for i in range(self.loop): 943 | x = self.mlps[i](x) 944 | x = F.relu(x) 945 | x = self.films[i](x, c) 946 | 947 | return x 948 | --------------------------------------------------------------------------------