├── voicefixer ├── tools │ ├── modules │ │ ├── filters │ │ │ ├── f_2_64.mat │ │ │ ├── f_4_64.mat │ │ │ ├── f_8_64.mat │ │ │ ├── h_2_64.mat │ │ │ ├── h_4_64.mat │ │ │ └── h_8_64.mat │ │ ├── __init__.py │ │ ├── pqmf.py │ │ └── fDomainHelper.py │ ├── path.py │ ├── __init__.py │ ├── io.py │ ├── random_.py │ ├── pytorch_util.py │ ├── wav.py │ ├── base.py │ └── mel_scale.py ├── vocoder │ ├── model │ │ ├── __init__.py │ │ ├── pqmf.py │ │ ├── res_msd.py │ │ ├── util.py │ │ ├── generator.py │ │ └── modules.py │ ├── __init__.py │ ├── base.py │ └── config.py ├── __init__.py ├── restorer │ ├── __init__.py │ ├── model_kqq_bn.py │ ├── modules.py │ └── model.py ├── __main__.py └── base.py ├── bin ├── voicefixer.cmd └── voicefixer ├── .gitignore ├── test ├── figure.png ├── streamlit.png ├── utterance │ ├── output │ │ ├── oracle.flac │ │ ├── output_mode_0.flac │ │ ├── output_mode_1.flac │ │ └── output_mode_2.flac │ ├── target │ │ ├── oracle.flac │ │ ├── output_mode_0.flac │ │ ├── output_mode_1.flac │ │ └── output_mode_2.flac │ └── original │ │ ├── original.flac │ │ ├── original.wav │ │ └── p360_001_mic1.flac ├── streamlit.py ├── inference.py └── test.py ├── LICENSE ├── setup.py └── README.md /voicefixer/tools/modules/filters/f_2_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/f_4_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/f_8_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_2_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_4_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/filters/h_8_64.mat: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/voicefixer.cmd: -------------------------------------------------------------------------------- 1 | @echo OFF 2 | python -m voicefixer %* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS* 2 | __pycache__ 3 | dist 4 | *egg* 5 | .idea 6 | build -------------------------------------------------------------------------------- /test/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/figure.png -------------------------------------------------------------------------------- /test/streamlit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/streamlit.png -------------------------------------------------------------------------------- /test/utterance/output/oracle.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/oracle.flac -------------------------------------------------------------------------------- /test/utterance/target/oracle.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/oracle.flac -------------------------------------------------------------------------------- /test/utterance/original/original.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/original.flac -------------------------------------------------------------------------------- /test/utterance/original/original.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/original.wav -------------------------------------------------------------------------------- /test/utterance/output/output_mode_0.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_0.flac -------------------------------------------------------------------------------- /test/utterance/output/output_mode_1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_1.flac -------------------------------------------------------------------------------- /test/utterance/output/output_mode_2.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/output/output_mode_2.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_0.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_0.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_1.flac -------------------------------------------------------------------------------- /test/utterance/target/output_mode_2.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/target/output_mode_2.flac -------------------------------------------------------------------------------- /test/utterance/original/p360_001_mic1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/voicefixer/main/test/utterance/original/p360_001_mic1.flac -------------------------------------------------------------------------------- /voicefixer/tools/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def find_and_build(root, path): 5 | path = os.path.join(root, path) 6 | if not os.path.exists(path): 7 | os.makedirs(path, exist_ok=True) 8 | return path 9 | 10 | 11 | def root_path(repo_name="voicefixer"): 12 | path = os.path.abspath(__file__) 13 | return path.split(repo_name)[0] 14 | -------------------------------------------------------------------------------- /voicefixer/tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:28 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:29 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 1:00 AM Haohe Liu 1.0 None 11 | """ 12 | -------------------------------------------------------------------------------- /voicefixer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:31 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | from voicefixer.vocoder.base import Vocoder 14 | from voicefixer.base import VoiceFixer 15 | -------------------------------------------------------------------------------- /voicefixer/vocoder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 1:00 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import os 14 | from voicefixer.vocoder.config import Config 15 | import urllib.request 16 | 17 | if not os.path.exists(Config.ckpt): 18 | os.makedirs(os.path.dirname(Config.ckpt), exist_ok=True) 19 | print("Downloading the weight of neural vocoder: TFGAN") 20 | urllib.request.urlretrieve( 21 | "https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1", 22 | Config.ckpt, 23 | ) 24 | print( 25 | "Weights downloaded in: {} Size: {}".format( 26 | Config.ckpt, os.path.getsize(Config.ckpt) 27 | ) 28 | ) 29 | # cmd = "wget https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1 -O " + Config.ckpt 30 | # os.system(cmd) 31 | -------------------------------------------------------------------------------- /voicefixer/tools/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | 4 | 5 | def read_list(fname): 6 | result = [] 7 | with open(fname, "r") as f: 8 | for each in f.readlines(): 9 | each = each.strip("\n") 10 | result.append(each) 11 | return result 12 | 13 | 14 | def write_list(list, fname): 15 | with open(fname, "w") as f: 16 | for word in list: 17 | f.write(word) 18 | f.write("\n") 19 | 20 | 21 | def write_json(my_dict, fname): 22 | # print("Save json file at "+fname) 23 | json_str = json.dumps(my_dict) 24 | with open(fname, "w") as json_file: 25 | json_file.write(json_str) 26 | 27 | 28 | def load_json(fname): 29 | with open(fname, "r") as f: 30 | data = json.load(f) 31 | return data 32 | 33 | 34 | def save_pickle(obj, fname): 35 | # print("Save pickle at "+fname) 36 | with open(fname, "wb") as f: 37 | pickle.dump(obj, f) 38 | 39 | 40 | def load_pickle(fname): 41 | # print("Load pickle at "+fname) 42 | with open(fname, "rb") as f: 43 | res = pickle.load(f) 44 | return res 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /voicefixer/tools/random_.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | RANDOM_RESOLUTION = 2**31 5 | 6 | 7 | def random_torch(high, to_int=True): 8 | if to_int: 9 | return int((torch.rand(1)) * high) # do not use numpy.random.random 10 | else: 11 | return (torch.rand(1)) * high # do not use numpy.random.random 12 | 13 | 14 | def shuffle_torch(list): 15 | length = len(list) 16 | res = [] 17 | order = torch.randperm(length) 18 | for each in order: 19 | res.append(list[each]) 20 | assert len(list) == len(res) 21 | return res 22 | 23 | 24 | def random_choose_list(list): 25 | num = int(uniform_torch(0, len(list))) 26 | return list[num] 27 | 28 | 29 | def normal_torch(mean=0, segma=1): 30 | return float(torch.normal(mean=mean, std=torch.Tensor([segma]))[0]) 31 | 32 | 33 | def uniform_torch(lower, upper): 34 | if abs(lower - upper) < 1e-5: 35 | return upper 36 | return (upper - lower) * torch.rand(1) + lower 37 | 38 | 39 | def random_key(keys: list, weights: list): 40 | return random.choices(keys, weights=weights)[0] 41 | 42 | 43 | def random_select(probs): 44 | res = [] 45 | chance = random_torch(RANDOM_RESOLUTION) 46 | threshold = None 47 | for prob in probs: 48 | # if(threshold is None):threshold=prob 49 | # else:threshold*=prob 50 | threshold = prob 51 | res.append(chance < threshold * RANDOM_RESOLUTION) 52 | return res, chance 53 | -------------------------------------------------------------------------------- /voicefixer/restorer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : __init__.py.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 12:31 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import os 14 | import torch 15 | import urllib.request 16 | 17 | meta = { 18 | "voicefixer_fe": { 19 | "path": os.path.join( 20 | os.path.expanduser("~"), 21 | ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt", 22 | ), 23 | "url": "https://zenodo.org/record/5600188/files/vf.ckpt?download=1", 24 | }, 25 | } 26 | 27 | if not os.path.exists(meta["voicefixer_fe"]["path"]): 28 | os.makedirs(os.path.dirname(meta["voicefixer_fe"]["path"]), exist_ok=True) 29 | print("Downloading the main structure of voicefixer") 30 | 31 | urllib.request.urlretrieve( 32 | meta["voicefixer_fe"]["url"], meta["voicefixer_fe"]["path"] 33 | ) 34 | print( 35 | "Weights downloaded in: {} Size: {}".format( 36 | meta["voicefixer_fe"]["path"], 37 | os.path.getsize(meta["voicefixer_fe"]["path"]), 38 | ) 39 | ) 40 | 41 | # cmd = "wget "+ meta["voicefixer_fe"]['url'] + " -O " + meta["voicefixer_fe"]['path'] 42 | # os.system(cmd) 43 | # temp = torch.load(meta["voicefixer_fe"]['path']) 44 | # torch.save(temp['state_dict'], os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt")) 45 | -------------------------------------------------------------------------------- /test/streamlit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import librosa 5 | import soundfile 6 | import streamlit as st 7 | import torch 8 | from io import BytesIO 9 | from voicefixer import VoiceFixer 10 | 11 | 12 | @st.experimental_singleton 13 | def init_voicefixer(): 14 | return VoiceFixer() 15 | 16 | 17 | # init with global shared singleton instance 18 | voice_fixer = init_voicefixer() 19 | 20 | 21 | sample_rate = 44100 22 | 23 | 24 | st.write("Wav player") 25 | 26 | 27 | w = st.file_uploader("Upload a wav file", type="wav") 28 | 29 | 30 | if w: 31 | st.write("Inference : ") 32 | 33 | # choose options 34 | mode = st.radio( 35 | "Voice fixer modes (0: original mode, 1: Add preprocessing module 2: Train mode (may work sometimes on seriously degraded speech))", 36 | [0, 1, 2], 37 | ) 38 | if torch.cuda.is_available(): 39 | is_cuda = st.radio("Turn on GPU", [True, False]) 40 | if is_cuda != list(voice_fixer._model.parameters())[0].is_cuda: 41 | device = "cuda" if is_cuda else "cpu" 42 | voice_fixer._model = voice_fixer._model.to(device) 43 | else: 44 | is_cuda = False 45 | 46 | t1 = time.time() 47 | 48 | # Load audio from binary 49 | audio, _ = librosa.load(w, sr=sample_rate, mono=True) 50 | 51 | # Inference 52 | pred_wav = voice_fixer.restore_inmem(audio, mode=mode, cuda=is_cuda) 53 | 54 | pred_time = time.time() - t1 55 | 56 | # original audio 57 | st.write("Original Audio : ") 58 | 59 | st.audio(w) 60 | 61 | # predicted audio 62 | st.write("Predicted Audio : ") 63 | 64 | # make buffer 65 | with BytesIO() as buffer: 66 | soundfile.write(buffer, pred_wav.T, samplerate=sample_rate, format="WAV") 67 | st.write("Time: {:.3f}s".format(pred_time)) 68 | st.audio(buffer.getvalue(), format="audio/wav") 69 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/pqmf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import scipy.io.wavfile 7 | 8 | 9 | class PQMF(nn.Module): 10 | def __init__(self, N, M, file_path="utils/pqmf_hk_4_64.dat"): 11 | super().__init__() 12 | self.N = N # nsubband 13 | self.M = M # nfilter 14 | self.ana_conv_filter = nn.Conv1d( 15 | 1, out_channels=N, kernel_size=M, stride=N, bias=False 16 | ) 17 | data = np.reshape(np.fromfile(file_path, dtype=np.float32), (N, M)) 18 | data = np.flipud(data.T).T 19 | gk = data.copy() 20 | data = np.reshape(data, (N, 1, M)).copy() 21 | dict_new = self.ana_conv_filter.state_dict().copy() 22 | dict_new["weight"] = torch.from_numpy(data) 23 | self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) 24 | self.ana_conv_filter.load_state_dict(dict_new) 25 | 26 | self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) 27 | self.syn_conv_filter = nn.Conv1d( 28 | N, out_channels=N, kernel_size=M // N, stride=1, bias=False 29 | ) 30 | gk = np.transpose(np.reshape(gk, (4, 16, 4)), (1, 0, 2)) * N 31 | gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() 32 | dict_new = self.syn_conv_filter.state_dict().copy() 33 | dict_new["weight"] = torch.from_numpy(gk) 34 | self.syn_conv_filter.load_state_dict(dict_new) 35 | 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def analysis(self, inputs): 40 | return self.ana_conv_filter(self.ana_pad(inputs)) 41 | 42 | def synthesis(self, inputs): 43 | return self.syn_conv_filter(self.syn_pad(inputs)) 44 | 45 | def forward(self, inputs): 46 | return self.ana_conv_filter(self.ana_pad(inputs)) 47 | 48 | 49 | if __name__ == "__main__": 50 | a = PQMF(4, 64) 51 | # x = np.load('data/train/audio/010000.npy') 52 | x = np.zeros([8, 24000], np.float32) 53 | x = np.reshape(x, (8, 1, -1)) 54 | x = torch.from_numpy(x) 55 | b = a.analysis(x) 56 | c = a.synthesis(b) 57 | print(x.shape, b.shape, c.shape) 58 | b = (b * 32768).numpy() 59 | b = np.reshape(np.transpose(b, (0, 2, 1)), (-1, 1)).astype(np.int16) 60 | # b.tofile('1.pcm') 61 | # np.reshape(np.transpose(c.numpy()*32768, (0, 2, 1)), (-1,1)).astype(np.int16).tofile('2.pcm') 62 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/res_msd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResStack(nn.Module): 21 | def __init__(self, channels=384, kernel_size=3, resstack_depth=3, hp=None): 22 | super(ResStack, self).__init__() 23 | dilation = [2 * i + 1 for i in range(resstack_depth)] # [1, 3, 5] 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[i], 33 | padding=get_padding(kernel_size, dilation[i]), 34 | ) 35 | ) 36 | for i in range(resstack_depth) 37 | ] 38 | ) 39 | self.convs1.apply(init_weights) 40 | 41 | self.convs2 = nn.ModuleList( 42 | [ 43 | weight_norm( 44 | Conv1d( 45 | channels, 46 | channels, 47 | kernel_size, 48 | 1, 49 | dilation=1, 50 | padding=get_padding(kernel_size, 1), 51 | ) 52 | ) 53 | for i in range(resstack_depth) 54 | ] 55 | ) 56 | self.convs2.apply(init_weights) 57 | 58 | def forward(self, x): 59 | for c1, c2 in zip(self.convs1, self.convs2): 60 | xt = F.leaky_relu(x, LRELU_SLOPE) 61 | xt = c1(xt) 62 | xt = F.leaky_relu(xt, LRELU_SLOPE) 63 | xt = c2(xt) 64 | x = xt + x 65 | return x 66 | 67 | def remove_weight_norm(self): 68 | for l in self.convs1: 69 | remove_weight_norm(l) 70 | for l in self.convs2: 71 | remove_weight_norm(l) 72 | -------------------------------------------------------------------------------- /test/inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : inference.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/6/21 3:08 PM Haohe Liu 1.0 None 11 | """ 12 | 13 | from voicefixer import VoiceFixer 14 | from voicefixer import Vocoder 15 | 16 | from os.path import isdir, exists, basename, join 17 | from argparse import ArgumentParser 18 | from progressbar import * 19 | 20 | parser = ArgumentParser() 21 | 22 | parser.add_argument( 23 | "-i", 24 | "--input_file_path", 25 | default="/Users/liuhaohe/Desktop/test.wav", 26 | help="The .wav file or the audio folder to be processed", 27 | ) 28 | parser.add_argument( 29 | "-o", "--output_path", default=".", help="The output dirpath for the results" 30 | ) 31 | parser.add_argument("-m", "--models", default="voicefixer_fe") 32 | parser.add_argument( 33 | "--cuda", type=bool, default=False, help="Whether use GPU acceleration." 34 | ) 35 | args = parser.parse_args() 36 | 37 | if __name__ == "__main__": 38 | voicefixer = VoiceFixer() 39 | 40 | if not isdir(args.output_path): 41 | raise ValueError("Error: output path need to be a directory, not a file name.") 42 | if not exists(args.output_path): 43 | os.makedirs(args.output_path, exist_ok=True) 44 | 45 | if not isdir(args.input_file_path): 46 | assert ( 47 | args.input_file_path[-3:] == "wav" or args.input_file_path[-4:] == "flac" 48 | ), ( 49 | "Error: invalid file " 50 | + args.input_file_path 51 | + ", we only accept .wav and .flac file." 52 | ) 53 | output_path = join(args.output_path, basename(args.input_file_path)) 54 | print("Start Prediction.") 55 | voicefixer.restore( 56 | input=args.input_file_path, output=output_path, cuda=args.cuda 57 | ) 58 | else: 59 | files = os.listdir(args.input_file_path) 60 | print("Found", len(files), "files in", args.input_file_path) 61 | widgets = [ 62 | "Performing Resotartion", 63 | " [", 64 | Timer(), 65 | "] ", 66 | Bar(), 67 | " (", 68 | ETA(), 69 | ") ", 70 | ] 71 | pbar = ProgressBar(widgets=widgets).start() 72 | print("Start Prediction.") 73 | for i, file in enumerate(files): 74 | if not file[-3:] == "wav" and not file[-4:] == "flac": 75 | print( 76 | "Ignore file", 77 | file, 78 | " unsupported file type. Please use wav or flac format.", 79 | ) 80 | continue 81 | output_path = join(args.output_path, basename(file)) 82 | voicefixer.restore( 83 | input=join(args.input_file_path, file), 84 | output=output_path, 85 | cuda=args.cuda, 86 | ) 87 | pbar.update(int((i / (len(files))) * 100)) 88 | print("Congratulations! Prediction Complete.") 89 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : test.py 5 | @Contact : haoheliu@gmail.com 6 | @License : (C)Copyright 2020-2100 7 | 8 | @Modify Time @Author @Version @Desciption 9 | ------------ ------- -------- ----------- 10 | 9/14/21 11:02 AM Haohe Liu 1.0 None 11 | """ 12 | 13 | import git 14 | import os 15 | import sys 16 | import librosa 17 | import numpy as np 18 | import torch 19 | 20 | git_root = git.Repo("", search_parent_directories=True).git.rev_parse("--show-toplevel") 21 | sys.path.append(git_root) 22 | from voicefixer import VoiceFixer, Vocoder 23 | 24 | os.makedirs(os.path.join(git_root, "test/utterance/output"), exist_ok=True) 25 | 26 | 27 | def check(fname): 28 | """ 29 | check if the output is normal 30 | """ 31 | output = os.path.join(git_root, "test/utterance/output", fname) 32 | target = os.path.join(git_root, "test/utterance/target", fname) 33 | output, _ = librosa.load(output, sr=44100) 34 | target, _ = librosa.load(target, sr=44100) 35 | assert np.mean(np.abs(output - target)) < 0.01 36 | 37 | 38 | # TEST VOICEFIXER 39 | ## Initialize a voicefixer 40 | print("Initializing VoiceFixer...") 41 | voicefixer = VoiceFixer() 42 | # Mode 0: Original Model (suggested by default) 43 | # Mode 1: Add preprocessing module (remove higher frequency) 44 | # Mode 2: Train mode (might work sometimes on seriously degraded real speech) 45 | for mode in [0, 1, 2]: 46 | print("Test voicefixer mode", mode, end=", ") 47 | print("Using CPU:") 48 | voicefixer.restore( 49 | input=os.path.join( 50 | git_root, "test/utterance/original/original.flac" 51 | ), # low quality .wav/.flac file 52 | output=os.path.join( 53 | git_root, "test/utterance/output/output_mode_" + str(mode) + ".flac" 54 | ), # save file path 55 | cuda=False, # GPU acceleration 56 | mode=mode, 57 | ) 58 | if mode != 2: 59 | check("output_mode_" + str(mode) + ".flac") 60 | 61 | if torch.cuda.is_available(): 62 | print("Using GPU:") 63 | voicefixer.restore( 64 | input=os.path.join(git_root, "test/utterance/original/original.flac"), 65 | # low quality .wav/.flac file 66 | output=os.path.join( 67 | git_root, "test/utterance/output/output_mode_" + str(mode) + ".flac" 68 | ), 69 | # save file path 70 | cuda=True, # GPU acceleration 71 | mode=mode, 72 | ) 73 | if mode != 2: 74 | check("output_mode_" + str(mode) + ".flac") 75 | print("Pass") 76 | 77 | # TEST VOCODER 78 | ## Initialize a vocoder 79 | print("Initializing 44.1kHz speech vocoder...") 80 | vocoder = Vocoder(sample_rate=44100) 81 | 82 | ### read wave (fpath) -> mel spectrogram -> vocoder -> wave -> save wave (out_path) 83 | print("Test vocoder using groundtruth mel spectrogram...") 84 | print("Using CPU:") 85 | vocoder.oracle( 86 | fpath=os.path.join(git_root, "test/utterance/original/p360_001_mic1.flac"), 87 | out_path=os.path.join(git_root, "test/utterance/output/oracle.flac"), 88 | cuda=False, 89 | ) # GPU acceleration 90 | 91 | check("oracle.flac") 92 | 93 | if torch.cuda.is_available(): 94 | print("Using GPU:") 95 | vocoder.oracle( 96 | fpath=os.path.join(git_root, "test/utterance/original/p360_001_mic1.flac"), 97 | out_path=os.path.join(git_root, "test/utterance/output/oracle.flac"), 98 | cuda=True, 99 | ) # GPU acceleration 100 | # Another interface 101 | # vocoder.forward(mel=mel) 102 | check("oracle.flac") 103 | 104 | print("Pass") 105 | -------------------------------------------------------------------------------- /voicefixer/vocoder/base.py: -------------------------------------------------------------------------------- 1 | from voicefixer.vocoder.model.generator import Generator 2 | from voicefixer.tools.wav import read_wave, save_wave 3 | from voicefixer.tools.pytorch_util import * 4 | from voicefixer.vocoder.model.util import * 5 | from voicefixer.vocoder.config import Config 6 | import os 7 | import numpy as np 8 | 9 | 10 | class Vocoder(nn.Module): 11 | def __init__(self, sample_rate): 12 | super(Vocoder, self).__init__() 13 | Config.refresh(sample_rate) 14 | self.rate = sample_rate 15 | if(not os.path.exists(Config.ckpt)): 16 | raise RuntimeError("Error 1: The checkpoint for synthesis module / vocoder (model.ckpt-1490000_trimed) is not found in ~/.cache/voicefixer/synthesis_module/44100. \ 17 | By default the checkpoint should be download automatically by this program. Something bad may happened. Apologies for the inconvenience.\ 18 | But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/model.ckpt-1490000_trimed.pt?download=1") 19 | self._load_pretrain(Config.ckpt) 20 | self.weight_torch = Config.get_mel_weight_torch(percent=1.0)[ 21 | None, None, None, ... 22 | ] 23 | 24 | def _load_pretrain(self, pth): 25 | self.model = Generator(Config.cin_channels) 26 | checkpoint = load_checkpoint(pth, torch.device("cpu")) 27 | load_try(checkpoint["generator"], self.model) 28 | self.model.eval() 29 | self.model.remove_weight_norm() 30 | self.model.remove_weight_norm() 31 | for p in self.model.parameters(): 32 | p.requires_grad = False 33 | 34 | # def vocoder_mel_npy(self, mel, save_dir, sample_rate, gain): 35 | # mel = mel / Config.get_mel_weight(percent=gain)[...,None] 36 | # mel = normalize(amp_to_db(np.abs(mel)) - 20) 37 | # mel = pre(np.transpose(mel, (1, 0))) 38 | # with torch.no_grad(): 39 | # wav_re = self.model(mel) # torch.Size([1, 1, 104076]) 40 | # save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate) 41 | 42 | def forward(self, mel, cuda=False): 43 | """ 44 | :param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel] 45 | :return: [batchsize, 1, samples] 46 | """ 47 | assert mel.size()[-1] == 128 48 | check_cuda_availability(cuda=cuda) 49 | self.model = try_tensor_cuda(self.model, cuda=cuda) 50 | mel = try_tensor_cuda(mel, cuda=cuda) 51 | self.weight_torch = self.weight_torch.type_as(mel) 52 | mel = mel / self.weight_torch 53 | mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0) 54 | mel = tr_pre(mel[:, 0, ...]) 55 | wav_re = self.model(mel) 56 | return wav_re 57 | 58 | def oracle(self, fpath, out_path, cuda=False): 59 | check_cuda_availability(cuda=cuda) 60 | self.model = try_tensor_cuda(self.model, cuda=cuda) 61 | wav = read_wave(fpath, sample_rate=self.rate)[..., 0] 62 | wav = wav / np.max(np.abs(wav)) 63 | stft = np.abs( 64 | librosa.stft( 65 | wav, 66 | hop_length=Config.hop_length, 67 | win_length=Config.win_size, 68 | n_fft=Config.n_fft, 69 | ) 70 | ) 71 | mel = linear_to_mel(stft) 72 | mel = normalize(amp_to_db(np.abs(mel)) - 20) 73 | mel = pre(np.transpose(mel, (1, 0))) 74 | mel = try_tensor_cuda(mel, cuda=cuda) 75 | with torch.no_grad(): 76 | wav_re = self.model(mel) 77 | save_wave(tensor2numpy(wav_re * 2**15), out_path, sample_rate=self.rate) 78 | 79 | 80 | if __name__ == "__main__": 81 | model = Vocoder(sample_rate=44100) 82 | print(model.device) 83 | # model.load_pretrain(Config.ckpt) 84 | # model.oracle(path="/Users/liuhaohe/Desktop/test.wav", 85 | # sample_rate=44100, 86 | # save_dir="/Users/liuhaohe/Desktop/test_vocoder.wav") 87 | -------------------------------------------------------------------------------- /voicefixer/tools/modules/pqmf.py: -------------------------------------------------------------------------------- 1 | """ 2 | @File : subband_util.py 3 | @Contact : liu.8948@buckeyemail.osu.edu 4 | @License : (C)Copyright 2020-2021 5 | @Modify Time @Author @Version @Desciption 6 | ------------ ------- -------- ----------- 7 | 2020/4/3 4:54 PM Haohe Liu 1.0 None 8 | """ 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.nn as nn 13 | import numpy as np 14 | import os.path as op 15 | from scipy.io import loadmat 16 | 17 | 18 | def load_mat2numpy(fname=""): 19 | if len(fname) == 0: 20 | return None 21 | else: 22 | return loadmat(fname) 23 | 24 | 25 | class PQMF(nn.Module): 26 | def __init__(self, N, M, project_root): 27 | super().__init__() 28 | self.N = N # nsubband 29 | self.M = M # nfilter 30 | try: 31 | assert (N, M) in [(8, 64), (4, 64), (2, 64)] 32 | except: 33 | print("Warning:", N, "subbandand ", M, " filter is not supported") 34 | self.pad_samples = 64 35 | self.name = str(N) + "_" + str(M) + ".mat" 36 | self.ana_conv_filter = nn.Conv1d( 37 | 1, out_channels=N, kernel_size=M, stride=N, bias=False 38 | ) 39 | data = load_mat2numpy( 40 | op.join( 41 | project_root, 42 | "arnold_workspace/restorer/tools/pytorch/modules/filters/f_" 43 | + self.name, 44 | ) 45 | ) 46 | data = data["f"].astype(np.float32) / N 47 | data = np.flipud(data.T).T 48 | data = np.reshape(data, (N, 1, M)).copy() 49 | dict_new = self.ana_conv_filter.state_dict().copy() 50 | dict_new["weight"] = torch.from_numpy(data) 51 | self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) 52 | self.ana_conv_filter.load_state_dict(dict_new) 53 | 54 | self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) 55 | self.syn_conv_filter = nn.Conv1d( 56 | N, out_channels=N, kernel_size=M // N, stride=1, bias=False 57 | ) 58 | gk = load_mat2numpy( 59 | op.join( 60 | project_root, 61 | "arnold_workspace/restorer/tools/pytorch/modules/filters/h_" 62 | + self.name, 63 | ) 64 | ) 65 | gk = gk["h"].astype(np.float32) 66 | gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N 67 | gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() 68 | dict_new = self.syn_conv_filter.state_dict().copy() 69 | dict_new["weight"] = torch.from_numpy(gk) 70 | self.syn_conv_filter.load_state_dict(dict_new) 71 | 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def __analysis_channel(self, inputs): 76 | return self.ana_conv_filter(self.ana_pad(inputs)) 77 | 78 | def __systhesis_channel(self, inputs): 79 | ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1) 80 | return torch.reshape(ret, (ret.shape[0], 1, -1)) 81 | 82 | def analysis(self, inputs): 83 | """ 84 | :param inputs: [batchsize,channel,raw_wav],value:[0,1] 85 | :return: 86 | """ 87 | inputs = F.pad(inputs, ((0, self.pad_samples))) 88 | ret = None 89 | for i in range(inputs.size()[1]): # channels 90 | if ret is None: 91 | ret = self.__analysis_channel(inputs[:, i : i + 1, :]) 92 | else: 93 | ret = torch.cat( 94 | (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1 95 | ) 96 | return ret 97 | 98 | def synthesis(self, data): 99 | """ 100 | :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1] 101 | :return: 102 | """ 103 | ret = None 104 | # data = F.pad(data,((0,self.pad_samples//self.N))) 105 | for i in range(data.size()[1]): # channels 106 | if i % self.N == 0: 107 | if ret is None: 108 | ret = self.__systhesis_channel(data[:, i : i + self.N, :]) 109 | else: 110 | new = self.__systhesis_channel(data[:, i : i + self.N, :]) 111 | ret = torch.cat((ret, new), dim=1) 112 | ret = ret[..., : -self.pad_samples] 113 | return ret 114 | 115 | def forward(self, inputs): 116 | return self.ana_conv_filter(self.ana_pad(inputs)) 117 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/util.py: -------------------------------------------------------------------------------- 1 | from voicefixer.vocoder.config import Config 2 | from voicefixer.tools.pytorch_util import try_tensor_cuda, check_cuda_availability 3 | import torch 4 | import librosa 5 | import numpy as np 6 | 7 | 8 | def tr_normalize(S): 9 | if Config.allow_clipping_in_normalization: 10 | if Config.symmetric_mels: 11 | return torch.clip( 12 | (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db)) 13 | - Config.max_abs_value, 14 | -Config.max_abs_value, 15 | Config.max_abs_value, 16 | ) 17 | else: 18 | return torch.clip( 19 | Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)), 20 | 0, 21 | Config.max_abs_value, 22 | ) 23 | 24 | assert S.max() <= 0 and S.min() - Config.min_db >= 0 25 | if Config.symmetric_mels: 26 | return (2 * Config.max_abs_value) * ( 27 | (S - Config.min_db) / (-Config.min_db) 28 | ) - Config.max_abs_value 29 | else: 30 | return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)) 31 | 32 | 33 | def tr_amp_to_db(x): 34 | min_level = torch.exp(Config.min_level_db / 20 * torch.log(torch.tensor(10.0))) 35 | min_level = min_level.type_as(x) 36 | return 20 * torch.log10(torch.maximum(min_level, x)) 37 | 38 | 39 | def normalize(S): 40 | if Config.allow_clipping_in_normalization: 41 | if Config.symmetric_mels: 42 | return np.clip( 43 | (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db)) 44 | - Config.max_abs_value, 45 | -Config.max_abs_value, 46 | Config.max_abs_value, 47 | ) 48 | else: 49 | return np.clip( 50 | Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)), 51 | 0, 52 | Config.max_abs_value, 53 | ) 54 | 55 | assert S.max() <= 0 and S.min() - Config.min_db >= 0 56 | if Config.symmetric_mels: 57 | return (2 * Config.max_abs_value) * ( 58 | (S - Config.min_db) / (-Config.min_db) 59 | ) - Config.max_abs_value 60 | else: 61 | return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)) 62 | 63 | 64 | def amp_to_db(x): 65 | min_level = np.exp(Config.min_level_db / 20 * np.log(10)) 66 | return 20 * np.log10(np.maximum(min_level, x)) 67 | 68 | 69 | def tr_pre(npy): 70 | # conditions = torch.FloatTensor(npy).type_as(npy) # to(device) 71 | conditions = npy.transpose(1, 2) 72 | l = conditions.size(-1) 73 | pad_tail = l % 2 + 4 74 | zeros = ( 75 | torch.zeros([conditions.size()[0], Config.num_mels, pad_tail]).type_as( 76 | conditions 77 | ) 78 | + -4.0 79 | ) 80 | return torch.cat([conditions, zeros], dim=-1) 81 | 82 | 83 | def pre(npy): 84 | conditions = npy 85 | ## padding tail 86 | if type(conditions) == np.ndarray: 87 | conditions = torch.FloatTensor(conditions).unsqueeze(0) 88 | else: 89 | conditions = torch.FloatTensor(conditions.float()).unsqueeze(0) 90 | conditions = conditions.transpose(1, 2) 91 | l = conditions.size(-1) 92 | pad_tail = l % 2 + 4 93 | zeros = torch.zeros([1, Config.num_mels, pad_tail]) + -4.0 94 | return torch.cat([conditions, zeros], dim=-1) 95 | 96 | 97 | def load_try(state, model): 98 | model_dict = model.state_dict() 99 | try: 100 | model_dict.update(state) 101 | model.load_state_dict(model_dict) 102 | except RuntimeError as e: 103 | print(str(e)) 104 | model_dict = model.state_dict() 105 | for k, v in state.items(): 106 | model_dict[k] = v 107 | model.load_state_dict(model_dict) 108 | 109 | 110 | def load_checkpoint(checkpoint_path, device): 111 | checkpoint = torch.load(checkpoint_path, map_location=device) 112 | return checkpoint 113 | 114 | 115 | def build_mel_basis(): 116 | return librosa.filters.mel( 117 | Config.sample_rate, 118 | Config.n_fft, 119 | htk=True, 120 | n_mels=Config.num_mels, 121 | fmin=0, 122 | fmax=int(Config.sample_rate // 2), 123 | ) 124 | 125 | 126 | def linear_to_mel(spectogram): 127 | _mel_basis = build_mel_basis() 128 | return np.dot(_mel_basis, spectogram) 129 | 130 | 131 | if __name__ == "__main__": 132 | data = torch.randn((3, 5, 100)) 133 | b = normalize(amp_to_db(data.numpy())) 134 | a = tr_normalize(tr_amp_to_db(data)).numpy() 135 | print(a - b) 136 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # python3 setup.py sdist bdist_wheel 4 | """ 5 | @File : setup.py.py 6 | @Contact : haoheliu@gmail.com 7 | @License : (C)Copyright 2020-2100 8 | 9 | @Modify Time @Author @Version @Desciption 10 | ------------ ------- -------- ----------- 11 | 9/6/21 5:16 PM Haohe Liu 1.0 None 12 | """ 13 | 14 | # !/usr/bin/env python 15 | # -*- coding: utf-8 -*- 16 | 17 | # Note: To use the 'upload' functionality of this file, you must: 18 | # $ pipenv install twine --dev 19 | 20 | import io 21 | import os 22 | import sys 23 | from shutil import rmtree 24 | 25 | from setuptools import find_packages, setup, Command 26 | 27 | # Package meta-data. 28 | NAME = "voicefixer" 29 | DESCRIPTION = "This package is written for the restoration of degraded speech" 30 | URL = "https://github.com/haoheliu/voicefixer" 31 | EMAIL = "haoheliu@gmail.com" 32 | AUTHOR = "Haohe Liu" 33 | REQUIRES_PYTHON = ">=3.7.0" 34 | VERSION = "0.1.2" 35 | 36 | # What packages are required for this module to be executed? 37 | REQUIRED = [ 38 | "librosa>=0.8.1,<0.9.0", 39 | "matplotlib", 40 | "torch>=1.7.0", 41 | "progressbar", 42 | "torchlibrosa==0.0.7", 43 | "GitPython", 44 | "streamlit>=1.12.0" 45 | "pyyaml", 46 | ] 47 | 48 | # What packages are optional? 49 | EXTRAS = {} 50 | 51 | # The rest you shouldn't have to touch too much :) 52 | # ------------------------------------------------ 53 | # Except, perhaps the License and Trove Classifiers! 54 | # If you do change the License, remember to change the Trove Classifier for that! 55 | 56 | here = os.path.abspath(os.path.dirname(__file__)) 57 | 58 | # Import the README and use it as the long-description. 59 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 60 | try: 61 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 62 | long_description = "\n" + f.read() 63 | except FileNotFoundError: 64 | long_description = DESCRIPTION 65 | 66 | # Load the package's __version__.py module as a dictionary. 67 | about = {} 68 | if not VERSION: 69 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 70 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 71 | exec(f.read(), about) 72 | else: 73 | about["__version__"] = VERSION 74 | 75 | 76 | class UploadCommand(Command): 77 | """Support setup.py upload.""" 78 | 79 | description = "Build and publish the package." 80 | user_options = [] 81 | 82 | @staticmethod 83 | def status(s): 84 | """Prints things in bold.""" 85 | print("\033[1m{0}\033[0m".format(s)) 86 | 87 | def initialize_options(self): 88 | pass 89 | 90 | def finalize_options(self): 91 | pass 92 | 93 | def run(self): 94 | try: 95 | self.status("Removing previous builds…") 96 | rmtree(os.path.join(here, "dist")) 97 | except OSError: 98 | pass 99 | 100 | self.status("Building Source and Wheel (universal) distribution…") 101 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 102 | 103 | self.status("Uploading the package to PyPI via Twine…") 104 | os.system("twine upload dist/*") 105 | 106 | self.status("Pushing git tags…") 107 | os.system("git tag v{0}".format(about["__version__"])) 108 | os.system("git push --tags") 109 | 110 | sys.exit() 111 | 112 | 113 | # Where the magic happens: 114 | setup( 115 | name=NAME, 116 | version=about["__version__"], 117 | description=DESCRIPTION, 118 | long_description=long_description, 119 | long_description_content_type="text/markdown", 120 | author=AUTHOR, 121 | author_email=EMAIL, 122 | python_requires=REQUIRES_PYTHON, 123 | url=URL, 124 | # packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 125 | # If your package is a single module, use this instead of 'packages': 126 | py_modules=["voicefixer"], 127 | # entry_points={ 128 | # 'console_scripts': ['mycli=mymodule:cli'], 129 | # }, 130 | install_requires=REQUIRED, 131 | extras_require=EXTRAS, 132 | packages=find_packages(), 133 | include_package_data=True, 134 | license="MIT", 135 | classifiers=[ 136 | # Trove classifiers 137 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 138 | "License :: OSI Approved :: MIT License", 139 | "Programming Language :: Python", 140 | "Programming Language :: Python :: 3", 141 | "Programming Language :: Python :: 3.7", 142 | "Programming Language :: Python :: Implementation :: CPython", 143 | "Programming Language :: Python :: Implementation :: PyPy", 144 | ], 145 | # $ setup.py publish support. 146 | cmdclass={ 147 | "upload": UploadCommand, 148 | }, 149 | scripts=['bin/voicefixer.cmd', "bin/voicefixer"] 150 | ) 151 | -------------------------------------------------------------------------------- /voicefixer/tools/pytorch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def check_cuda_availability(cuda): 7 | if cuda and not torch.cuda.is_available(): 8 | raise RuntimeError("Error: You set cuda=True but no cuda device found.") 9 | 10 | 11 | def try_tensor_cuda(tensor, cuda): 12 | if cuda and torch.cuda.is_available(): 13 | return tensor.cuda() 14 | else: 15 | return tensor.cpu() 16 | 17 | 18 | def to_log(input): 19 | assert torch.sum(input < 0) == 0, ( 20 | str(input) + " has negative values counts " + str(torch.sum(input < 0)) 21 | ) 22 | return torch.log10(torch.clip(input, min=1e-8)) 23 | 24 | 25 | def from_log(input): 26 | input = torch.clip(input, min=-np.inf, max=5) 27 | return 10**input 28 | 29 | 30 | def move_data_to_device(x, device): 31 | if "float" in str(x.dtype): 32 | x = torch.Tensor(x) 33 | elif "int" in str(x.dtype): 34 | x = torch.LongTensor(x) 35 | else: 36 | return x 37 | return x.to(device) 38 | 39 | 40 | def tensor2numpy(tensor): 41 | if "cuda" in str(tensor.device): 42 | return tensor.detach().cpu().numpy() 43 | else: 44 | return tensor.detach().numpy() 45 | 46 | 47 | def count_parameters(model): 48 | for p in model.parameters(): 49 | if p.requires_grad: 50 | print(p.shape) 51 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 52 | 53 | 54 | def count_flops(model, audio_length): 55 | multiply_adds = False 56 | list_conv2d = [] 57 | 58 | def conv2d_hook(self, input, output): 59 | batch_size, input_channels, input_height, input_width = input[0].size() 60 | output_channels, output_height, output_width = output[0].size() 61 | 62 | kernel_ops = ( 63 | self.kernel_size[0] 64 | * self.kernel_size[1] 65 | * (self.in_channels / self.groups) 66 | * (2 if multiply_adds else 1) 67 | ) 68 | bias_ops = 1 if self.bias is not None else 0 69 | 70 | params = output_channels * (kernel_ops + bias_ops) 71 | flops = batch_size * params * output_height * output_width 72 | 73 | list_conv2d.append(flops) 74 | 75 | list_conv1d = [] 76 | 77 | def conv1d_hook(self, input, output): 78 | batch_size, input_channels, input_length = input[0].size() 79 | output_channels, output_length = output[0].size() 80 | 81 | kernel_ops = ( 82 | self.kernel_size[0] 83 | * (self.in_channels / self.groups) 84 | * (2 if multiply_adds else 1) 85 | ) 86 | bias_ops = 1 if self.bias is not None else 0 87 | 88 | params = output_channels * (kernel_ops + bias_ops) 89 | flops = batch_size * params * output_length 90 | 91 | list_conv1d.append(flops) 92 | 93 | list_linear = [] 94 | 95 | def linear_hook(self, input, output): 96 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 97 | 98 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 99 | bias_ops = self.bias.nelement() 100 | 101 | flops = batch_size * (weight_ops + bias_ops) 102 | list_linear.append(flops) 103 | 104 | list_bn = [] 105 | 106 | def bn_hook(self, input, output): 107 | list_bn.append(input[0].nelement()) 108 | 109 | list_relu = [] 110 | 111 | def relu_hook(self, input, output): 112 | list_relu.append(input[0].nelement()) 113 | 114 | list_pooling2d = [] 115 | 116 | def pooling2d_hook(self, input, output): 117 | batch_size, input_channels, input_height, input_width = input[0].size() 118 | output_channels, output_height, output_width = output[0].size() 119 | 120 | kernel_ops = self.kernel_size * self.kernel_size 121 | bias_ops = 0 122 | params = output_channels * (kernel_ops + bias_ops) 123 | flops = batch_size * params * output_height * output_width 124 | 125 | list_pooling2d.append(flops) 126 | 127 | list_pooling1d = [] 128 | 129 | def pooling1d_hook(self, input, output): 130 | batch_size, input_channels, input_length = input[0].size() 131 | output_channels, output_length = output[0].size() 132 | 133 | kernel_ops = self.kernel_size 134 | bias_ops = 0 135 | params = output_channels * (kernel_ops + bias_ops) 136 | flops = batch_size * params * output_length 137 | 138 | list_pooling2d.append(flops) 139 | 140 | def foo(net): 141 | childrens = list(net.children()) 142 | if not childrens: 143 | if isinstance(net, nn.Conv2d): 144 | net.register_forward_hook(conv2d_hook) 145 | elif isinstance(net, nn.ConvTranspose2d): 146 | net.register_forward_hook(conv2d_hook) 147 | elif isinstance(net, nn.Conv1d): 148 | net.register_forward_hook(conv1d_hook) 149 | elif isinstance(net, nn.Linear): 150 | net.register_forward_hook(linear_hook) 151 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 152 | net.register_forward_hook(bn_hook) 153 | elif isinstance(net, nn.ReLU): 154 | net.register_forward_hook(relu_hook) 155 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 156 | net.register_forward_hook(pooling2d_hook) 157 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 158 | net.register_forward_hook(pooling1d_hook) 159 | else: 160 | print("Warning: flop of module {} is not counted!".format(net)) 161 | return 162 | for c in childrens: 163 | foo(c) 164 | 165 | foo(model) 166 | 167 | input = torch.rand(1, audio_length, 2) 168 | out = model(input) 169 | 170 | total_flops = ( 171 | sum(list_conv2d) 172 | + sum(list_conv1d) 173 | + sum(list_linear) 174 | + sum(list_bn) 175 | + sum(list_relu) 176 | + sum(list_pooling2d) 177 | + sum(list_pooling1d) 178 | ) 179 | 180 | return total_flops 181 | -------------------------------------------------------------------------------- /voicefixer/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | from genericpath import exists 3 | import os.path 4 | import argparse 5 | from voicefixer import VoiceFixer 6 | import torch 7 | import os 8 | 9 | 10 | def writefile(infile, outfile, mode, append_mode, cuda, verbose=False): 11 | if append_mode is True: 12 | outbasename, outext = os.path.splitext(os.path.basename(outfile)) 13 | outfile = os.path.join( 14 | os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext) 15 | ) 16 | if verbose: 17 | print("Processing {}, mode={}".format(infile, mode)) 18 | voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode)) 19 | 20 | def check_arguments(args): 21 | process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0 22 | # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \ 23 | # "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile) 24 | # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \ 25 | # "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder) 26 | assert ( 27 | process_file or process_folder 28 | ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h" 29 | 30 | # if(args.cuda and not torch.cuda.is_available()): 31 | # print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.") 32 | if process_file: 33 | assert os.path.exists(args.infile), ( 34 | "Error: The input file %s is not found." % args.infile 35 | ) 36 | output_dirname = os.path.dirname(args.outfile) 37 | if len(output_dirname) > 1: 38 | os.makedirs(output_dirname, exist_ok=True) 39 | if process_folder: 40 | assert os.path.exists(args.infolder), ( 41 | "Error: The input folder %s is not found." % args.infile 42 | ) 43 | output_dirname = args.outfolder 44 | if len(output_dirname) > 1: 45 | os.makedirs(args.outfolder, exist_ok=True) 46 | 47 | return process_file, process_folder 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="VoiceFixer - restores degraded speech" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--infile", 57 | type=str, 58 | default="", 59 | help="An input file to be processed by VoiceFixer.", 60 | ) 61 | parser.add_argument( 62 | "-o", 63 | "--outfile", 64 | type=str, 65 | default="outfile.wav", 66 | help="An output file to store the result.", 67 | ) 68 | 69 | parser.add_argument( 70 | "-ifdr", 71 | "--infolder", 72 | type=str, 73 | default="", 74 | help="Input folder. Place all your wav file that need process in this folder.", 75 | ) 76 | parser.add_argument( 77 | "-ofdr", 78 | "--outfolder", 79 | type=str, 80 | default="outfolder", 81 | help="Output folder. The processed files will be stored in this folder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--mode", help="mode", choices=["0", "1", "2", "all"], default="0" 86 | ) 87 | parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true") 88 | parser.add_argument( 89 | "--silent", 90 | help="Set this flag if you do not want to see any message.", 91 | default=False, 92 | action="store_true", 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | if torch.cuda.is_available() and not args.disable_cuda: 98 | cuda = True 99 | else: 100 | cuda = False 101 | 102 | process_file, process_folder = check_arguments(args) 103 | 104 | if not args.silent: 105 | print("Initializing VoiceFixer") 106 | voicefixer = VoiceFixer() 107 | 108 | if not args.silent: 109 | print("Start processing the input file %s." % args.infile) 110 | 111 | if process_file: 112 | audioext = os.path.splitext(os.path.basename(args.infile))[-1] 113 | if audioext != ".wav": 114 | raise ValueError( 115 | "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks." 116 | % audioext 117 | ) 118 | if args.mode == "all": 119 | for file_mode in range(3): 120 | writefile( 121 | args.infile, 122 | args.outfile, 123 | file_mode, 124 | True, 125 | cuda, 126 | verbose=not args.silent, 127 | ) 128 | else: 129 | writefile( 130 | args.infile, 131 | args.outfile, 132 | args.mode, 133 | False, 134 | cuda, 135 | verbose=not args.silent, 136 | ) 137 | 138 | if process_folder: 139 | files = [ 140 | file 141 | for file in os.listdir(args.infolder) 142 | if (os.path.splitext(os.path.basename(file))[-1] == ".wav") 143 | ] 144 | if not args.silent: 145 | print( 146 | "Found %s .wav files in the input folder %s. Start processing." 147 | % (len(files), args.infolder) 148 | ) 149 | for file in files: 150 | outbasename, outext = os.path.splitext(os.path.basename(file)) 151 | in_file = os.path.join(args.infolder, file) 152 | out_file = os.path.join(args.outfolder, file) 153 | 154 | if args.mode == "all": 155 | for file_mode in range(3): 156 | writefile( 157 | in_file, 158 | out_file, 159 | file_mode, 160 | True, 161 | cuda, 162 | verbose=not args.silent, 163 | ) 164 | else: 165 | writefile( 166 | in_file, out_file, args.mode, False, cuda, verbose=not args.silent 167 | ) 168 | 169 | if not args.silent: 170 | print("Done") 171 | -------------------------------------------------------------------------------- /bin/voicefixer: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | from genericpath import exists 3 | import os.path 4 | import argparse 5 | from voicefixer import VoiceFixer 6 | import torch 7 | import os 8 | 9 | 10 | def writefile(infile, outfile, mode, append_mode, cuda, verbose=False): 11 | if append_mode is True: 12 | outbasename, outext = os.path.splitext(os.path.basename(outfile)) 13 | outfile = os.path.join( 14 | os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext) 15 | ) 16 | if verbose: 17 | print("Processing {}, mode={}".format(infile, mode)) 18 | voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode)) 19 | 20 | def check_arguments(args): 21 | process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0 22 | # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \ 23 | # "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile) 24 | # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \ 25 | # "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder) 26 | assert ( 27 | process_file or process_folder 28 | ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h" 29 | 30 | # if(args.cuda and not torch.cuda.is_available()): 31 | # print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.") 32 | if process_file: 33 | assert os.path.exists(args.infile), ( 34 | "Error: The input file %s is not found." % args.infile 35 | ) 36 | output_dirname = os.path.dirname(args.outfile) 37 | if len(output_dirname) > 1: 38 | os.makedirs(output_dirname, exist_ok=True) 39 | if process_folder: 40 | assert os.path.exists(args.infolder), ( 41 | "Error: The input folder %s is not found." % args.infile 42 | ) 43 | output_dirname = args.outfolder 44 | if len(output_dirname) > 1: 45 | os.makedirs(args.outfolder, exist_ok=True) 46 | 47 | return process_file, process_folder 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="VoiceFixer - restores degraded speech" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--infile", 57 | type=str, 58 | default="", 59 | help="An input file to be processed by VoiceFixer.", 60 | ) 61 | parser.add_argument( 62 | "-o", 63 | "--outfile", 64 | type=str, 65 | default="outfile.wav", 66 | help="An output file to store the result.", 67 | ) 68 | 69 | parser.add_argument( 70 | "-ifdr", 71 | "--infolder", 72 | type=str, 73 | default="", 74 | help="Input folder. Place all your wav file that need process in this folder.", 75 | ) 76 | parser.add_argument( 77 | "-ofdr", 78 | "--outfolder", 79 | type=str, 80 | default="outfolder", 81 | help="Output folder. The processed files will be stored in this folder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--mode", help="mode", choices=["0", "1", "2", "all"], default="0" 86 | ) 87 | parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true") 88 | parser.add_argument( 89 | "--silent", 90 | help="Set this flag if you do not want to see any message.", 91 | default=False, 92 | action="store_true", 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | if torch.cuda.is_available() and not args.disable_cuda: 98 | cuda = True 99 | else: 100 | cuda = False 101 | 102 | process_file, process_folder = check_arguments(args) 103 | 104 | if not args.silent: 105 | print("Initializing VoiceFixer") 106 | voicefixer = VoiceFixer() 107 | 108 | if not args.silent: 109 | print("Start processing the input file %s." % args.infile) 110 | 111 | if process_file: 112 | audioext = os.path.splitext(os.path.basename(args.infile))[-1] 113 | if audioext != ".wav": 114 | raise ValueError( 115 | "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks." 116 | % audioext 117 | ) 118 | if args.mode == "all": 119 | for file_mode in range(3): 120 | writefile( 121 | args.infile, 122 | args.outfile, 123 | file_mode, 124 | True, 125 | cuda, 126 | verbose=not args.silent, 127 | ) 128 | else: 129 | writefile( 130 | args.infile, 131 | args.outfile, 132 | args.mode, 133 | False, 134 | cuda, 135 | verbose=not args.silent, 136 | ) 137 | 138 | if process_folder: 139 | if not args.silent: 140 | files = [ 141 | file 142 | for file in os.listdir(args.infolder) 143 | if (os.path.splitext(os.path.basename(file))[-1] == ".wav") 144 | ] 145 | print( 146 | "Found %s .wav files in the input folder %s. Start processing." 147 | % (len(files), args.infolder) 148 | ) 149 | for file in os.listdir(args.infolder): 150 | outbasename, outext = os.path.splitext(os.path.basename(file)) 151 | in_file = os.path.join(args.infolder, file) 152 | out_file = os.path.join(args.outfolder, file) 153 | 154 | if args.mode == "all": 155 | for file_mode in range(3): 156 | writefile( 157 | in_file, 158 | out_file, 159 | file_mode, 160 | True, 161 | cuda, 162 | verbose=not args.silent, 163 | ) 164 | else: 165 | writefile( 166 | in_file, out_file, args.mode, False, cuda, verbose=not args.silent 167 | ) 168 | 169 | if not args.silent: 170 | print("Done") -------------------------------------------------------------------------------- /voicefixer/restorer/model_kqq_bn.py: -------------------------------------------------------------------------------- 1 | from voicefixer.restorer.modules import * 2 | 3 | from voicefixer.tools.pytorch_util import * 4 | 5 | 6 | class UNetResComplex_100Mb(nn.Module): 7 | def __init__(self, channels, nsrc=1): 8 | super(UNetResComplex_100Mb, self).__init__() 9 | activation = "relu" 10 | momentum = 0.01 11 | 12 | self.nsrc = nsrc 13 | self.channels = channels 14 | self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks} 15 | 16 | self.encoder_block1 = EncoderBlockRes( 17 | in_channels=channels * nsrc, 18 | out_channels=32, 19 | downsample=(2, 2), 20 | activation=activation, 21 | momentum=momentum, 22 | ) 23 | self.encoder_block2 = EncoderBlockRes( 24 | in_channels=32, 25 | out_channels=64, 26 | downsample=(2, 2), 27 | activation=activation, 28 | momentum=momentum, 29 | ) 30 | self.encoder_block3 = EncoderBlockRes( 31 | in_channels=64, 32 | out_channels=128, 33 | downsample=(2, 2), 34 | activation=activation, 35 | momentum=momentum, 36 | ) 37 | self.encoder_block4 = EncoderBlockRes( 38 | in_channels=128, 39 | out_channels=256, 40 | downsample=(2, 2), 41 | activation=activation, 42 | momentum=momentum, 43 | ) 44 | self.encoder_block5 = EncoderBlockRes( 45 | in_channels=256, 46 | out_channels=384, 47 | downsample=(2, 2), 48 | activation=activation, 49 | momentum=momentum, 50 | ) 51 | self.encoder_block6 = EncoderBlockRes( 52 | in_channels=384, 53 | out_channels=384, 54 | downsample=(2, 2), 55 | activation=activation, 56 | momentum=momentum, 57 | ) 58 | self.conv_block7 = ConvBlockRes( 59 | in_channels=384, 60 | out_channels=384, 61 | size=3, 62 | activation=activation, 63 | momentum=momentum, 64 | ) 65 | self.decoder_block1 = DecoderBlockRes( 66 | in_channels=384, 67 | out_channels=384, 68 | stride=(2, 2), 69 | activation=activation, 70 | momentum=momentum, 71 | ) 72 | self.decoder_block2 = DecoderBlockRes( 73 | in_channels=384, 74 | out_channels=384, 75 | stride=(2, 2), 76 | activation=activation, 77 | momentum=momentum, 78 | ) 79 | self.decoder_block3 = DecoderBlockRes( 80 | in_channels=384, 81 | out_channels=256, 82 | stride=(2, 2), 83 | activation=activation, 84 | momentum=momentum, 85 | ) 86 | self.decoder_block4 = DecoderBlockRes( 87 | in_channels=256, 88 | out_channels=128, 89 | stride=(2, 2), 90 | activation=activation, 91 | momentum=momentum, 92 | ) 93 | self.decoder_block5 = DecoderBlockRes( 94 | in_channels=128, 95 | out_channels=64, 96 | stride=(2, 2), 97 | activation=activation, 98 | momentum=momentum, 99 | ) 100 | self.decoder_block6 = DecoderBlockRes( 101 | in_channels=64, 102 | out_channels=32, 103 | stride=(2, 2), 104 | activation=activation, 105 | momentum=momentum, 106 | ) 107 | 108 | self.after_conv_block1 = ConvBlockRes( 109 | in_channels=32, 110 | out_channels=32, 111 | size=3, 112 | activation=activation, 113 | momentum=momentum, 114 | ) 115 | 116 | self.after_conv2 = nn.Conv2d( 117 | in_channels=32, 118 | out_channels=1, 119 | kernel_size=(1, 1), 120 | stride=(1, 1), 121 | padding=(0, 0), 122 | bias=True, 123 | ) 124 | 125 | self.init_weights() 126 | 127 | def init_weights(self): 128 | init_layer(self.after_conv2) 129 | 130 | def forward(self, sp): 131 | """ 132 | Args: 133 | input: (batch_size, channels_num, segment_samples) 134 | 135 | Outputs: 136 | output_dict: { 137 | 'wav': (batch_size, channels_num, segment_samples), 138 | 'sp': (batch_size, channels_num, time_steps, freq_bins)} 139 | """ 140 | 141 | # Batch normalization 142 | x = sp 143 | 144 | # Pad spectrogram to be evenly divided by downsample ratio. 145 | origin_len = x.shape[2] # time_steps 146 | pad_len = ( 147 | int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio 148 | - origin_len 149 | ) 150 | x = F.pad(x, pad=(0, 0, 0, pad_len)) 151 | x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) 152 | 153 | # UNet 154 | (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2) 155 | (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4) 156 | (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8) 157 | (x4_pool, x4) = self.encoder_block4( 158 | x3_pool 159 | ) # x4_pool: (bs, 256, T / 16, F / 16) 160 | (x5_pool, x5) = self.encoder_block5( 161 | x4_pool 162 | ) # x5_pool: (bs, 512, T / 32, F / 32) 163 | (x6_pool, x6) = self.encoder_block6( 164 | x5_pool 165 | ) # x6_pool: (bs, 1024, T / 64, F / 64) 166 | x_center = self.conv_block7(x6_pool) # (bs, 2048, T / 64, F / 64) 167 | x7 = self.decoder_block1(x_center, x6) # (bs, 1024, T / 32, F / 32) 168 | x8 = self.decoder_block2(x7, x5) # (bs, 512, T / 16, F / 16) 169 | x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8) 170 | x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4) 171 | x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2) 172 | x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F) 173 | x = self.after_conv_block1(x12) # (bs, 32, T, F) 174 | x = self.after_conv2(x) # (bs, channels, T, F) 175 | 176 | # Recover shape 177 | x = F.pad(x, pad=(0, 1)) 178 | x = x[:, :, 0:origin_len, :] 179 | 180 | output_dict = {"mel": x} 181 | return output_dict 182 | 183 | 184 | if __name__ == "__main__": 185 | model = UNetResComplex_100Mb(channels=1) 186 | print(model(torch.randn((1, 1, 101, 128)))["mel"].size()) 187 | -------------------------------------------------------------------------------- /voicefixer/vocoder/model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from voicefixer.vocoder.model.modules import UpsampleNet, ResStack 5 | from voicefixer.vocoder.config import Config 6 | from voicefixer.vocoder.model.pqmf import PQMF 7 | import os 8 | 9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 10 | 11 | 12 | class Generator(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels=128, 16 | use_elu=False, 17 | use_gcnn=False, 18 | up_org=False, 19 | group=1, 20 | hp=None, 21 | ): 22 | super(Generator, self).__init__() 23 | self.hp = hp 24 | channels = Config.channels 25 | self.upsample_scales = Config.upsample_scales 26 | self.use_condnet = Config.use_condnet 27 | self.out_channels = Config.out_channels 28 | self.resstack_depth = Config.resstack_depth 29 | self.use_postnet = Config.use_postnet 30 | self.use_cond_rnn = Config.use_cond_rnn 31 | if self.use_condnet: 32 | cond_channels = Config.cond_channels 33 | self.condnet = nn.Sequential( 34 | nn.utils.weight_norm( 35 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 36 | ), 37 | nn.ELU(), 38 | nn.utils.weight_norm( 39 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 40 | ), 41 | nn.ELU(), 42 | nn.utils.weight_norm( 43 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 44 | ), 45 | nn.ELU(), 46 | nn.utils.weight_norm( 47 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 48 | ), 49 | nn.ELU(), 50 | nn.utils.weight_norm( 51 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 52 | ), 53 | nn.ELU(), 54 | ) 55 | in_channels = cond_channels 56 | if self.use_cond_rnn: 57 | self.rnn = nn.GRU( 58 | cond_channels, 59 | cond_channels // 2, 60 | num_layers=1, 61 | batch_first=True, 62 | bidirectional=True, 63 | ) 64 | 65 | if use_elu: 66 | act = nn.ELU() 67 | else: 68 | act = nn.LeakyReLU(0.2, True) 69 | 70 | kernel_size = Config.kernel_size 71 | 72 | if self.out_channels == 1: 73 | self.generator = nn.Sequential( 74 | nn.ReflectionPad1d(3), 75 | nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)), 76 | act, 77 | UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp, 0), 78 | ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp), 79 | act, 80 | UpsampleNet( 81 | channels // 2, channels // 4, self.upsample_scales[1], hp, 1 82 | ), 83 | ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp), 84 | act, 85 | UpsampleNet( 86 | channels // 4, channels // 8, self.upsample_scales[2], hp, 2 87 | ), 88 | ResStack(channels // 8, kernel_size[2], self.resstack_depth[2], hp), 89 | act, 90 | UpsampleNet( 91 | channels // 8, channels // 16, self.upsample_scales[3], hp, 3 92 | ), 93 | ResStack(channels // 16, kernel_size[3], self.resstack_depth[3], hp), 94 | act, 95 | nn.ReflectionPad1d(3), 96 | nn.utils.weight_norm( 97 | nn.Conv1d(channels // 16, self.out_channels, kernel_size=7) 98 | ), 99 | nn.Tanh(), 100 | ) 101 | else: 102 | channels = Config.m_channels 103 | self.generator = nn.Sequential( 104 | nn.ReflectionPad1d(3), 105 | nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)), 106 | act, 107 | UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp), 108 | ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp), 109 | act, 110 | UpsampleNet(channels // 2, channels // 4, self.upsample_scales[1], hp), 111 | ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp), 112 | act, 113 | UpsampleNet(channels // 4, channels // 8, self.upsample_scales[3], hp), 114 | ResStack(channels // 8, kernel_size[3], self.resstack_depth[2], hp), 115 | act, 116 | nn.ReflectionPad1d(3), 117 | nn.utils.weight_norm( 118 | nn.Conv1d(channels // 8, self.out_channels, kernel_size=7) 119 | ), 120 | nn.Tanh(), 121 | ) 122 | if self.out_channels > 1: 123 | self.pqmf = PQMF(4, 64) 124 | 125 | self.num_params() 126 | 127 | def forward(self, conditions, use_res=False, f0=None): 128 | res = conditions 129 | if self.use_condnet: 130 | conditions = self.condnet(conditions) 131 | if self.use_cond_rnn: 132 | conditions, _ = self.rnn(conditions.transpose(1, 2)) 133 | conditions = conditions.transpose(1, 2) 134 | 135 | wav = self.generator(conditions) 136 | if self.out_channels > 1: 137 | B = wav.size(0) 138 | f_wav = ( 139 | self.pqmf.synthesis(wav) 140 | .transpose(1, 2) 141 | .reshape(B, 1, -1) 142 | .clamp(-0.99, 0.99) 143 | ) 144 | return f_wav, wav 145 | return wav 146 | 147 | def num_params(self): 148 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 149 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 150 | return parameters 151 | # print('Trainable Parameters: %.3f million' % parameters) 152 | 153 | def remove_weight_norm(self): 154 | def _remove_weight_norm(m): 155 | try: 156 | torch.nn.utils.remove_weight_norm(m) 157 | except ValueError: # this module didn't have weight norm 158 | return 159 | 160 | self.apply(_remove_weight_norm) 161 | 162 | 163 | if __name__ == "__main__": 164 | model = Generator(128) 165 | x = torch.randn(3, 128, 13) 166 | print(x.shape) 167 | y = model(x) 168 | print(y.shape) 169 | -------------------------------------------------------------------------------- /voicefixer/base.py: -------------------------------------------------------------------------------- 1 | import librosa.display 2 | from voicefixer.tools.pytorch_util import * 3 | from voicefixer.tools.wav import * 4 | from voicefixer.restorer.model import VoiceFixer as voicefixer_fe 5 | import os 6 | 7 | EPS = 1e-8 8 | 9 | 10 | class VoiceFixer(nn.Module): 11 | def __init__(self): 12 | super(VoiceFixer, self).__init__() 13 | self._model = voicefixer_fe(channels=2, sample_rate=44100) 14 | # print(os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/epoch=15_trimed_bn.ckpt")) 15 | self.analysis_module_ckpt = os.path.join( 16 | os.path.expanduser("~"), 17 | ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt", 18 | ) 19 | if(not os.path.exists(self.analysis_module_ckpt)): 20 | raise RuntimeError("Error 0: The checkpoint for analysis module (vf.ckpt) is not found in ~/.cache/voicefixer/analysis_module/checkpoints. \ 21 | By default the checkpoint should be download automatically by this program. Something bad may happened.\ 22 | But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/vf.ckpt?download=1.") 23 | self._model.load_state_dict( 24 | torch.load( 25 | self.analysis_module_ckpt 26 | ) 27 | ) 28 | self._model.eval() 29 | 30 | def _load_wav_energy(self, path, sample_rate, threshold=0.95): 31 | wav_10k, _ = librosa.load(path, sr=sample_rate) 32 | stft = np.log10(np.abs(librosa.stft(wav_10k)) + 1.0) 33 | fbins = stft.shape[0] 34 | e_stft = np.sum(stft, axis=1) 35 | for i in range(e_stft.shape[0]): 36 | e_stft[-i - 1] = np.sum(e_stft[: -i - 1]) 37 | total = e_stft[-1] 38 | for i in range(e_stft.shape[0]): 39 | if e_stft[i] < total * threshold: 40 | continue 41 | else: 42 | break 43 | return wav_10k, int((sample_rate // 2) * (i / fbins)) 44 | 45 | def _load_wav(self, path, sample_rate, threshold=0.95): 46 | wav_10k, _ = librosa.load(path, sr=sample_rate) 47 | return wav_10k 48 | 49 | def _amp_to_original_f(self, mel_sp_est, mel_sp_target, cutoff=0.2): 50 | freq_dim = mel_sp_target.size()[-1] 51 | mel_sp_est_low, mel_sp_target_low = ( 52 | mel_sp_est[..., 5 : int(freq_dim * cutoff)], 53 | mel_sp_target[..., 5 : int(freq_dim * cutoff)], 54 | ) 55 | energy_est, energy_target = torch.mean(mel_sp_est_low, dim=(2, 3)), torch.mean( 56 | mel_sp_target_low, dim=(2, 3) 57 | ) 58 | amp_ratio = energy_target / energy_est 59 | return mel_sp_est * amp_ratio[..., None, None], mel_sp_target 60 | 61 | def _trim_center(self, est, ref): 62 | diff = np.abs(est.shape[-1] - ref.shape[-1]) 63 | if est.shape[-1] == ref.shape[-1]: 64 | return est, ref 65 | elif est.shape[-1] > ref.shape[-1]: 66 | min_len = min(est.shape[-1], ref.shape[-1]) 67 | est, ref = est[..., int(diff // 2) : -int(diff // 2)], ref 68 | est, ref = est[..., :min_len], ref[..., :min_len] 69 | return est, ref 70 | else: 71 | min_len = min(est.shape[-1], ref.shape[-1]) 72 | est, ref = est, ref[..., int(diff // 2) : -int(diff // 2)] 73 | est, ref = est[..., :min_len], ref[..., :min_len] 74 | return est, ref 75 | 76 | def _pre(self, model, input, cuda): 77 | input = input[None, None, ...] 78 | input = torch.tensor(input) 79 | input = try_tensor_cuda(input, cuda=cuda) 80 | sp, _, _ = model.f_helper.wav_to_spectrogram_phase(input) 81 | mel_orig = model.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) 82 | # return models.to_log(sp), models.to_log(mel_orig) 83 | return sp, mel_orig 84 | 85 | def remove_higher_frequency(self, wav, ratio=0.95): 86 | stft = librosa.stft(wav) 87 | real, img = np.real(stft), np.imag(stft) 88 | mag = (real**2 + img**2) ** 0.5 89 | cos, sin = real / (mag + EPS), img / (mag + EPS) 90 | spec = np.abs(stft) # [1025,T] 91 | feature = spec.copy() 92 | feature = np.log10(feature + EPS) 93 | feature[feature < 0] = 0 94 | energy_level = np.sum(feature, axis=1) 95 | threshold = np.sum(energy_level) * ratio 96 | curent_level, i = energy_level[0], 0 97 | while i < energy_level.shape[0] and curent_level < threshold: 98 | curent_level += energy_level[i + 1, ...] 99 | i += 1 100 | spec[i:, ...] = np.zeros_like(spec[i:, ...]) 101 | stft = spec * cos + 1j * spec * sin 102 | return librosa.istft(stft) 103 | 104 | @torch.no_grad() 105 | def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None): 106 | check_cuda_availability(cuda=cuda) 107 | self._model = try_tensor_cuda(self._model, cuda=cuda) 108 | if mode == 0: 109 | self._model.eval() 110 | elif mode == 1: 111 | self._model.eval() 112 | elif mode == 2: 113 | self._model.train() # More effective on seriously demaged speech 114 | res = [] 115 | seg_length = 44100 * 30 116 | break_point = seg_length 117 | while break_point < wav_10k.shape[0] + seg_length: 118 | segment = wav_10k[break_point - seg_length : break_point] 119 | if mode == 1: 120 | segment = self.remove_higher_frequency(segment) 121 | sp, mel_noisy = self._pre(self._model, segment, cuda) 122 | out_model = self._model(sp, mel_noisy) 123 | denoised_mel = from_log(out_model["mel"]) 124 | if your_vocoder_func is None: 125 | out = self._model.vocoder(denoised_mel, cuda=cuda) 126 | else: 127 | out = your_vocoder_func(denoised_mel) 128 | # unify energy 129 | if torch.max(torch.abs(out)) > 1.0: 130 | out = out / torch.max(torch.abs(out)) 131 | print("Warning: Exceed energy limit,", input) 132 | # frame alignment 133 | out, _ = self._trim_center(out, segment) 134 | res.append(out) 135 | break_point += seg_length 136 | out = torch.cat(res, -1) 137 | return tensor2numpy(out.squeeze(0)) 138 | 139 | def restore(self, input, output, cuda=False, mode=0, your_vocoder_func=None): 140 | wav_10k = self._load_wav(input, sample_rate=44100) 141 | out_np_wav = self.restore_inmem( 142 | wav_10k, cuda=cuda, mode=mode, your_vocoder_func=your_vocoder_func 143 | ) 144 | save_wave(out_np_wav, fname=output, sample_rate=44100) 145 | -------------------------------------------------------------------------------- /voicefixer/restorer/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class ConvBlockRes(nn.Module): 8 | def __init__(self, in_channels, out_channels, size, activation, momentum): 9 | super(ConvBlockRes, self).__init__() 10 | 11 | self.activation = activation 12 | if type(size) == type((3, 4)): 13 | pad = size[0] // 2 14 | size = size[0] 15 | else: 16 | pad = size // 2 17 | size = size 18 | 19 | self.conv1 = nn.Conv2d( 20 | in_channels=in_channels, 21 | out_channels=out_channels, 22 | kernel_size=(size, size), 23 | stride=(1, 1), 24 | dilation=(1, 1), 25 | padding=(pad, pad), 26 | bias=False, 27 | ) 28 | 29 | self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) 30 | # self.abn1 = InPlaceABN(num_features=in_channels, momentum=momentum, activation='leaky_relu') 31 | 32 | self.conv2 = nn.Conv2d( 33 | in_channels=out_channels, 34 | out_channels=out_channels, 35 | kernel_size=(size, size), 36 | stride=(1, 1), 37 | dilation=(1, 1), 38 | padding=(pad, pad), 39 | bias=False, 40 | ) 41 | 42 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) 43 | 44 | # self.abn2 = InPlaceABN(num_features=out_channels, momentum=momentum, activation='leaky_relu') 45 | 46 | if in_channels != out_channels: 47 | self.shortcut = nn.Conv2d( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=(1, 1), 51 | stride=(1, 1), 52 | padding=(0, 0), 53 | ) 54 | self.is_shortcut = True 55 | else: 56 | self.is_shortcut = False 57 | 58 | self.init_weights() 59 | 60 | def init_weights(self): 61 | init_bn(self.bn1) 62 | init_layer(self.conv1) 63 | init_layer(self.conv2) 64 | 65 | if self.is_shortcut: 66 | init_layer(self.shortcut) 67 | 68 | def forward(self, x): 69 | origin = x 70 | x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) 71 | x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01)) 72 | 73 | if self.is_shortcut: 74 | return self.shortcut(origin) + x 75 | else: 76 | return origin + x 77 | 78 | 79 | class EncoderBlockRes(nn.Module): 80 | def __init__(self, in_channels, out_channels, downsample, activation, momentum): 81 | super(EncoderBlockRes, self).__init__() 82 | size = 3 83 | 84 | self.conv_block1 = ConvBlockRes( 85 | in_channels, out_channels, size, activation, momentum 86 | ) 87 | self.conv_block2 = ConvBlockRes( 88 | out_channels, out_channels, size, activation, momentum 89 | ) 90 | self.conv_block3 = ConvBlockRes( 91 | out_channels, out_channels, size, activation, momentum 92 | ) 93 | self.conv_block4 = ConvBlockRes( 94 | out_channels, out_channels, size, activation, momentum 95 | ) 96 | self.downsample = downsample 97 | 98 | def forward(self, x): 99 | encoder = self.conv_block1(x) 100 | encoder = self.conv_block2(encoder) 101 | encoder = self.conv_block3(encoder) 102 | encoder = self.conv_block4(encoder) 103 | encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) 104 | return encoder_pool, encoder 105 | 106 | 107 | class DecoderBlockRes(nn.Module): 108 | def __init__(self, in_channels, out_channels, stride, activation, momentum): 109 | super(DecoderBlockRes, self).__init__() 110 | size = 3 111 | self.activation = activation 112 | 113 | self.conv1 = torch.nn.ConvTranspose2d( 114 | in_channels=in_channels, 115 | out_channels=out_channels, 116 | kernel_size=(size, size), 117 | stride=stride, 118 | padding=(0, 0), 119 | output_padding=(0, 0), 120 | bias=False, 121 | dilation=(1, 1), 122 | ) 123 | 124 | self.bn1 = nn.BatchNorm2d(in_channels) 125 | self.conv_block2 = ConvBlockRes( 126 | out_channels * 2, out_channels, size, activation, momentum 127 | ) 128 | self.conv_block3 = ConvBlockRes( 129 | out_channels, out_channels, size, activation, momentum 130 | ) 131 | self.conv_block4 = ConvBlockRes( 132 | out_channels, out_channels, size, activation, momentum 133 | ) 134 | self.conv_block5 = ConvBlockRes( 135 | out_channels, out_channels, size, activation, momentum 136 | ) 137 | 138 | def init_weights(self): 139 | init_layer(self.conv1) 140 | 141 | def prune(self, x, both=False): 142 | """Prune the shape of x after transpose convolution.""" 143 | if both: 144 | x = x[:, :, 0:-1, 0:-1] 145 | else: 146 | x = x[:, :, 0:-1, :] 147 | return x 148 | 149 | def forward(self, input_tensor, concat_tensor, both=False): 150 | x = self.conv1(F.relu_(self.bn1(input_tensor))) 151 | x = self.prune(x, both=both) 152 | x = torch.cat((x, concat_tensor), dim=1) 153 | x = self.conv_block2(x) 154 | x = self.conv_block3(x) 155 | x = self.conv_block4(x) 156 | x = self.conv_block5(x) 157 | return x 158 | 159 | 160 | def init_layer(layer): 161 | """Initialize a Linear or Convolutional layer.""" 162 | nn.init.xavier_uniform_(layer.weight) 163 | 164 | if hasattr(layer, "bias"): 165 | if layer.bias is not None: 166 | layer.bias.data.fill_(0.0) 167 | 168 | 169 | def init_bn(bn): 170 | """Initialize a Batchnorm layer.""" 171 | bn.bias.data.fill_(0.0) 172 | bn.weight.data.fill_(1.0) 173 | 174 | 175 | def init_gru(rnn): 176 | """Initialize a GRU layer.""" 177 | 178 | def _concat_init(tensor, init_funcs): 179 | (length, fan_out) = tensor.shape 180 | fan_in = length // len(init_funcs) 181 | 182 | for (i, init_func) in enumerate(init_funcs): 183 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 184 | 185 | def _inner_uniform(tensor): 186 | fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") 187 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 188 | 189 | for i in range(rnn.num_layers): 190 | _concat_init( 191 | getattr(rnn, "weight_ih_l{}".format(i)), 192 | [_inner_uniform, _inner_uniform, _inner_uniform], 193 | ) 194 | torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) 195 | 196 | _concat_init( 197 | getattr(rnn, "weight_hh_l{}".format(i)), 198 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_], 199 | ) 200 | torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) 201 | 202 | 203 | from torch.cuda import init 204 | 205 | 206 | def act(x, activation): 207 | if activation == "relu": 208 | return F.relu_(x) 209 | 210 | elif activation == "leaky_relu": 211 | return F.leaky_relu_(x, negative_slope=0.2) 212 | 213 | elif activation == "swish": 214 | return x * torch.sigmoid(x) 215 | 216 | else: 217 | raise Exception("Incorrect activation!") 218 | -------------------------------------------------------------------------------- /voicefixer/tools/wav.py: -------------------------------------------------------------------------------- 1 | import wave 2 | import os 3 | import numpy as np 4 | import scipy.signal as signal 5 | import soundfile as sf 6 | import librosa 7 | 8 | 9 | def save_wave(frames: np.ndarray, fname, sample_rate=44100): 10 | shape = list(frames.shape) 11 | if len(shape) == 1: 12 | frames = frames[..., None] 13 | in_samples, in_channels = shape[-2], shape[-1] 14 | if in_channels >= 3: 15 | if len(shape) == 2: 16 | frames = np.transpose(frames, (1, 0)) 17 | elif len(shape) == 3: 18 | frames = np.transpose(frames, (0, 2, 1)) 19 | msg = ( 20 | "Warning: Save audio with " 21 | + str(in_channels) 22 | + " channels, save permute audio with shape " 23 | + str(list(frames.shape)) 24 | + " please check if it's correct." 25 | ) 26 | # print(msg) 27 | if ( 28 | np.max(frames) <= 1 29 | and frames.dtype == np.float32 30 | or frames.dtype == np.float16 31 | or frames.dtype == np.float64 32 | ): 33 | frames *= 2**15 34 | frames = frames.astype(np.short) 35 | if len(frames.shape) >= 3: 36 | frames = frames[0, ...] 37 | sf.write(fname, frames, samplerate=sample_rate) 38 | 39 | 40 | def constrain_length(chunk, length): 41 | frames_length = chunk.shape[0] 42 | if frames_length == length: 43 | return chunk 44 | elif frames_length < length: 45 | return np.pad(chunk, ((0, int(length - frames_length)), (0, 0)), "constant") 46 | else: 47 | return chunk[:length, ...] 48 | 49 | 50 | def random_chunk_wav_file(fname, chunk_length): 51 | """ 52 | fname: path to wav file 53 | chunk_length: frame length in seconds 54 | """ 55 | with wave.open(fname) as f: 56 | params = f.getparams() 57 | duration = params[3] / params[2] 58 | sample_rate = params[2] 59 | sample_length = params[3] 60 | if duration < chunk_length or abs(duration - chunk_length) < 1e-4: 61 | frames = read_wave(fname, sample_rate) 62 | return frames, duration, sample_rate # [-1,1] 63 | else: 64 | # Random trunk 65 | random_starts = np.random.randint( 66 | 0, sample_length - sample_rate * chunk_length 67 | ) 68 | random_end = random_starts + sample_rate * chunk_length 69 | random_starts, random_end = ( 70 | random_starts / sample_rate, 71 | random_end / sample_rate, 72 | ) 73 | random_starts, random_end = random_starts / duration, random_end / duration 74 | frames = read_wave( 75 | fname, sample_rate, portion_start=random_starts, portion_end=random_end 76 | ) 77 | frames = constrain_length(frames, length=int(chunk_length * sample_rate)) 78 | return frames, chunk_length, sample_rate 79 | 80 | 81 | def random_chunk_wav_file_v2(fname, chunk_length, random_starts=None, random_end=None): 82 | """ 83 | fname: path to wav file 84 | chunk_length: frame length in seconds 85 | """ 86 | with wave.open(fname) as f: 87 | params = f.getparams() 88 | duration = params[3] / params[2] 89 | sample_rate = params[2] 90 | sample_length = params[3] 91 | if duration < chunk_length or abs(duration - chunk_length) < 1e-4: 92 | frames = read_wave(fname, sample_rate) 93 | return frames, duration, sample_rate # [-1,1] 94 | else: 95 | # Random trunk 96 | if random_starts is None and random_end is None: 97 | random_starts = np.random.randint( 98 | 0, sample_length - sample_rate * chunk_length 99 | ) 100 | random_end = random_starts + sample_rate * chunk_length 101 | random_starts, random_end = ( 102 | random_starts / sample_rate, 103 | random_end / sample_rate, 104 | ) 105 | random_starts, random_end = ( 106 | random_starts / duration, 107 | random_end / duration, 108 | ) 109 | frames = read_wave( 110 | fname, sample_rate, portion_start=random_starts, portion_end=random_end 111 | ) 112 | frames = constrain_length(frames, length=int(chunk_length * sample_rate)) 113 | return frames, chunk_length, sample_rate, random_starts, random_end 114 | 115 | 116 | def read_wave( 117 | fname, 118 | sample_rate, 119 | portion_start=0, 120 | portion_end=1, 121 | ): # Whether you want raw bytes 122 | """ 123 | :param fname: wav file path 124 | :param sample_rate: 125 | :param portion_start: 126 | :param portion_end: 127 | :return: [sample, channels] 128 | """ 129 | # sr = get_sample_rate(fname) 130 | # if(sr != sample_rate): 131 | # print("Warning: Sample rate not match, may lead to unexpected behavior.") 132 | if portion_end > 1 and portion_end < 1.1: 133 | portion_end = 1 134 | if portion_end != 1: 135 | duration = get_duration(fname) 136 | wav, _ = librosa.load( 137 | fname, 138 | sr=sample_rate, 139 | offset=portion_start * duration, 140 | duration=(portion_end - portion_start) * duration, 141 | mono=False, 142 | ) 143 | else: 144 | wav, _ = librosa.load(fname, sr=sample_rate, mono=False) 145 | if len(list(wav.shape)) == 1: 146 | wav = wav[..., None] 147 | else: 148 | wav = np.transpose(wav, (1, 0)) 149 | return wav 150 | 151 | 152 | def get_channels_sampwidth_and_sample_rate(fname): 153 | with wave.open(fname) as f: 154 | params = f.getparams() 155 | return ( 156 | params[0], 157 | params[1], 158 | params[2], 159 | ) # == (2,2,44100),(params[0],params[1],params[2]) 160 | 161 | 162 | def get_channels(fname): 163 | with wave.open(fname) as f: 164 | params = f.getparams() 165 | return params[0] 166 | 167 | 168 | def get_sample_rate(fname): 169 | with wave.open(fname) as f: 170 | params = f.getparams() 171 | return params[2] 172 | 173 | 174 | def get_duration(fname): 175 | with wave.open(fname) as f: 176 | params = f.getparams() 177 | return params[3] / params[2] 178 | 179 | 180 | def get_framesLength(fname): 181 | with wave.open(fname) as f: 182 | params = f.getparams() 183 | return params[3] 184 | 185 | 186 | def restore_wave(zxx): 187 | _, w = signal.istft(zxx) 188 | return w 189 | 190 | 191 | def calculate_total_times(dir): 192 | total = 0 193 | for each in os.listdir(dir): 194 | fname = os.path.join(dir, each) 195 | try: 196 | duration = get_duration(fname) 197 | except: 198 | print(fname) 199 | total += duration 200 | return total 201 | 202 | 203 | def filter(pth): 204 | global dic 205 | temp = [] 206 | for each in os.listdir(pth): 207 | temp.append(os.path.join(pth, each)) 208 | for each in temp: 209 | sr = get_sample_rate(each) 210 | if sr not in dic.keys(): 211 | dic[sr] = [] 212 | dic[sr].append(each) 213 | for each in dic[16000]: 214 | # print(each) 215 | pass 216 | print(dic.keys()) 217 | for each in list(dic.keys()): 218 | print(each, len(dic[each])) 219 | 220 | 221 | if __name__ == "__main__": 222 | path = "/Users/admin/Desktop/p376_025.wav" 223 | stereo = "/Users/admin/Desktop/vocals.wav" 224 | path_16 = "/Users/admin/Desktop/SI869.WAV.wav" 225 | import time 226 | 227 | start = time.time() 228 | for i in range(1000): 229 | frames, duration, sample_rate = random_chunk_wav_file(stereo, chunk_length=3.0) 230 | print(frames.shape, np.max(frames)) 231 | save_wave(frames, "stero.wav", sample_rate=44100) 232 | frames, duration, sample_rate = random_chunk_wav_file(path, chunk_length=3.0) 233 | print(frames.shape, np.max(frames)) 234 | save_wave(frames, "mono.wav", sample_rate=44100) 235 | frames, duration, sample_rate = random_chunk_wav_file(path_16, chunk_length=3.0) 236 | print(frames.shape, np.max(frames)) 237 | save_wave(frames, "16.wav", sample_rate=16000) 238 | print(time.time() - start) 239 | # frames = read_wave(stereo,sample_rate=44100) 240 | print(frames.shape) 241 | 242 | print(frames) 243 | -------------------------------------------------------------------------------- /voicefixer/tools/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | import torch.fft 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 9 | 10 | 11 | def get_window(window_size, window_type, square_root_window=True): 12 | """Return the window""" 13 | window = { 14 | "hamming": torch.hamming_window(window_size), 15 | "hanning": torch.hann_window(window_size), 16 | }[window_type] 17 | if square_root_window: 18 | window = torch.sqrt(window) 19 | return window 20 | 21 | 22 | def fft_point(dim): 23 | assert dim > 0 24 | num = math.log(dim, 2) 25 | num_point = 2 ** (math.ceil(num)) 26 | return num_point 27 | 28 | 29 | def pre_emphasis(signal, coefficient=0.97): 30 | """Pre-emphasis original signal 31 | y(n) = x(n) - a*x(n-1) 32 | """ 33 | return np.append(signal[0], signal[1:] - coefficient * signal[:-1]) 34 | 35 | 36 | def de_emphasis(signal, coefficient=0.97): 37 | """De-emphasis original signal 38 | y(n) = x(n) + a*x(n-1) 39 | """ 40 | length = signal.shape[0] 41 | for i in range(1, length): 42 | signal[i] = signal[i] + coefficient * signal[i - 1] 43 | return signal 44 | 45 | 46 | def seperate_magnitude(magnitude, phase): 47 | real = torch.cos(phase) * magnitude 48 | imagine = torch.sin(phase) * magnitude 49 | expand_dim = len(list(real.size())) 50 | return torch.stack((real, imagine), expand_dim) 51 | 52 | 53 | def stft_single( 54 | signal, 55 | sample_rate=44100, 56 | frame_length=46, 57 | frame_shift=10, 58 | window_type="hanning", 59 | device=torch.device("cuda"), 60 | square_root_window=True, 61 | ): 62 | """Compute the Short Time Fourier Transform. 63 | 64 | Args: 65 | signal: input speech signal, 66 | sample_rate: waveform datas sample frequency (Hz) 67 | frame_length: frame length in milliseconds 68 | frame_shift: frame shift in milliseconds 69 | window_type: type of window 70 | square_root_window: square root window 71 | Return: 72 | fft: (n/2)+1 dim complex STFT restults 73 | """ 74 | hop_length = int( 75 | sample_rate * frame_shift / 1000 76 | ) # The greater sample_rate, the greater hop_length 77 | win_length = int(sample_rate * frame_length / 1000) 78 | # num_point = fft_point(win_length) 79 | num_point = win_length 80 | window = get_window(num_point, window_type, square_root_window) 81 | if "cuda" in str(device): 82 | window = window.cuda(device) 83 | feat = torch.stft( 84 | signal, 85 | n_fft=num_point, 86 | hop_length=hop_length, 87 | win_length=window.shape[0], 88 | window=window, 89 | ) 90 | real, imag = feat[..., 0], feat[..., 1] 91 | return real.permute(0, 2, 1).unsqueeze(1), imag.permute(0, 2, 1).unsqueeze(1) 92 | 93 | 94 | def istft( 95 | real, 96 | imag, 97 | length, 98 | sample_rate=44100, 99 | frame_length=46, 100 | frame_shift=10, 101 | window_type="hanning", 102 | preemphasis=0.0, 103 | device=torch.device("cuda"), 104 | square_root_window=True, 105 | ): 106 | """Convert frames to signal using overlap-and-add systhesis. 107 | Args: 108 | spectrum: magnitude spectrum [batchsize,x,y,2] 109 | signal: wave signal to supply phase information 110 | Return: 111 | wav: synthesied output waveform 112 | """ 113 | real = real.permute(0, 3, 2, 1) 114 | imag = imag.permute(0, 3, 2, 1) 115 | spectrum = torch.cat([real, imag], dim=-1) 116 | 117 | hop_length = int(sample_rate * frame_shift / 1000) 118 | win_length = int(sample_rate * frame_length / 1000) 119 | 120 | # num_point = fft_point(win_length) 121 | num_point = win_length 122 | if "cuda" in str(device): 123 | window = get_window(num_point, window_type, square_root_window).cuda(device) 124 | else: 125 | window = get_window(num_point, window_type, square_root_window) 126 | 127 | wav = torch_istft( 128 | spectrum, 129 | num_point, 130 | hop_length=hop_length, 131 | win_length=window.shape[0], 132 | window=window, 133 | ) 134 | return wav[..., :length] 135 | 136 | 137 | def torch_istft( 138 | stft_matrix, # type: Tensor 139 | n_fft, # type: int 140 | hop_length=None, # type: Optional[int] 141 | win_length=None, # type: Optional[int] 142 | window=None, # type: Optional[Tensor] 143 | center=True, # type: bool 144 | pad_mode="reflect", # type: str 145 | normalized=False, # type: bool 146 | onesided=True, # type: bool 147 | length=None, # type: Optional[int] 148 | ): 149 | # type: (...) -> Tensor 150 | 151 | stft_matrix_dim = stft_matrix.dim() 152 | assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) 153 | 154 | if stft_matrix_dim == 3: 155 | # add a channel dimension 156 | stft_matrix = stft_matrix.unsqueeze(0) 157 | 158 | dtype = stft_matrix.dtype 159 | device = stft_matrix.device 160 | fft_size = stft_matrix.size(1) 161 | assert (onesided and n_fft // 2 + 1 == fft_size) or ( 162 | not onesided and n_fft == fft_size 163 | ), ( 164 | "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. " 165 | + "Given values were onesided: %s, n_fft: %d, fft_size: %d" 166 | % ("True" if onesided else False, n_fft, fft_size) 167 | ) 168 | 169 | # use stft defaults for Optionals 170 | if win_length is None: 171 | win_length = n_fft 172 | 173 | if hop_length is None: 174 | hop_length = int(win_length // 4) 175 | 176 | # There must be overlap 177 | assert 0 < hop_length <= win_length 178 | assert 0 < win_length <= n_fft 179 | 180 | if window is None: 181 | window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype) 182 | 183 | assert window.dim() == 1 and window.size(0) == win_length 184 | 185 | if win_length != n_fft: 186 | # center window with pad left and right zeros 187 | left = (n_fft - win_length) // 2 188 | window = torch.nn.functional.pad(window, (left, n_fft - win_length - left)) 189 | assert window.size(0) == n_fft 190 | # win_length and n_fft are synonymous from here on 191 | 192 | stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) 193 | stft_matrix = torch.irfft( 194 | stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,) 195 | ) # size (channel, n_frames, n_fft) 196 | 197 | assert stft_matrix.size(2) == n_fft 198 | n_frames = stft_matrix.size(1) 199 | 200 | ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft) 201 | # each column of a channel is a frame which needs to be overlap added at the right place 202 | ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) 203 | 204 | eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze( 205 | 1 206 | ) # size (n_fft, 1, n_fft) 207 | 208 | # this does overlap add where the frames of ytmp are added such that the i'th frame of 209 | # ytmp is added starting at i*hop_length in the output 210 | y = torch.nn.functional.conv_transpose1d( 211 | ytmp, eye, stride=hop_length, padding=0 212 | ) # size (channel, 1, expected_signal_len) 213 | 214 | # do the same for the window function 215 | window_sq = ( 216 | window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) 217 | ) # size (1, n_fft, n_frames) 218 | window_envelop = torch.nn.functional.conv_transpose1d( 219 | window_sq, eye, stride=hop_length, padding=0 220 | ) # size (1, 1, expected_signal_len) 221 | 222 | expected_signal_len = n_fft + hop_length * (n_frames - 1) 223 | assert y.size(2) == expected_signal_len 224 | assert window_envelop.size(2) == expected_signal_len 225 | 226 | half_n_fft = n_fft // 2 227 | # we need to trim the front padding away if center 228 | start = half_n_fft if center else 0 229 | end = -half_n_fft if length is None else start + length 230 | 231 | y = y[:, :, start:end] 232 | window_envelop = window_envelop[:, :, start:end] 233 | 234 | # check NOLA non-zero overlap condition 235 | window_envelop_lowest = window_envelop.abs().min() 236 | assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % ( 237 | window_envelop_lowest 238 | ) 239 | 240 | y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) 241 | 242 | if stft_matrix_dim == 3: # remove the channel dimension 243 | y = y.squeeze(0) 244 | return y 245 | -------------------------------------------------------------------------------- /voicefixer/tools/mel_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | import math 5 | 6 | import warnings 7 | 8 | 9 | class MelScale(torch.nn.Module): 10 | r"""Turn a normal STFT into a mel frequency STFT, using a conversion 11 | matrix. This uses triangular filter banks. 12 | 13 | User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). 14 | 15 | Args: 16 | n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) 17 | sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) 18 | f_min (float, optional): Minimum frequency. (Default: ``0.``) 19 | f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) 20 | n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) 21 | norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band 22 | (area normalization). (Default: ``None``) 23 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 24 | 25 | See also: 26 | :py:func:`torchaudio.functional.melscale_fbanks` - The function used to 27 | generate the filter banks. 28 | """ 29 | __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"] 30 | 31 | def __init__( 32 | self, 33 | n_mels: int = 128, 34 | sample_rate: int = 16000, 35 | f_min: float = 0.0, 36 | f_max: Optional[float] = None, 37 | n_stft: int = 201, 38 | norm: Optional[str] = None, 39 | mel_scale: str = "htk", 40 | ) -> None: 41 | super(MelScale, self).__init__() 42 | self.n_mels = n_mels 43 | self.sample_rate = sample_rate 44 | self.f_max = f_max if f_max is not None else float(sample_rate // 2) 45 | self.f_min = f_min 46 | self.norm = norm 47 | self.mel_scale = mel_scale 48 | 49 | assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format( 50 | f_min, self.f_max 51 | ) 52 | fb = melscale_fbanks( 53 | n_stft, 54 | self.f_min, 55 | self.f_max, 56 | self.n_mels, 57 | self.sample_rate, 58 | self.norm, 59 | self.mel_scale, 60 | ) 61 | self.register_buffer("fb", fb) 62 | 63 | def forward(self, specgram: Tensor) -> Tensor: 64 | r""" 65 | Args: 66 | specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). 67 | 68 | Returns: 69 | Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). 70 | """ 71 | 72 | # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time) 73 | mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose( 74 | -1, -2 75 | ) 76 | 77 | return mel_specgram 78 | 79 | 80 | def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float: 81 | r"""Convert Hz to Mels. 82 | 83 | Args: 84 | freqs (float): Frequencies in Hz 85 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 86 | 87 | Returns: 88 | mels (float): Frequency in Mels 89 | """ 90 | 91 | if mel_scale not in ["slaney", "htk"]: 92 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 93 | 94 | if mel_scale == "htk": 95 | return 2595.0 * math.log10(1.0 + (freq / 700.0)) 96 | 97 | # Fill in the linear part 98 | f_min = 0.0 99 | f_sp = 200.0 / 3 100 | 101 | mels = (freq - f_min) / f_sp 102 | 103 | # Fill in the log-scale part 104 | min_log_hz = 1000.0 105 | min_log_mel = (min_log_hz - f_min) / f_sp 106 | logstep = math.log(6.4) / 27.0 107 | 108 | if freq >= min_log_hz: 109 | mels = min_log_mel + math.log(freq / min_log_hz) / logstep 110 | 111 | return mels 112 | 113 | 114 | def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor: 115 | """Convert mel bin numbers to frequencies. 116 | 117 | Args: 118 | mels (Tensor): Mel frequencies 119 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 120 | 121 | Returns: 122 | freqs (Tensor): Mels converted in Hz 123 | """ 124 | 125 | if mel_scale not in ["slaney", "htk"]: 126 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 127 | 128 | if mel_scale == "htk": 129 | return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) 130 | 131 | # Fill in the linear scale 132 | f_min = 0.0 133 | f_sp = 200.0 / 3 134 | freqs = f_min + f_sp * mels 135 | 136 | # And now the nonlinear scale 137 | min_log_hz = 1000.0 138 | min_log_mel = (min_log_hz - f_min) / f_sp 139 | logstep = math.log(6.4) / 27.0 140 | 141 | log_t = mels >= min_log_mel 142 | freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) 143 | 144 | return freqs 145 | 146 | 147 | def _create_triangular_filterbank( 148 | all_freqs: Tensor, 149 | f_pts: Tensor, 150 | ) -> Tensor: 151 | """Create a triangular filter bank. 152 | 153 | Args: 154 | all_freqs (Tensor): STFT freq points of size (`n_freqs`). 155 | f_pts (Tensor): Filter mid points of size (`n_filter`). 156 | 157 | Returns: 158 | fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`). 159 | """ 160 | # Adopted from Librosa 161 | # calculate the difference between each filter mid point and each stft freq point in hertz 162 | f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) 163 | slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2) 164 | # create overlapping triangles 165 | zero = torch.zeros(1) 166 | down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter) 167 | up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter) 168 | fb = torch.max(zero, torch.min(down_slopes, up_slopes)) 169 | 170 | return fb 171 | 172 | 173 | def melscale_fbanks( 174 | n_freqs: int, 175 | f_min: float, 176 | f_max: float, 177 | n_mels: int, 178 | sample_rate: int, 179 | norm: Optional[str] = None, 180 | mel_scale: str = "htk", 181 | ) -> Tensor: 182 | r"""Create a frequency bin conversion matrix. 183 | 184 | Note: 185 | For the sake of the numerical compatibility with librosa, not all the coefficients 186 | in the resulting filter bank has magnitude of 1. 187 | 188 | .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png 189 | :alt: Visualization of generated filter bank 190 | 191 | Args: 192 | n_freqs (int): Number of frequencies to highlight/apply 193 | f_min (float): Minimum frequency (Hz) 194 | f_max (float): Maximum frequency (Hz) 195 | n_mels (int): Number of mel filterbanks 196 | sample_rate (int): Sample rate of the audio waveform 197 | norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band 198 | (area normalization). (Default: ``None``) 199 | mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) 200 | 201 | Returns: 202 | Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) 203 | meaning number of frequencies to highlight/apply to x the number of filterbanks. 204 | Each column is a filterbank so that assuming there is a matrix A of 205 | size (..., ``n_freqs``), the applied result would be 206 | ``A * melscale_fbanks(A.size(-1), ...)``. 207 | 208 | """ 209 | 210 | if norm is not None and norm != "slaney": 211 | raise ValueError("norm must be one of None or 'slaney'") 212 | 213 | # freq bins 214 | all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) 215 | 216 | # calculate mel freq bins 217 | m_min = _hz_to_mel(f_min, mel_scale=mel_scale) 218 | m_max = _hz_to_mel(f_max, mel_scale=mel_scale) 219 | 220 | m_pts = torch.linspace(m_min, m_max, n_mels + 2) 221 | f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale) 222 | 223 | # create filterbank 224 | fb = _create_triangular_filterbank(all_freqs, f_pts) 225 | 226 | if norm is not None and norm == "slaney": 227 | # Slaney-style mel is scaled to be approx constant energy per channel 228 | enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) 229 | fb *= enorm.unsqueeze(0) 230 | 231 | if (fb.max(dim=0).values == 0.0).any(): 232 | warnings.warn( 233 | "At least one mel filterbank has all zero values. " 234 | f"The value for `n_mels` ({n_mels}) may be set too high. " 235 | f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." 236 | ) 237 | 238 | return fb 239 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](https://arxiv.org/abs/2109.13731) [](https://colab.research.google.com/drive/1HYYUepIsl2aXsdET6P_AmNVXuWP1MCMf?usp=sharing) [](https://badge.fury.io/py/voicefixer) [](https://haoheliu.github.io/demopage-voicefixer)[](https://huggingface.co/spaces/akhaliq/VoiceFixer) 2 | 3 | - [:speaking_head: :wrench: VoiceFixer](#speaking_head-wrench-voicefixer) 4 | - [Demo](#demo) 5 | - [Usage](#usage) 6 | - [Command line](#command-line) 7 | - [Desktop App](#desktop-app) 8 | - [Python Examples](#python-examples) 9 | - [Others Features](#others-features) 10 | - [Materials](#materials) 11 | - [Change log](#change-log) 12 | 13 | # :speaking_head: :wrench: VoiceFixer 14 | 15 | *Voicefixer* aims to restore human speech regardless how serious its degraded. It can handle noise, reveberation, low resolution (2kHz~44.1kHz) and clipping (0.1-1.0 threshold) effect within one model. 16 | 17 | This package provides: 18 | - A pretrained *Voicefixer*, which is build based on neural vocoder. 19 | - A pretrained 44.1k universal speaker-independent neural vocoder. 20 | 21 |  22 | 23 | - If you found this repo helpful, please consider citing or [](https://www.buymeacoffee.com/haoheliuP) 24 | 25 | ```bib 26 | @misc{liu2021voicefixer, 27 | title={VoiceFixer: Toward General Speech Restoration With Neural Vocoder}, 28 | author={Haohe Liu and Qiuqiang Kong and Qiao Tian and Yan Zhao and DeLiang Wang and Chuanzeng Huang and Yuxuan Wang}, 29 | year={2021}, 30 | eprint={2109.13731}, 31 | archivePrefix={arXiv}, 32 | primaryClass={cs.SD} 33 | } 34 | ``` 35 | 36 | ## Demo 37 | 38 | Please visit [demo page](https://haoheliu.github.io/demopage-voicefixer/) to view what voicefixer can do. 39 | 40 | ## Usage 41 | 42 | ### Command line 43 | 44 | First, install voicefixer via pip: 45 | ```shell 46 | pip install voicefixer==0.1.2 47 | ``` 48 | 49 | Process a file: 50 | ```shell 51 | # Specify the input .wav file. Output file is outfile.wav. 52 | voicefixer --infile test/utterance/original/original.wav 53 | # Or specify a output path 54 | voicefixer --infile test/utterance/original/original.wav --outfile test/utterance/original/original_processed.wav 55 | ``` 56 | 57 | Process files in a folder: 58 | ```shell 59 | voicefixer --infolder /path/to/input --outfolder /path/to/output 60 | ``` 61 | 62 | Change mode (The default mode is 0): 63 | ```shell 64 | voicefixer --infile /path/to/input.wav --outfile /path/to/output.wav --mode 1 65 | ``` 66 | 67 | Run all modes: 68 | ```shell 69 | # output file saved to `/path/to/output-modeX.wav`. 70 | voicefixer --infile /path/to/input.wav --outfile /path/to/output.wav --mode all 71 | ``` 72 | 73 | For more helper information please run: 74 | 75 | ```shell 76 | voicefixer -h 77 | ``` 78 | 79 | ### Desktop App 80 | 81 | [Demo on Youtube](https://www.youtube.com/watch?v=d_j8UKTZ7J8) (Thanks @Justin John) 82 | 83 | Install voicefixer via pip: 84 | ```shell script 85 | pip install voicefixer==0.1.2 86 | ``` 87 | 88 | You can test audio samples on your desktop by running website (powered by [streamlit](https://streamlit.io/)) 89 | 90 | 1. Clone the repo first. 91 | ```shell script 92 | git clone https://github.com/haoheliu/voicefixer.git 93 | cd voicefixer 94 | ``` 95 | :warning: **For windows users**, please make sure you have installed [WGET](https://eternallybored.org/misc/wget) and added the wget command to the system path (thanks @justinjohn0306). 96 | 97 | 98 | 2. Initialize and start web page. 99 | ```shell script 100 | # Run streamlit 101 | streamlit run test/streamlit.py 102 | ``` 103 | 104 | - If you run for the first time: the web page may leave blank for several minutes for downloading models. You can checkout the terminal for downloading progresses. 105 | 106 | - You can use [this low quality speech file](https://github.com/haoheliu/voicefixer/blob/main/test/utterance/original/original.wav) we provided for a test run. The page after processing will look like the following. 107 | 108 |
