├── data ├── __init__.py ├── compress_audio.sh ├── data_loader.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── audio_dataset.py └── test.csv ├── models ├── __init__.py ├── models.py ├── base_model.py ├── spectrogram.py ├── mdct.py └── networks.py ├── util ├── __init__.py ├── spectro_img.py ├── image_pool.py ├── html.py ├── visualizer.py └── util.py ├── options ├── __init__.py ├── audio_config.py ├── test_options.py ├── train_options.py └── base_options.py ├── .gitignore ├── requirements.txt ├── generate_audio.sh ├── train.sh ├── LICENSE ├── generate_audio.py ├── README.md └── train.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | */__pycache__ -------------------------------------------------------------------------------- /data/compress_audio.sh: -------------------------------------------------------------------------------- 1 | find ./ -name '*.wav' -exec bash -c 'ffmpeg -i $0 -c:a flac -ar 48000 -compression_level 12 ${0/.wav/.flac} && rm $0' {} \; -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bottleneck_transformer_pytorch==0.1.4 2 | dominate 3 | einops 4 | matplotlib 5 | numpy 6 | Pillow 7 | scipy 8 | torch 9 | torchaudio 10 | torchvision 11 | torch_scatter 12 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def get_train_dataloader(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /options/audio_config.py: -------------------------------------------------------------------------------- 1 | N_FFT = 512 2 | HOP_LENGTH = 256 3 | WIN_LENGTH = 512 4 | LR_SAMPLE_RATE = 8000 5 | HR_SAMPLE_RATE = 48000 6 | SR_SAMPLE_RATE = 48000 7 | BINS = 128 8 | assert BINS%16 == 0 #must divisable by 16 9 | CENTER = True 10 | if CENTER: 11 | FRAME_LENGTH = (BINS-1)*HOP_LENGTH 12 | else: 13 | FRAME_LENGTH = (BINS-1)*HOP_LENGTH + WIN_LENGTH -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def create_model(opt): 4 | if opt.model == 'pix2pixHD': 5 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel 6 | if opt.isTrain: 7 | model = Pix2PixHDModel() 8 | else: 9 | model = InferenceModel() 10 | else: 11 | from .ui_model import UIModel 12 | model = UIModel() 13 | model.initialize(opt) 14 | if opt.verbose: 15 | print("model [%s] was created" % (model.name())) 16 | 17 | # if opt.isTrain and len(opt.gpu_ids) and not opt.fp16: 18 | # model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 19 | 20 | return model 21 | -------------------------------------------------------------------------------- /generate_audio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python generate_audio.py \ 4 | --name output_folder_name \ 5 | --load_pretrain /home/neoncloud/pix2pixHDAudioSR/checkpoints/vctk_fintune_G4A3L3_56ngf_3x \ 6 | --lr_sampling_rate 16000 --sr_sampling_rate 48000 \ 7 | --dataroot /mnt/d/Desktop/LJ001-0056.wav --batchSize 16 \ 8 | --gpu_id 0 --fp16 --nThreads 1 \ 9 | --arcsinh_transform --abs_spectro --arcsinh_gain 1000 --center \ 10 | --norm_range -1 1 --smooth 0.0 --abs_norm --src_range -5 5 \ 11 | --netG local --ngf 56 --niter 40 \ 12 | --n_downsample_global 3 --n_blocks_global 4 \ 13 | --n_blocks_attn_g 3 --dim_head_g 128 --heads_g 6 --proj_factor_g 4 \ 14 | --n_blocks_attn_l 0 --n_blocks_local 3 --gen_overlap 0 \ 15 | --fit_residual --upsample_type interpolate --downsample_type resconv --phase test -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --name your_training_name \ 5 | --dataroot /home/neoncloud/VCTK-Corpus/train.csv --evalroot /home/neoncloud/VCTK-Corpus/test.csv \ 6 | --lr_sampling_rate 16000 --sr_sampling_rate 48000 \ 7 | --batchSize 20 \ 8 | --gpu_id 0 --fp16 --nThreads 16 --lr 1.5e-4 \ 9 | --arcsinh_transform --abs_spectro --arcsinh_gain 1000 --center \ 10 | --norm_range -1 1 --smooth 0.0 --abs_norm --src_range -5 5 \ 11 | --netG local --ngf 56 \ 12 | --n_downsample_global 3 --n_blocks_global 4 \ 13 | --n_blocks_attn_g 3 --dim_head_g 128 --heads_g 6 --proj_factor_g 4 \ 14 | --n_blocks_attn_l 0 --n_blocks_local 3 \ 15 | --fit_residual --upsample_type interpolate --downsample_type resconv \ 16 | --niter 60 --niter_decay 60 --num_D 3 \ 17 | --eval_freq 32000 --save_latest_freq 16000 --save_epoch_freq 10 --display_freq 16000 --tf_log -------------------------------------------------------------------------------- /util/spectro_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | plt.switch_backend('agg') 4 | def fig2img(fig): 5 | """Convert a Matplotlib figure to a PIL Image and return it""" 6 | fig.canvas.draw() 7 | return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (3,)) 8 | 9 | def compute_visuals(sp=None, pha=None, abs=False): 10 | sp = sp.transpose() if sp is not None else None 11 | pha = pha.transpose() if pha is not None else None 12 | sp_spectro = None 13 | sp_hist = None 14 | if sp is not None: 15 | sp_fig, sp_ax = plt.subplots() 16 | sp_ax.pcolormesh(sp if not abs else np.abs(sp), cmap='PuBu_r') 17 | sp_spectro = fig2img(sp_fig) 18 | 19 | sp_hist_fig, sp_hist_ax = plt.subplots() 20 | sp_hist_ax.hist(sp.reshape(-1,1),bins=100) 21 | sp_hist = fig2img(sp_hist_fig) 22 | 23 | if pha is not None: 24 | pha_fig, pha_ax = plt.subplots() 25 | pha_ax.pcolormesh(pha, cmap='cool') 26 | pha = fig2img(pha_fig) 27 | 28 | plt.close('all') 29 | return sp_spectro, sp_hist, pha 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | --------------------------- LICENSE FOR pix2pixHD ---------------- 2 | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 3 | BSD License. All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 17 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 18 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 19 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 20 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 8 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 10 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 11 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 12 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 13 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') 14 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") 15 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") 16 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") 17 | self.isTrain = False 18 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | def save(self): 45 | html_file = '%s/index.html' % self.web_dir 46 | f = open(html_file, 'wt') 47 | f.write(self.doc.render()) 48 | f.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | html = HTML('web/', 'test_html') 53 | html.add_header('hello world') 54 | 55 | ims = [] 56 | txts = [] 57 | links = [] 58 | for n in range(4): 59 | ims.append('image_%d.jpg' % n) 60 | txts.append('text_%d' % n) 61 | links.append('image_%d.jpg' % n) 62 | html.add_images(ims, txts, links) 63 | html.save() 64 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | flip = random.random() > 0.5 31 | return {'crop_pos': (x, y), 'flip': flip} 32 | 33 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 34 | transform_list = [] 35 | if 'resize' in opt.resize_or_crop: 36 | osize = [opt.loadSize, opt.loadSize] 37 | transform_list.append(transforms.Scale(osize, method)) 38 | elif 'scale_width' in opt.resize_or_crop: 39 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 40 | 41 | if 'crop' in opt.resize_or_crop: 42 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 43 | 44 | if opt.resize_or_crop == 'none': 45 | base = float(2 ** opt.n_downsample_global) 46 | if opt.netG == 'local': 47 | base *= (2 ** opt.n_local_enhancers) 48 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 49 | 50 | if opt.isTrain and not opt.no_flip: 51 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 52 | 53 | transform_list += [transforms.ToTensor()] 54 | 55 | if normalize: 56 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 57 | (0.5, 0.5, 0.5))] 58 | return transforms.Compose(transform_list) 59 | 60 | def normalize(): 61 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 62 | 63 | def __make_power_2(img, base, method=Image.BICUBIC): 64 | ow, oh = img.size 65 | h = int(round(oh / base) * base) 66 | w = int(round(ow / base) * base) 67 | if (h == oh) and (w == ow): 68 | return img 69 | return img.resize((w, h), method) 70 | 71 | def __scale_width(img, target_width, method=Image.BICUBIC): 72 | ow, oh = img.size 73 | if (ow == target_width): 74 | return img 75 | w = target_width 76 | h = int(target_width * oh / ow) 77 | return img.resize((w, h), method) 78 | 79 | def __crop(img, pos, size): 80 | ow, oh = img.size 81 | x1, y1 = pos 82 | tw = th = size 83 | if (ow > tw or oh > th): 84 | return img.crop((x1, y1, x1 + tw, y1 + th)) 85 | return img 86 | 87 | def __flip(img, flip): 88 | if flip: 89 | return img.transpose(Image.FLIP_LEFT_RIGHT) 90 | return img 91 | -------------------------------------------------------------------------------- /generate_audio.py: -------------------------------------------------------------------------------- 1 | from util.util import compute_matrics 2 | import torch 3 | import torchaudio 4 | import os 5 | 6 | from options.train_options import TrainOptions 7 | from data.data_loader import CreateDataLoader 8 | from models.models import create_model 9 | from util.visualizer import Visualizer 10 | from util.spectro_img import compute_visuals 11 | 12 | # Initilize the setup 13 | opt = TrainOptions().parse() 14 | visualizer = Visualizer(opt) 15 | data_loader = CreateDataLoader(opt) 16 | dataset = data_loader.get_train_dataloader() 17 | dataset_size = len(data_loader) 18 | model = create_model(opt) 19 | print('#audio segments = %d' % dataset_size) 20 | 21 | 22 | # Forward pass 23 | spectro_mag = [] 24 | spectro_pha = [] 25 | norm_params = [] 26 | audio = [] 27 | model.eval() 28 | stride = opt.segment_length-opt.gen_overlap 29 | with torch.no_grad(): 30 | for i, data in enumerate(dataset): 31 | sr_spectro, sr_audio, lr_pha, norm_param, lr_spectro = model.inference( 32 | data['LR_audio'].cuda()) 33 | print(sr_spectro.size()) 34 | # spectro_mag.append(sr_spectro) 35 | # spectro_pha.append(lr_pha) 36 | # norm_params.append(norm_param) 37 | audio.append(sr_audio.cpu()) 38 | 39 | # Concatenate the audio 40 | if opt.gen_overlap > 0: 41 | from torch.nn.functional import fold 42 | out_len = (dataset_size-1) * stride + opt.segment_length 43 | print(out_len) 44 | audio = torch.cat(audio,dim=0) 45 | print(audio.shape) 46 | audio[...,:opt.gen_overlap] *= 0.5 47 | audio[...,-opt.gen_overlap:] *= 0.5 48 | audio = audio.squeeze().transpose(-1,-2) 49 | audio = fold(audio, kernel_size=(1,opt.segment_length), stride=(1,stride), output_size=(1,out_len)).squeeze(0) 50 | audio = audio[...,opt.gen_overlap:-opt.gen_overlap] 51 | print(audio.shape) 52 | else: 53 | audio = torch.cat(audio, dim=0).view(1, -1) 54 | audio_len = data_loader.train_dataset.raw_audio.size(-1) 55 | # print(audio.size()) 56 | 57 | # Evaluate the matrics 58 | audio_len = data_loader.train_dataset.raw_audio.size(-1) 59 | _mse, _snr_sr, _snr_lr, _ssnr_sr, _ssnr_lr, _pesq, _lsd = compute_matrics( 60 | data_loader.train_dataset.raw_audio, data_loader.train_dataset.lr_audio[..., :audio_len], audio[..., :audio_len], opt) 61 | print('MSE: %.4f' % _mse) 62 | print('SNR_SR: %.4f' % _snr_sr) 63 | print('SNR_LR: %.4f' % _snr_lr) 64 | #print('SSNR_SR: %.4f' % _ssnr_sr) 65 | #print('SSNR_LR: %.4f' % _ssnr_lr) 66 | #print('PESQ: %.4f' % _pesq) 67 | print('LSD: %.4f' % _lsd) 68 | 69 | # # Generate visuals 70 | # lr_mag, _, sr_mag, _, _, _, _, _ = model.encode_input( 71 | # lr_audio=data_loader.dataset.lr_audio, hr_audio=audio) 72 | # if opt.explicit_encoding: 73 | # lr_mag = 0.5*(lr_mag[:, 0, :, :]+lr_mag[:, 1, :, :]) 74 | # sr_mag = 0.5*(sr_mag[:, 0, :, :]+sr_mag[:, 1, :, :]) 75 | # lr_spectro, lr_hist, _ = compute_visuals( 76 | # sp=lr_mag.squeeze().detach().cpu().numpy(), abs=True) 77 | # sr_spectro, sr_hist, _ = compute_visuals( 78 | # sp=sr_mag.squeeze().detach().cpu().numpy(), abs=True) 79 | # visuals = {'lable_spectro': lr_spectro, 80 | # 'generated_spectro': sr_spectro, 81 | # 'lable_hist': lr_hist, 82 | # 'generated_hist': sr_hist} 83 | 84 | # # Save files 85 | # visualizer.display_current_results(visuals, 1, 1) 86 | with open(os.path.join(opt.checkpoints_dir, opt.name, 'metric.txt'), 'w') as f: 87 | f.write('MSE,SNR_SR,LSD\n') 88 | f.write('%f,%f,%f' % (_mse, _snr_sr, _lsd)) 89 | sr_path = os.path.join(opt.checkpoints_dir, opt.name, 'sr_audio.wav') 90 | lr_path = os.path.join(opt.checkpoints_dir, opt.name, 'lr_audio.wav') 91 | hr_path = os.path.join(opt.checkpoints_dir, opt.name, 'hr_audio.wav') 92 | torchaudio.save(sr_path, audio.cpu().to(torch.float32), opt.hr_sampling_rate) 93 | torchaudio.save(lr_path, data_loader.train_dataset.lr_audio.cpu(), 94 | opt.hr_sampling_rate) 95 | torchaudio.save(hr_path, data_loader.train_dataset.raw_audio.cpu(), 96 | data_loader.train_dataset.in_sampling_rate) 97 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from threading import Thread 5 | 6 | 7 | def CreateDataset(opt): 8 | train_dataset = None 9 | if opt.phase == 'train': 10 | from data.audio_dataset import AudioDataset 11 | train_dataset = AudioDataset(opt) 12 | eval_dataset = AudioDataset(opt, True) 13 | print("dataset [%s] was created" % (train_dataset.name())) 14 | print("dataset [%s] was created" % (eval_dataset.name())) 15 | elif opt.phase == 'test': 16 | from data.audio_dataset import AudioTestDataset 17 | train_dataset = AudioTestDataset(opt) 18 | print("dataset [%s] was created" % (train_dataset.name())) 19 | eval_dataset = None 20 | 21 | # dataset.initialize(opt) 22 | return train_dataset, eval_dataset 23 | 24 | class CustomDatasetDataLoader(BaseDataLoader): 25 | def name(self): 26 | return 'CustomDatasetDataLoader' 27 | 28 | def initialize(self, opt): 29 | self.q_size = 16 30 | self.idx = 0 31 | self.load_stream = torch.cuda.Stream(device='cuda') 32 | self.queue: Queue = Queue(maxsize=self.q_size) 33 | BaseDataLoader.initialize(self, opt) 34 | self.train_dataset, self.eval_dataset = CreateDataset(opt) 35 | self.data_lenth = len(self.train_dataset) 36 | self.eval_data_lenth = len(self.eval_dataset) if self.eval_dataset is not None else None 37 | if opt.phase == "train": 38 | self.train_dataloader = torch.utils.data.DataLoader( 39 | self.train_dataset, 40 | batch_size=opt.batchSize, 41 | shuffle=True, 42 | num_workers=int(opt.nThreads), 43 | prefetch_factor=8, 44 | pin_memory=True) 45 | 46 | self.eval_dataloder = torch.utils.data.DataLoader( 47 | self.eval_dataset, 48 | batch_size=opt.batchSize, 49 | shuffle=True, 50 | num_workers=int(opt.nThreads), 51 | pin_memory=True) 52 | 53 | elif opt.phase == "test": 54 | self.train_dataloader = torch.utils.data.DataLoader( 55 | self.train_dataset, 56 | batch_size=opt.batchSize, 57 | num_workers=int(opt.nThreads), 58 | shuffle=False, 59 | pin_memory=True) 60 | self.eval_dataloder = None 61 | self.eval_data_lenth = 0 62 | 63 | def load_loop(self) -> None: # The loop that will load into the queue in the background 64 | for i, sample in enumerate(self.train_dataloader): 65 | self.queue.put(self.load_instance(sample)) 66 | if i == len(self): 67 | break 68 | 69 | def load_instance(self, sample:dict): 70 | with torch.cuda.stream(self.load_stream): 71 | return {k:v.cuda(non_blocking=True) for k,v in sample.items()} 72 | 73 | def get_train_dataloader(self): 74 | return self.train_dataloader 75 | 76 | def async_load_data(self): 77 | return self 78 | 79 | def get_eval_dataloader(self): 80 | return self.eval_dataloder 81 | 82 | def eval_data_len(self): 83 | return self.eval_data_lenth 84 | 85 | def __len__(self): 86 | return self.data_lenth 87 | 88 | def __iter__(self): 89 | if_worker = not hasattr(self, "worker") or not self.worker.is_alive() # type: ignore[has-type] 90 | if if_worker and self.queue.empty() and self.idx == 0: 91 | self.worker = Thread(target=self.load_loop) 92 | self.worker.daemon = True 93 | self.worker.start() 94 | return self 95 | 96 | def __next__(self): 97 | # If we've reached the number of batches to return 98 | # or the queue is empty and the worker is dead then exit 99 | done = not self.worker.is_alive() and self.queue.empty() 100 | done = done or self.idx >= len(self) 101 | if done: 102 | self.idx = 0 103 | self.queue.join() 104 | self.worker.join() 105 | raise StopIteration 106 | # Otherwise return the next batch 107 | out = self.queue.get() 108 | self.queue.task_done() 109 | self.idx += 1 110 | return out 111 | 112 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | class BaseModel(torch.nn.Module): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | #self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | self.device = 'cuda' if len(self.gpu_ids) > 0 else 'cpu' 16 | 17 | def set_input(self, input): 18 | self.input = input 19 | 20 | def forward(self): 21 | pass 22 | 23 | # used in test time, no backprop 24 | def test(self): 25 | pass 26 | 27 | def get_image_paths(self): 28 | pass 29 | 30 | def optimize_parameters(self): 31 | pass 32 | 33 | def get_current_visuals(self): 34 | return self.input 35 | 36 | def get_current_errors(self): 37 | return {} 38 | 39 | def save(self, label): 40 | pass 41 | 42 | # helper saving function that can be used by subclasses 43 | def save_network(self, network, network_label, epoch_label, gpu_ids): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | torch.save(network.state_dict(), save_path) 47 | 48 | # helper loading function that can be used by subclasses 49 | def load_network(self, network, network_label, epoch_label, save_dir=''): 50 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 51 | if not save_dir: 52 | save_dir = self.save_dir 53 | save_path = os.path.join(save_dir, save_filename) 54 | if not os.path.isfile(save_path): 55 | print('%s not exists yet!' % save_path) 56 | if network_label == 'G': 57 | raise('Generator must exist!') 58 | else: 59 | #network.load_state_dict(torch.load(save_path)) 60 | try: 61 | network.load_state_dict(torch.load(save_path, map_location='cpu')) 62 | network.to(self.device) 63 | except: 64 | pretrained_dict = torch.load(save_path) 65 | model_dict = network.state_dict() 66 | try: 67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 68 | network.load_state_dict(pretrained_dict) 69 | if self.opt.verbose: 70 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 71 | except: 72 | print('Pretrained network %s has fewer layers; The following layers are possibly matched:' % network_label) 73 | pretrained_dict = torch.load(save_path) 74 | module_map = self.opt.param_key_map 75 | for name, param in pretrained_dict.items(): 76 | if name not in model_dict or param.size()!=model_dict[name].size(): 77 | #print('No match %s. Try to find mapping...'%name) 78 | layer_name = name.split('.') 79 | key = layer_name[0]+'.'+layer_name[1] 80 | if key in module_map: 81 | layer_name[1] = module_map[key] 82 | name_ = name 83 | name = "." 84 | name = name.join(layer_name) 85 | print(" ",name_,'->',name) 86 | else: 87 | for k, v in model_dict.items(): 88 | if v.size() == param.size(): 89 | print(" ",k,":",name) 90 | continue 91 | # if isinstance(param, torch.nn.Parameter): 92 | # # backwards compatibility for serialized parameters 93 | # param = param.data 94 | model_dict[name]=param 95 | # for k, v in pretrained_dict.items(): 96 | # if v.size() == model_dict[k].size(): 97 | # print('Layer %s initialized'%k) 98 | # model_dict[k] = v 99 | 100 | # if sys.version_info >= (3,0): 101 | # not_initialized = set() 102 | # else: 103 | # from sets import Set 104 | # not_initialized = Set() 105 | 106 | # for k, v in model_dict.items(): 107 | # if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 108 | # not_initialized.add(k.split('.')[0]) 109 | 110 | # print(sorted(not_initialized)) 111 | network.load_state_dict(model_dict) 112 | 113 | def update_learning_rate(): 114 | pass 115 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | import scipy.misc 8 | try: 9 | from StringIO import StringIO # Python 2.7 10 | except ImportError: 11 | from io import BytesIO # Python 3.x 12 | 13 | class Visualizer(): 14 | def __init__(self, opt): 15 | # self.opt = opt 16 | self.tf_log = opt.tf_log 17 | self.use_html = opt.isTrain and not opt.no_html 18 | self.win_size = opt.display_winsize 19 | self.name = opt.name 20 | if self.tf_log: 21 | from torch.utils.tensorboard import SummaryWriter 22 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 23 | self.writer = SummaryWriter(self.log_dir) 24 | 25 | if self.use_html: 26 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 27 | self.img_dir = os.path.join(self.web_dir, 'images') 28 | print('create web directory %s...' % self.web_dir) 29 | util.mkdirs([self.web_dir, self.img_dir]) 30 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 31 | with open(self.log_name, "a") as log_file: 32 | now = time.strftime("%c") 33 | log_file.write('================ Training Loss (%s) ================\n' % now) 34 | 35 | # |visuals|: dictionary of images to display or save 36 | def display_current_results(self, visuals, epoch, step): 37 | if self.tf_log: # show images in tensorboard output 38 | for label, image_numpy in visuals.items(): 39 | # Create an Image object 40 | if 'spectro' in label: 41 | cat = 'spctro/' 42 | elif 'hist' in label: 43 | cat = 'histogram/' 44 | elif 'pha' in label: 45 | cat = 'phase/' 46 | self.writer.add_image(cat+label, image_numpy, step, dataformats='HWC') 47 | 48 | if self.use_html: # save images to a html file 49 | for label, image_numpy in visuals.items(): 50 | if isinstance(image_numpy, list): 51 | for i in range(len(image_numpy)): 52 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 53 | util.save_image(image_numpy[i], img_path) 54 | else: 55 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) 56 | util.save_image(image_numpy, img_path) 57 | 58 | # update website 59 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 60 | for n in range(epoch, 0, -1): 61 | webpage.add_header('epoch [%d]' % n) 62 | ims = [] 63 | txts = [] 64 | links = [] 65 | 66 | for label, image_numpy in visuals.items(): 67 | if isinstance(image_numpy, list): 68 | for i in range(len(image_numpy)): 69 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) 70 | ims.append(img_path) 71 | txts.append(label+str(i)) 72 | links.append(img_path) 73 | else: 74 | img_path = 'epoch%.3d_%s.jpg' % (n, label) 75 | ims.append(img_path) 76 | txts.append(label) 77 | links.append(img_path) 78 | if len(ims) < 10: 79 | webpage.add_images(ims, txts, links, width=self.win_size) 80 | else: 81 | num = int(round(len(ims)/2.0)) 82 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 83 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 84 | webpage.save() 85 | 86 | # errors: dictionary of error labels and values 87 | def plot_current_errors(self, errors, step): 88 | if self.tf_log: 89 | self.writer.add_scalars('Losses', errors, step) 90 | 91 | # errors: same format as |errors| of plotCurrentErrors 92 | def print_current_errors(self, epoch, i, errors, t): 93 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 94 | for k, v in errors.items(): 95 | if v != 0: 96 | message += '%s: %.3f ' % (k, v) 97 | 98 | print(message) 99 | with open(self.log_name, "a") as log_file: 100 | log_file.write('%s\n' % message) 101 | 102 | # save image to the disk 103 | def save_images(self, webpage, visuals, image_path): 104 | image_dir = webpage.get_image_dir() 105 | short_path = ntpath.basename(image_path[0]) 106 | name = os.path.splitext(short_path)[0] 107 | 108 | webpage.add_header(name) 109 | ims = [] 110 | txts = [] 111 | links = [] 112 | 113 | for label, image_numpy in visuals.items(): 114 | image_name = '%s_%s.jpg' % (name, label) 115 | save_path = os.path.join(image_dir, image_name) 116 | util.save_image(image_numpy, save_path) 117 | 118 | ims.append(image_name) 119 | txts.append(label) 120 | links.append(image_name) 121 | webpage.add_images(ims, txts, links, width=self.win_size) 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

mdctGAN: Taming transformer-based GAN for speech super-resolution with Modified DCT spectra

3 |
4 | 5 |
6 | By: Chenhao Shuai, Chaohua Shi, Lu Gan and Hongqing Liu 7 |
8 | 9 |
10 | Accepted in the INTERSPEECH 2023 [arXiv] 11 |
12 | 13 | ## Requirements 14 | * bottleneck_transformer_pytorch==0.1.4 15 | * dominate 16 | * einops 17 | * matplotlib 18 | * numpy 19 | * Pillow 20 | * scipy 21 | * torch 22 | * torchaudio 23 | * torchvision 24 | * torch_scatter (Optional if you want to use `FastMDCT4`) 25 | 26 | ## Pretrained Models 27 | [![HF Models](https://img.shields.io/badge/%F0%9F%A4%97Hugging%20Face-models-green)](https://huggingface.co/neoncloud/mdctGAN) 28 | 29 | ## Data Preparation 30 | Firstly, for excessively long speech audio file, we recommend that you remove long gaps and split it into smaller segments. Other than this no other pre-processing is required, the program will automatically sample a random section from the longer audio file. 31 | 32 | It also automatically resamples the high sample rate audio to the low sample rate and upsamples it again to the target sample rate. This process simulates the loss of speech after downsampling. And up-sampling again aligns the low-res audio with the original high sample rate audio. So you don't need to manually resample the original audio. 33 | 34 | Secondly, Prepare your dataset index file like this (VCTK dataset example): 35 | ``` 36 | wav48/p250/p250_328.wav 37 | wav48/p310/p310_345.wav 38 | wav48/p227/p227_020.wav 39 | wav48/p285/p285_050.wav 40 | wav48/p248/p248_011.wav 41 | wav48/p246/p246_030.wav 42 | wav48/p247/p247_191.wav 43 | wav48/p287/p287_127.wav 44 | wav48/p334/p334_220.wav 45 | wav48/p340/p340_414.wav 46 | wav48/p236/p236_231.wav 47 | wav48/p301/p301_334.wav 48 | ... 49 | ``` 50 | Save it to the root directory of your dataset as a text file and the program will splice the parent folder of index file with the relative path of the records in the file. You can also find the index file used in our experiments in `data/train.csv`. 51 | 52 | ## Train 53 | Modify & run `sh train.sh`. Detailed explanation of args can be found in `options/base_options.py` and `options/train_options.py` 54 | 55 | 56 | | Parameter Name | Description | 57 | |----------------------|--------------------------------------------------------------------------------------------------| 58 | | --name | Name of the experiment. It decides where to store samples and models. | 59 | | **--dataroot** | Path to your train set csv file. | 60 | | **--evalroot** | Path to your eval set csv file. | 61 | | **--lr_sampling_rate** | Input Low-res sampling rate. It will be automatically resampled to this value. | 62 | | **--sr_sampling_rate** | Target super-resolution sampling rate. | 63 | | --fp16 | Train with Automatic Mixed Precision (AMP). | 64 | | --nThreads | Number of threads for loading data. | 65 | | --lr | Initial learning rate for the Adam optimizer. | 66 | | --arcsinh_transform | Use $\log(x+\sqrt{x^2+1})$ to compress the range of input. | 67 | | --abs_spectro | Use the absolute value of the spectrogram. | 68 | | --arcsinh_gain | Gain parameter for the arcsinh_transform. | 69 | | --center | Centered MDCT. | 70 | | --norm_range | Specify the target distribution range. | 71 | | --abs_norm | Assume the spectrograms are all distributed in a fixed range. Normalize by an absolute range. | 72 | | --src_range | Specify the source distribution range. Used when --abs_norm is specified. | 73 | | --netG | Select the model to use for netG. | 74 | | --ngf | Number of generator filters in the first conv layer. | 75 | | --n_downsample_global| Number of downsampling layers in netG. | 76 | | --n_blocks_global | Number of residual blocks in the global generator network. | 77 | | --n_blocks_attn_g | Number of attention blocks in the global generator network. | 78 | | --dim_head_g | Dimension of attention heads in the global generator network. | 79 | | --heads_g | Number of attention heads in the global generator network. | 80 | | --proj_factor_g | Projection factor of attention blocks in the global generator network. | 81 | | --n_blocks_local | Number of residual blocks in the local enhancer network. | 82 | | --n_blocks_attn_l | Number of attention blocks in the local enhancer network. | 83 | | --fit_residual | If specified, fit $HR-LR$ than directly fit $HR$. | 84 | | --upsample_type | Select upsampling layers for netG. Supported options: interpolate, transconv. | 85 | | --downsample_type | Select downsampling layers for netG. Supported options: resconv, conv. | 86 | | --num_D | Number of discriminators to use. | 87 | | --eval_freq | Frequency of evaluating metrics. | 88 | | --save_latest_freq | Frequency of saving the latest results. | 89 | | --save_epoch_freq | Frequency of saving checkpoints at the end of epochs. | 90 | | --display_freq | Frequency of showing training results on screen. | 91 | | --tf_log | If specified, use TensorBoard logging. Requires TensorFlow installed. | 92 | 93 | ## Evaluate & Generate audio 94 | Modify & run `sh gen_audio.sh`. 95 | 96 | ## Acknowledgement 97 | This code repository refers heavily to the [official pix2pixHD implementation](https://github.com/NVIDIA/pix2pixHD). Also, this work is based on an improved version of my undergraduate Final Year Project, see: [pix2pixHDAudioSR](https://github.com/neoncloud/pix2pixHDAudioSR) 98 | 99 | ## Bonus 100 | Try `FastMDCT4`/`FastIMDCT4` in `models/mdct.py` to have faster MDCT conversion. You can use `FastMDCT4` as an in-place replacement for `MDCT4`, or modify the import statement in `models/pix2pixHD_model.py` to `from .mdct import FastMDCT4 as MDCT4, FastIMDCT4 as IMDCT4` 101 | 102 | On my computer (RTX3070 laptop, Intel Core i7 11800H), each forward transformation saves 2ms. 103 | 104 | ```python 105 | sig = torch.randn(64,32512, device='cuda') 106 | %timeit -r 20 -n 500 mdct(sig) 107 | # 9.61 ms ± 643 µs per loop (mean ± std. dev. of 20 runs, 500 loops each) 108 | %timeit -r 20 -n 500 fast_mdct(sig) 109 | # 7.68 ms ± 691 µs per loop (mean ± std. dev. of 20 runs, 500 loops each) 110 | ``` 111 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | from .audio_config import * 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | # for displays 8 | self.parser.add_argument('--display_freq', type=int, default=200, help='frequency of showing training results on screen') 9 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 10 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 11 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 12 | self.parser.add_argument('--eval_freq', type=int, default=32000, help='frequency of evaluating matrics') 13 | self.parser.add_argument('--loss_update_freq', type=int, default=256, help='frequency of updating scalers of auxiliary losses') 14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 16 | self.parser.add_argument('--abs_spectro', action='store_true', help='use absolute value of spectrogram') 17 | 18 | # for training 19 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 20 | self.parser.add_argument('--freeze_g_d', action='store_true', help='freeze downsample in G_g') 21 | self.parser.add_argument('--freeze_g_u', action='store_true', help='freeze upsample in G_g') 22 | self.parser.add_argument('--freeze_l_d', action='store_true', help='freeze downsample in G_l') 23 | self.parser.add_argument('--freeze_l_u', action='store_true', help='freeze upsample in G_l') 24 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 25 | self.parser.add_argument('--param_key_map', type=lambda x: {str(k):str(v) for k,v in (i.split(':') for i in x.split(','))}, default={}, help='if the pretrained model do not match the current model, this is helpful to map the pretrained modules to current model.') 26 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 27 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 28 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 29 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 30 | self.parser.add_argument('--niter_limit_aux', type=int, default=20, help='# of iter to limit auxiliary losses') 31 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 32 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 33 | self.parser.add_argument('--validation_split', type=float, default=0.05, help='path to file containing validation split indices') 34 | self.parser.add_argument('--val_indices', type=str, help='proportion of training data to be used as validation data if validation_split is not specified') 35 | self.parser.add_argument('--eval_size', type=int, default=100, help='how many samples to evaluate') 36 | self.parser.add_argument('--phase_encoding_mode', type=str, default=None, help='norm_dist|uni_dist|None') 37 | 38 | # for discriminators 39 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 40 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 41 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 42 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 43 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 44 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 45 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 46 | # self.parser.add_argument('--num_mr_D', type=int, default=2, help='number of multires discriminators to use') 47 | # self.parser.add_argument('--lambda_mat', type=float, default=0.01, help='weight for phase matching loss') 48 | # self.parser.add_argument('--lambda_time', type=float, default=0.4, help='weight for time domain loss') 49 | # self.parser.add_argument('--lambda_mr', type=float, default=0.08, help='weight for Multi-Res D loss') 50 | # self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss', default=True) 51 | # self.parser.add_argument('--use_match_loss', action='store_true', help='if specified, use matching loss') 52 | # self.parser.add_argument('--match_loss_count', type=int, default=1, help='let the match loss gradually increase after this number of iterations') 53 | # self.parser.add_argument('--match_loss_count_max', type=int, default=10000, help='let the match loss gradually increase after this number of iterations') 54 | # self.parser.add_argument('--match_loss_thres', type=float, default=0.5, help='if the matching loss is greater than this threshold, the loss will be discarded') 55 | # self.parser.add_argument('--use_hifigan_D', action='store_true', help='if specified, use multi-scale-multi-period hifigan time domain discriminator') 56 | # self.parser.add_argument('--use_time_D', action='store_true', help='if specified, use time domain discriminator') 57 | # self.parser.add_argument('--use_multires_D', action='store_true', help='if specified, use Multi-Resolution discriminator') 58 | # self.parser.add_argument('--use_shifted_match', action='store_true', help='if specified, shift audios randomly, then concat to the original ones as Discriminator inputs') 59 | # self.parser.add_argument('--time_D_count', type=int, default=0, help='let the time D gradually increase after this number of iterations') 60 | # self.parser.add_argument('--time_D_count_max', type=int, default=150, help='let the time D gradually increase after this number of iterations') 61 | 62 | # STFT params 63 | self.parser.add_argument('--lr_sampling_rate', type=int, default=LR_SAMPLE_RATE, help='low resolution sampling rate') 64 | self.parser.add_argument('--hr_sampling_rate', type=int, default=HR_SAMPLE_RATE, help='high resolution sampling rate') 65 | self.parser.add_argument('--sr_sampling_rate', type=int, default=SR_SAMPLE_RATE, help='target resolution sampling rate') 66 | self.parser.add_argument('--segment_length', type=int, default=FRAME_LENGTH, help='audio segment length') 67 | self.parser.add_argument('--gen_overlap', type=int, default=0, help='overlap length when generating. It is helpful to eliminate the transient effect') 68 | self.parser.add_argument('--n_fft', type=int, default=N_FFT, help='num of FFT points') 69 | self.parser.add_argument('--bins', type=int, default=BINS, help='num of time bins. This does not effect anything.') 70 | self.parser.add_argument('--hop_length', type=int, default=HOP_LENGTH, help='sliding window increament') 71 | self.parser.add_argument('--win_length', type=int, default=WIN_LENGTH, help='sliding window width') 72 | self.parser.add_argument('--center', action='store_true', help='centered FFT') 73 | self.parser.add_argument('--is_lr_input', action='store_true', help='if specified, the audio generator will assert the input as low res. And it will only do upsampling.') 74 | self.isTrain = True 75 | -------------------------------------------------------------------------------- /data/audio_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from numpy import ceil 4 | import torch 5 | import torch.nn.functional as F 6 | import torchaudio 7 | import torchaudio.functional as aF 8 | from data.base_dataset import BaseDataset 9 | torchaudio.set_audio_backend('sox_io') 10 | 11 | class AudioDataset(BaseDataset): 12 | def __init__(self, opt, test=False) -> None: 13 | BaseDataset.__init__(self) 14 | self.lr_sampling_rate = opt.lr_sampling_rate 15 | self.hr_sampling_rate = opt.hr_sampling_rate 16 | self.segment_length = opt.segment_length 17 | self.n_fft = opt.n_fft 18 | self.hop_length = opt.hop_length 19 | self.win_length = opt.win_length 20 | self.audio_file = self.get_files(opt.evalroot if test else opt.dataroot) 21 | self.audio_len = [(0,0)]*len(self.audio_file) 22 | self.center = opt.center 23 | self.add_noise = opt.add_noise 24 | self.snr = opt.snr 25 | 26 | torch.manual_seed(opt.seed) 27 | 28 | def __len__(self): 29 | return len(self.audio_file) 30 | 31 | def name(self): 32 | return 'AudioMDCTSpectrogramDataset' 33 | 34 | def readaudio(self, idx): 35 | file_path = self.audio_file[idx] 36 | if self.audio_len[idx][1] == 0: 37 | metadata = torchaudio.info(file_path) 38 | audio_length = metadata.num_frames 39 | fs = metadata.sample_rate 40 | self.audio_len[idx] = (fs,audio_length) 41 | else: 42 | fs,audio_length = self.audio_len[idx] 43 | max_audio_start = int(audio_length - self.segment_length*fs/self.hr_sampling_rate) 44 | if max_audio_start > 0: 45 | offset = torch.randint( 46 | low=0, high=max_audio_start, size=(1,)).item() 47 | waveform, orig_sample_rate = torchaudio.load( 48 | file_path, frame_offset=offset, num_frames=self.segment_length) 49 | else: 50 | #print("Warning: %s is shorter than segment_length"%file_path, audio_length) 51 | waveform, orig_sample_rate = torchaudio.load(file_path) 52 | return waveform, orig_sample_rate 53 | 54 | def __getitem__(self, idx): 55 | try: 56 | waveform, orig_sample_rate = self.readaudio(idx) 57 | except: # try next until success 58 | i = 1 59 | while 1: 60 | print('Load failed!') 61 | try: 62 | waveform, orig_sample_rate = self.readaudio(idx+i) 63 | break 64 | except: 65 | i += 1 66 | hr_waveform = aF.resample( 67 | waveform=waveform, orig_freq=orig_sample_rate, new_freq=self.hr_sampling_rate) 68 | lr_waveform = aF.resample( 69 | waveform=waveform, orig_freq=orig_sample_rate, new_freq=self.lr_sampling_rate) 70 | lr_waveform = aF.resample( 71 | waveform=lr_waveform, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate) 72 | if self.add_noise: 73 | noise = torch.randn(lr_waveform.size()) 74 | noise = noise-noise.mean() 75 | signal_power = torch.sum(lr_waveform**2)/self.segment_length 76 | noise_var = signal_power / 10**(self.snr/10) 77 | noise = torch.sqrt(noise_var)/noise.std()*noise 78 | lr_waveform = lr_waveform + noise 79 | # lr_waveform = aF.lowpass_biquad(waveform, sample_rate=self.hr_sampling_rate, cutoff_freq = self.lr_sampling_rate//2) #Meet the Nyquest sampling theorem 80 | hr = self.seg_pad_audio(hr_waveform) 81 | lr = self.seg_pad_audio(lr_waveform) 82 | return {'HR_audio': hr.squeeze(0), 'LR_audio': lr.squeeze(0)} 83 | 84 | def get_files(self, file_path): 85 | if os.path.isdir(file_path): 86 | print("Searching for audio file") 87 | file_list = [] 88 | for root, dirs, files in os.walk(file_path, topdown=False): 89 | for name in files: 90 | if os.path.splitext(name)[1] == ".wav" or ".mp3" or ".flac": 91 | file_list.append(os.path.join(root, name)) 92 | else: 93 | print("Using csv file list") 94 | root, csv_file = os.path.split(file_path) 95 | with open(file_path, 'r') as csv_file: 96 | csv_reader = csv.reader(csv_file) 97 | file_list = [os.path.join(root, item) for sublist in list( 98 | csv_reader) for item in sublist] 99 | print(len(file_list)) 100 | return file_list 101 | 102 | def seg_pad_audio(self, waveform): 103 | if waveform.size(1) >= self.segment_length: 104 | waveform = waveform[0][:self.segment_length] 105 | else: 106 | waveform = F.pad( 107 | waveform, (0, self.segment_length - 108 | waveform.size(1)), 'constant' 109 | ).data 110 | return waveform 111 | 112 | 113 | class AudioTestDataset(BaseDataset): 114 | def __init__(self, opt) -> None: 115 | BaseDataset.__init__(self) 116 | self.lr_sampling_rate = opt.lr_sampling_rate 117 | self.hr_sampling_rate = opt.hr_sampling_rate 118 | self.segment_length = opt.segment_length 119 | self.n_fft = opt.n_fft 120 | self.hop_length = opt.hop_length 121 | self.win_length = opt.win_length 122 | self.center = opt.center 123 | self.dataroot = opt.dataroot 124 | self.is_lr_input = opt.is_lr_input 125 | self.overlap = opt.gen_overlap 126 | self.add_noise = opt.add_noise 127 | self.snr = opt.snr 128 | 129 | self.read_audio() 130 | self.post_processing() 131 | 132 | def __len__(self): 133 | return self.seg_audio.size(0) 134 | 135 | def name(self): 136 | return 'AudioMDCTSpectrogramTestDataset' 137 | 138 | def __getitem__(self, idx): 139 | return {'LR_audio': self.seg_audio[idx, :].squeeze(0)} 140 | 141 | def read_audio(self): 142 | try: 143 | self.raw_audio, self.in_sampling_rate = torchaudio.load( 144 | self.dataroot) 145 | self.audio_len = self.raw_audio.size(-1) 146 | self.raw_audio += 1e-4 - torch.mean(self.raw_audio) 147 | print("Audio length:", self.audio_len) 148 | except: 149 | self.raw_audio = [] 150 | print("load audio failed") 151 | exit(0) 152 | 153 | def seg_pad_audio(self, audio): 154 | audio = audio.squeeze(0) 155 | length = len(audio) 156 | if length >= self.segment_length: 157 | num_segments = int(ceil(length/self.segment_length)) 158 | audio = F.pad(audio, (self.overlap, self.segment_length * 159 | num_segments - length + self.overlap), "constant").data 160 | audio = audio.unfold( 161 | dimension=0, size=self.segment_length, step=self.segment_length-self.overlap) 162 | else: 163 | audio = F.pad( 164 | audio, (0, self.segment_length - length), 'constant').data 165 | audio = audio.unsqueeze(0) 166 | 167 | return audio 168 | 169 | def post_processing(self): 170 | if self.is_lr_input: 171 | self.lr_audio = aF.resample( 172 | waveform=self.raw_audio, orig_freq=self.in_sampling_rate, new_freq=self.hr_sampling_rate) 173 | else: 174 | self.lr_audio = aF.resample( 175 | waveform=self.raw_audio, orig_freq=self.in_sampling_rate, new_freq=self.lr_sampling_rate) 176 | self.lr_audio = aF.resample( 177 | waveform=self.lr_audio, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate) 178 | if self.add_noise: 179 | noise = torch.randn(self.lr_audio.size()) 180 | noise = noise-noise.mean() 181 | signal_power = torch.sum(self.lr_audio**2)/self.segment_length 182 | noise_var = signal_power / 10**(self.snr/10) 183 | noise = torch.sqrt(noise_var)/noise.std()*noise 184 | self.lr_audio = self.lr_audio + noise 185 | self.seg_audio = self.seg_pad_audio(self.lr_audio) 186 | 187 | class AudioAppDataset(AudioTestDataset): 188 | def __init__(self, opt, audio:torch.Tensor, fs) -> None: 189 | self.lr_sampling_rate = opt.lr_sampling_rate 190 | self.hr_sampling_rate = opt.hr_sampling_rate 191 | self.segment_length = opt.segment_length 192 | self.n_fft = opt.n_fft 193 | self.hop_length = opt.hop_length 194 | self.win_length = opt.win_length 195 | self.center = opt.center 196 | self.dataroot = audio 197 | self.is_lr_input = opt.is_lr_input 198 | self.overlap = opt.gen_overlap 199 | self.add_noise = opt.add_noise 200 | self.snr = opt.snr 201 | self.raw_audio = audio 202 | self.in_sampling_rate = fs 203 | self.post_processing() 204 | def read_audio(self): 205 | pass -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import torchaudio.functional as aF 7 | from torch.nn.functional import conv1d 8 | #import pysepm 9 | 10 | # Converts a Tensor into a Numpy array 11 | # |imtype|: the desired type of the converted numpy array 12 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 13 | if isinstance(image_tensor, list): 14 | image_numpy = [] 15 | for i in range(len(image_tensor)): 16 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 17 | return image_numpy 18 | image_numpy = image_tensor.cpu().float().numpy() 19 | if normalize: 20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 21 | else: 22 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 23 | image_numpy = np.clip(image_numpy, 0, 255) 24 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 25 | image_numpy = image_numpy[:,:,0] 26 | return image_numpy.astype(imtype) 27 | 28 | # Converts a one-hot tensor into a colorful label map 29 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 30 | if n_label == 0: 31 | return tensor2im(label_tensor, imtype) 32 | label_tensor = label_tensor.cpu().float() 33 | if label_tensor.size()[0] > 1: 34 | label_tensor = label_tensor.max(0, keepdim=True)[1] 35 | label_tensor = Colorize(n_label)(label_tensor) 36 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 37 | return label_numpy.astype(imtype) 38 | 39 | def save_image(image_numpy, image_path): 40 | image_pil = Image.fromarray(image_numpy) 41 | image_pil.save(image_path) 42 | 43 | def mkdirs(paths): 44 | if isinstance(paths, list) and not isinstance(paths, str): 45 | for path in paths: 46 | mkdir(path) 47 | else: 48 | mkdir(paths) 49 | 50 | def mkdir(path): 51 | if not os.path.exists(path): 52 | os.makedirs(path) 53 | 54 | ############################################################################### 55 | # Code from 56 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 57 | # Modified so it complies with the Citscape label map colors 58 | ############################################################################### 59 | def uint82bin(n, count=8): 60 | """returns the binary of integer n, count refers to amount of bits""" 61 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 62 | 63 | def labelcolormap(N): 64 | if N == 35: # cityscape 65 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 66 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 67 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 68 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 69 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 70 | dtype=np.uint8) 71 | else: 72 | cmap = np.zeros((N, 3), dtype=np.uint8) 73 | for i in range(N): 74 | r, g, b = 0, 0, 0 75 | id = i 76 | for j in range(7): 77 | str_id = uint82bin(id) 78 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 79 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 80 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 81 | id = id >> 3 82 | cmap[i, 0] = r 83 | cmap[i, 1] = g 84 | cmap[i, 2] = b 85 | return cmap 86 | 87 | class Colorize(object): 88 | def __init__(self, n=35): 89 | self.cmap = labelcolormap(n) 90 | self.cmap = torch.from_numpy(self.cmap[:n]) 91 | 92 | def __call__(self, gray_image): 93 | size = gray_image.size() 94 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 95 | 96 | for label in range(0, len(self.cmap)): 97 | mask = (label == gray_image[0]).cpu() 98 | color_image[0][mask] = self.cmap[label][0] 99 | color_image[1][mask] = self.cmap[label][1] 100 | color_image[2][mask] = self.cmap[label][2] 101 | 102 | return color_image 103 | 104 | def imdct(spectro, pha, norm_param, _imdct, min_value=1e-7, up_ratio=1, explicit_encoding=False): 105 | device = spectro.device 106 | spectro = torch.abs(spectro)*(norm_param['max'].to(device)-norm_param['min'].to(device))+norm_param['min'].to(device) 107 | #log_mag = log_mag*norm_param['std']+norm_param['mean'] 108 | spectro = aF.DB_to_amplitude(spectro.to(device),10,0.5)-min_value 109 | if explicit_encoding: 110 | pha = pha.squeeze() 111 | psudo_pha = torch.sign(spectro[...,0,:,:]-spectro[...,1,:,:]) 112 | spectro = spectro[...,0,:,:]+spectro[...,1,:,:] 113 | if up_ratio > 1: 114 | size = pha.size(-2) 115 | if pha.dim() != 3: 116 | pha = pha.unsqueeze(0) 117 | pha = torch.cat((pha[...,:int(size*(1/up_ratio)),:],psudo_pha[...,int(size*(1/up_ratio)):,:]),dim=-2) 118 | else: 119 | if up_ratio > 1: 120 | size = pha.size(-2) 121 | psudo_pha = 2*torch.randint(low=0,high=2,size=pha.size(),device=device)-1 122 | pha = torch.cat((pha[...,:int(size*(1/up_ratio)),:],psudo_pha[...,int(size*(1/up_ratio)):,:]),dim=-2) 123 | # BCHW -> BWH 124 | #print(spectro.shape) 125 | spectro = spectro*pha 126 | if explicit_encoding: 127 | audio = _imdct(spectro.permute(0,2,1).contiguous())/2 128 | else: 129 | audio = _imdct(spectro.squeeze(1).permute(0,2,1).contiguous())/2 130 | return audio 131 | 132 | def compute_matrics(hr_audio,lr_audio,sr_audio,opt): 133 | #print(hr_audio.shape,lr_audio.shape,sr_audio.shape) 134 | device = sr_audio.device 135 | hr_audio = hr_audio.to(device) 136 | lr_audio = lr_audio.to(device) 137 | 138 | # Calculate error 139 | mse = ((sr_audio-hr_audio)**2).mean().item() 140 | 141 | # Calculate SNR 142 | snr_sr = 10*torch.log10(torch.sum(hr_audio**2, dim=-1)/torch.sum((sr_audio-hr_audio)**2, dim=-1)).mean().item() 143 | snr_lr = 10*torch.log10(torch.sum(hr_audio**2,dim=-1)/torch.sum((lr_audio-hr_audio)**2,dim=-1)).mean().item() 144 | 145 | # Calculate segmental SNR 146 | #ssnr_sr = pysepm.SNRseg(clean_speech=hr_audio.numpy(), processed_speech=sr_audio.numpy(), fs=opt.hr_sampling_rate) 147 | #ssnr_lr = pysepm.SNRseg(clean_speech=hr_audio.numpy(), processed_speech=lr_audio.numpy(), fs=opt.hr_sampling_rate) 148 | 149 | # Calculate PESQ 150 | """ if hr_audio.dim() > 1: 151 | hr_audio = hr_audio.squeeze() 152 | sr_audio = sr_audio.squeeze() 153 | for i in range(hr_audio.size(-2)): 154 | p = [] 155 | h = hr_audio[i,:] 156 | s = sr_audio[i,:] 157 | try: 158 | pesq = pysepm.pesq(aF.resample(h, orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), aF.resample(s, orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), 16000) 159 | p.append(pesq) 160 | except: 161 | print('PESQ no utterance') 162 | 163 | pesq = np.mean(p) 164 | else: 165 | try: 166 | pesq = pysepm.pesq(aF.resample(hr_audio,orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), aF.resample(sr_audio,orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), 16000) 167 | except: 168 | pesq = 0 """ 169 | 170 | # Calculte STFT loss(LSD) 171 | hr_stft = aF.spectrogram(hr_audio, n_fft=2*opt.n_fft, hop_length=2*opt.hop_length, win_length=2*opt.win_length, window=kbdwin(2*opt.win_length).to(device), center=opt.center, pad=0, power=2, normalized=False) 172 | sr_stft = aF.spectrogram(sr_audio, n_fft=2*opt.n_fft, hop_length=2*opt.hop_length, win_length=2*opt.win_length, window=kbdwin(2*opt.win_length).to(device), center=opt.center, pad=0, power=2, normalized=False) 173 | hr_stft_log = torch.log10(hr_stft+1e-6) 174 | sr_stft_log = torch.log10(sr_stft+1e-6) 175 | lsd = torch.sqrt(torch.mean((hr_stft_log-sr_stft_log)**2,dim=-2)).mean().item() 176 | 177 | return mse,snr_sr,snr_lr,0,0,0,lsd 178 | 179 | def kbdwin(N:int, beta:float=12.0, device='cpu')->torch.Tensor: 180 | # Matlab style Kaiser-Bessel window 181 | # Author: Chenhao Shuai 182 | assert N%2==0, "N must be even" 183 | w = torch.kaiser_window(window_length=N//2+1, beta=beta*torch.pi, periodic=False, device=device) 184 | w_sum = w.sum() 185 | wdw_half = torch.sqrt(torch.cumsum(w,dim=0)/w_sum)[:-1] 186 | return torch.cat((wdw_half,wdw_half.flip(dims=(0,))),dim=0) 187 | 188 | def alignment(x,y,win_len=128): 189 | x_max_idx = torch.argmax(x) 190 | x_sample = x[...,int(x_max_idx-win_len//2):int(x_max_idx+win_len//2)] 191 | corr = conv1d(y,x_sample,dilation=0) 192 | 193 | 194 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | # experiment specifics 13 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') 14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 15 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 16 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use') 17 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 21 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP') 22 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') 23 | self.parser.add_argument('--seed', type=int, default=42, help='random seed for reproducing results') 24 | self.parser.add_argument('--fit_residual', action='store_true', default=False, help='if specified, fit HR-LR than directly fit HR') 25 | 26 | # input/output sizes 27 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 28 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') 29 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 30 | self.parser.add_argument('--label_nc', type=int, default=0, help='# of input label channels') 31 | self.parser.add_argument('--input_nc', type=int, default=2, help='# of input spectro channels') 32 | self.parser.add_argument('--output_nc', type=int, default=1, help='# of output spectro channels') 33 | 34 | # for setting inputs 35 | self.parser.add_argument('--dataroot', type=str, default='./datasets/vctk/train.csv') 36 | self.parser.add_argument('--evalroot', type=str, default='./datasets/vctk/test.csv') 37 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 40 | self.parser.add_argument('--explicit_encoding', action='store_true', help='if selected, using trick to encode phase') 41 | self.parser.add_argument('--alpha', type=float, default=0.6, help='phase encoding factor') 42 | self.parser.add_argument('--norm_range', type=float, default=(0,1), nargs=2, help='specify the target ditribution range') 43 | self.parser.add_argument('--abs_norm', action='store_true', help='if selected, assuming the spectrograms are all distributed in a fixed range. Thus instead of normalizing by min and max each by each, normalize by an absolute range.') 44 | self.parser.add_argument('--src_range', type=float, default=(-5,5), nargs=2, help='specify the source ditribution range. This value is used when --abs_norm is specified.') 45 | self.parser.add_argument('--arcsinh_transform', action='store_true', help='if selected, using log(G*x+sqrt(((G*x)^2+1))) to compressing the range of input. Do not use this option with --explicit_encoding') 46 | self.parser.add_argument('--raw_mdct', action='store_true', help='if selected, DO NO transform. Do not use this option with --explicit_encoding|arcsinh_transform') 47 | self.parser.add_argument('--arcsinh_gain', type=float, default=500, help='gain of arcsinh_trasform input') 48 | self.parser.add_argument('--add_noise', action='store_true', help='if selected, add some noise to input waveform') 49 | self.parser.add_argument('--snr', type=float, default=55, help='add noise by SnR (working if --add_noise is selected)') 50 | 51 | # for displays 52 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 53 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 54 | 55 | # for generator 56 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 57 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 58 | self.parser.add_argument('--upsample_type', type=str, default='transconv', help='selects upsampling layers for netG [transconv|interpolate]') 59 | self.parser.add_argument('--downsample_type', type=str, default='conv', help='selects upsampling layers for netG [resconv|conv]') 60 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 61 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network') 62 | self.parser.add_argument('--n_blocks_attn_g', type=int, default=1, help='number of attention blocks in the global generator network') 63 | self.parser.add_argument('--proj_factor_g', type=int, default=4, help='projection factor of attention blocks in the global generator network') 64 | self.parser.add_argument('--dim_head_g', type=int, default=128, help='dim of attention heads in the global generator network') 65 | self.parser.add_argument('--heads_g', type=int, default=4, help='number of attention heads in the global generator network') 66 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 67 | self.parser.add_argument('--n_blocks_attn_l', type=int, default=0, help='number of attention blocks in the local enhancer network') 68 | self.parser.add_argument('--proj_factor_l', type=int, default=4, help='projection factor of attention blocks in the local enhancers network') 69 | self.parser.add_argument('--dim_head_l', type=int, default=128, help='dim of attention heads in the local enhancers network') 70 | self.parser.add_argument('--heads_l', type=int, default=4, help='number of attention heads in the local enhancers network') 71 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 72 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 73 | 74 | # for instance-wise features 75 | # self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input', default=True) 76 | # self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') 77 | # self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') 78 | # self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') 79 | # self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') 80 | # self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 81 | # self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 82 | # self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') 83 | 84 | # input mask options 85 | self.parser.add_argument('--mask', action='store_true', help='mask high freq conponent of lr spectro') 86 | self.parser.add_argument('--smooth', type=float, default=0.0, help='smooth the edge of the sr and lr') 87 | self.parser.add_argument('--mask_hr', action='store_true', help='mask high freq conponent of hr spectro') 88 | self.parser.add_argument('--mask_mode', type=str, default=None, help='[None|mode0|mode1]') 89 | self.parser.add_argument('--min_value', type=float, default=1e-7, help='minimum value to cutoff the spectrogram') 90 | 91 | self.initialized = True 92 | 93 | def parse(self, save=True): 94 | if not self.initialized: 95 | self.initialize() 96 | self.opt = self.parser.parse_args() 97 | self.opt.isTrain = self.isTrain # train or test 98 | 99 | str_ids = self.opt.gpu_ids.split(',') 100 | self.opt.gpu_ids = [] 101 | for str_id in str_ids: 102 | id = int(str_id) 103 | if id >= 0: 104 | self.opt.gpu_ids.append(id) 105 | 106 | # set gpu ids 107 | if len(self.opt.gpu_ids) > 0: 108 | torch.cuda.set_device(self.opt.gpu_ids[0]) 109 | 110 | args = vars(self.opt) 111 | 112 | print('------------ Options -------------') 113 | for k, v in sorted(args.items()): 114 | print('%s: %s' % (str(k), str(v))) 115 | print('-------------- End ----------------') 116 | 117 | # save to the disk 118 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 119 | util.mkdirs(expr_dir) 120 | if save and not self.opt.continue_train: 121 | file_name = os.path.join(expr_dir, 'opt.txt') 122 | with open(file_name, 'wt') as opt_file: 123 | opt_file.write('------------ Options -------------\n') 124 | for k, v in sorted(args.items()): 125 | opt_file.write('%s: %s\n' % (str(k), str(v))) 126 | opt_file.write('-------------- End ----------------\n') 127 | return self.opt 128 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from data.data_loader import CreateDataLoader 2 | import signal 3 | from util.util import compute_matrics 4 | from util.visualizer import Visualizer 5 | from options.train_options import TrainOptions 6 | from models.models import create_model 7 | 8 | import math 9 | import os 10 | import time 11 | import csv 12 | import gc 13 | 14 | import numpy as np 15 | import torch 16 | 17 | 18 | 19 | def lcm(a, b): return abs(a * b)/math.gcd(a, b) if a and b else 0 20 | 21 | 22 | # import debugpy 23 | # debugpy.listen(("localhost", 5678)) 24 | # debugpy.wait_for_client() 25 | # os.environ['CUDA_VISIBLE_DEVICES']='0' 26 | torch.backends.cudnn.benchmark = True 27 | # Get the training options 28 | opt = TrainOptions().parse() 29 | # Set the seed 30 | torch.manual_seed(opt.seed) 31 | # Set the path for save the trainning losses 32 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 33 | eval_path = os.path.join(opt.checkpoints_dir, opt.name, 'eval.csv') 34 | 35 | if opt.continue_train: 36 | try: 37 | start_epoch, epoch_iter = np.loadtxt( 38 | iter_path, delimiter=',', dtype=int) 39 | except: 40 | start_epoch, epoch_iter = 1, 0 41 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 42 | else: 43 | start_epoch, epoch_iter = 1, 0 44 | 45 | # Create the data loader 46 | data_loader = CreateDataLoader(opt) 47 | train_dataloader = data_loader.get_train_dataloader() 48 | train_dataset_size = len(data_loader) 49 | eval_dataloader = data_loader.get_eval_dataloader() 50 | eval_dataset_size = data_loader.eval_data_len() 51 | print('#training data = %d' % train_dataset_size) 52 | print('#evaluating data = %d' % eval_dataset_size) 53 | 54 | # Create the model 55 | model = create_model(opt) 56 | visualizer = Visualizer(opt) 57 | optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D 58 | 59 | # IMDCT for evaluation 60 | # from util.util import kbdwin, imdct 61 | # # from dct.dct import IDCT 62 | # # _idct = IDCT() 63 | # _imdct = IMDCT4(window=kbdwin, win_length=opt.win_length, hop_length=opt.hop_length, n_fft=opt.n_fft, center=opt.center, out_length=opt.segment_length, device = 'cuda') 64 | 65 | if opt.fp16: 66 | from torch.cuda.amp import autocast as autocast 67 | from torch.cuda.amp import GradScaler 68 | # According to the offical tutorial, use only one GradScaler and backward losses separately 69 | # https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers 70 | scaler = GradScaler() 71 | 72 | 73 | # Set frequency for displaying information and saving 74 | opt.print_freq = lcm(opt.print_freq, opt.batchSize) 75 | if opt.debug: 76 | opt.display_freq = 1 77 | opt.print_freq = 1 78 | opt.niter = 1 79 | opt.niter_decay = 0 80 | opt.max_dataset_size = 10 81 | total_steps = (start_epoch-1) * train_dataset_size + epoch_iter 82 | display_delta = total_steps % opt.display_freq 83 | print_delta = total_steps % opt.print_freq 84 | save_delta = total_steps % opt.save_latest_freq 85 | eval_delta = total_steps % opt.eval_freq if opt.validation_split > 0 else -1 86 | # loss_update_delta = total_steps % opt.loss_update_freq if opt.use_time_D or opt.use_match_loss else -1 87 | 88 | # Safe ctrl-c 89 | end = False 90 | 91 | 92 | def signal_handler(signal, frame): 93 | print('You pressed Ctrl+C!') 94 | global end 95 | end = True 96 | 97 | 98 | signal.signal(signal.SIGINT, signal_handler) 99 | 100 | # Evaluation process 101 | # Wrap it as a function so that I dont have to free up memory manually 102 | 103 | 104 | def eval_model(): 105 | err = [] 106 | snr = [] 107 | snr_seg = [] 108 | pesq = [] 109 | lsd = [] 110 | for j, eval_data in enumerate(eval_dataloader): 111 | model.eval() 112 | lr_audio = eval_data['LR_audio'].cuda() 113 | hr_audio = eval_data['HR_audio'].cuda() 114 | with torch.no_grad(): 115 | _, sr_audio, _, _, _ = model.inference(lr_audio) 116 | _mse, _snr_sr, _snr_lr, _ssnr_sr, _ssnr_lr, _pesq, _lsd = compute_matrics( 117 | hr_audio.squeeze(), lr_audio.squeeze(), sr_audio.squeeze(), opt) 118 | err.append(_mse) 119 | snr.append((_snr_lr, _snr_sr)) 120 | snr_seg.append((_ssnr_lr, _ssnr_sr)) 121 | pesq.append(_pesq) 122 | lsd.append(_lsd) 123 | if j >= opt.eval_size: 124 | break 125 | 126 | eval_result = {'err': np.mean(err), 'snr': np.mean(snr), 'snr_seg': np.mean( 127 | snr_seg), 'pesq': np.mean(pesq), 'lsd': np.mean(lsd)} 128 | with open(eval_path, 'a') as csv_file: 129 | writer = csv.DictWriter(csv_file, fieldnames=eval_result.keys()) 130 | if csv_file.tell() == 0: 131 | writer.writeheader() 132 | writer.writerow(eval_result) 133 | print('Evaluation:', eval_result) 134 | model.train() 135 | 136 | 137 | # Training... 138 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 139 | epoch_start_time = time.time() 140 | if epoch != start_epoch: 141 | epoch_iter = epoch_iter % train_dataset_size 142 | if epoch > opt.niter_limit_aux: 143 | model.limit_aux_loss = True 144 | for i, data in enumerate(train_dataloader, start=epoch_iter): 145 | if end: 146 | print('exiting and saving the model at the epoch %d, iters %d' % 147 | (epoch, total_steps)) 148 | model.save('latest') 149 | model.save(epoch) 150 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 151 | exit(0) 152 | if total_steps % opt.print_freq == print_delta: 153 | iter_start_time = time.time() 154 | total_steps += opt.batchSize 155 | epoch_iter += opt.batchSize 156 | 157 | # Whether to collect output images 158 | save_fake = total_steps % opt.display_freq == display_delta 159 | 160 | ############## Forward Pass ###################### 161 | if opt.fp16: 162 | with autocast(): 163 | losses, _ = model._forward( 164 | data['LR_audio'].cuda(), data['HR_audio'].cuda(), infer=False) 165 | else: 166 | losses, _ = model._forward( 167 | data['LR_audio'].cuda(), data['HR_audio'].cuda(), infer=False) 168 | 169 | # Sum per device losses 170 | losses = [torch.mean(x) if not isinstance(x, int) 171 | else x for x in losses] 172 | loss_dict = dict(zip(model.loss_names, losses)) 173 | 174 | # Calculate final loss scalar 175 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + (loss_dict.get('D_fake_t', 0) + loss_dict.get( 176 | 'D_real_t', 0))*0.5 + (loss_dict.get('D_fake_mr', 0) + loss_dict.get('D_real_mr', 0))*0.5 177 | loss_G = loss_dict['G_GAN'] + loss_dict.get('G_mat', 0) + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get( 178 | 'G_VGG', 0) + loss_dict.get('G_GAN_t', 0) + loss_dict.get('G_GAN_mr', 0) + loss_dict.get('G_shift', 0) 179 | 180 | ############### Backward Pass #################### 181 | # update generator weights 182 | optimizer_G.zero_grad() 183 | if opt.fp16: 184 | #with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward() 185 | scaler.scale(loss_G).backward() 186 | scaler.step(optimizer_G) 187 | # update the scaler only once per iteration 188 | # scaler.update() 189 | else: 190 | loss_G.backward() 191 | optimizer_G.step() 192 | 193 | # update discriminator weights 194 | optimizer_D.zero_grad() 195 | if opt.fp16: 196 | #with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward() 197 | scaler.scale(loss_D).backward() 198 | scaler.step(optimizer_D) 199 | scaler.update() 200 | else: 201 | loss_D.backward() 202 | optimizer_D.step() 203 | 204 | ############## Display results and errors ########## 205 | # print out errors 206 | if total_steps % opt.print_freq == print_delta: 207 | errors = {k: v.data.item() if not isinstance( 208 | v, int) else v for k, v in loss_dict.items()} 209 | t = (time.time() - iter_start_time) / opt.print_freq 210 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 211 | visualizer.plot_current_errors(errors, total_steps) 212 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 213 | 214 | # display output images 215 | if save_fake: 216 | visuals = model.get_current_visuals() 217 | visualizer.display_current_results(visuals, epoch, total_steps) 218 | del visuals 219 | 220 | # save latest model 221 | if total_steps % opt.save_latest_freq == save_delta: 222 | print('saving the latest model (epoch %d, total_steps %d)' % 223 | (epoch, total_steps)) 224 | model.save('latest') 225 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 226 | 227 | if total_steps % opt.eval_freq == eval_delta: 228 | del losses, loss_D, loss_G, loss_dict 229 | torch.cuda.empty_cache() 230 | gc.collect() 231 | eval_model() 232 | torch.cuda.empty_cache() 233 | gc.collect() 234 | # if total_steps % opt.loss_update_freq == loss_update_delta: 235 | # if opt.use_match_loss: 236 | # model.update_match_loss_scaler() 237 | # if opt.use_time_D: 238 | # model.update_time_D_loss_scaler() 239 | 240 | if epoch_iter >= train_dataset_size: 241 | break 242 | 243 | # end of epoch 244 | iter_end_time = time.time() 245 | print('End of epoch %d / %d \t Time Taken: %d sec' % 246 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 247 | 248 | # save model for this epoch 249 | if epoch % opt.save_epoch_freq == 0: 250 | print('saving the model at the end of epoch %d, iters %d' % 251 | (epoch, total_steps)) 252 | model.save('latest') 253 | model.save(epoch) 254 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 255 | 256 | # instead of only training the local enhancer, train the entire network after certain iterations 257 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 258 | model.update_fixed_params() 259 | 260 | # linearly decay learning rate after certain iterations 261 | if epoch > opt.niter: 262 | model.update_learning_rate() 263 | -------------------------------------------------------------------------------- /models/spectrogram.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to transform signals 3 | from https://github.com/nils-werner/stft/ 4 | 5 | """ 6 | from __future__ import division, absolute_import 7 | import itertools 8 | import torch 9 | import torch.fft 10 | import torch.nn.functional 11 | import numpy as np 12 | 13 | def _pad(data, frame_length): 14 | return torch.nn.functional.pad( 15 | data, 16 | pad=( 17 | 0, 18 | int( 19 | np.ceil( 20 | len(data) / frame_length 21 | ) * frame_length - len(data) 22 | ) 23 | ), 24 | mode='constant', 25 | value=0 26 | ) 27 | 28 | def unpad(data, outlength): 29 | slicetuple = [slice(None)] * data.ndim 30 | slicetuple[0] = slice(None, outlength) 31 | return data[tuple(slicetuple)] 32 | 33 | 34 | def center_pad(data, frame_length): 35 | #padtuple = [(0, 0)] * data.ndim 36 | padtuple = (frame_length // 2, frame_length // 2) 37 | #print(padtuple) 38 | return torch.nn.functional.pad( 39 | data, 40 | pad = padtuple, 41 | mode = 'constant', 42 | value = 0 43 | ) 44 | 45 | def center_unpad(data, frame_length): 46 | slicetuple = [slice(None)] * data.ndim 47 | slicetuple[1] = slice(frame_length // 2, -frame_length // 2) 48 | return data[tuple(slicetuple)] 49 | 50 | def process( 51 | data, 52 | window_function, 53 | halved, 54 | transform, 55 | padding=0, 56 | n_fft=1024 57 | ): 58 | """Calculate a windowed transform of a signal 59 | 60 | Parameters 61 | ---------- 62 | data : array_like 63 | The signal to be calculated. Must be a 1D array. 64 | window : array_like 65 | Tapering window 66 | halved : boolean 67 | Switch for turning on signal truncation. For real signals, the fourier 68 | transform of real signals returns a symmetrically mirrored spectrum. 69 | This additional data is not needed and can be removed. 70 | transform : callable 71 | The transform to be used. 72 | padding : int 73 | Zero-pad signal with x times the number of samples. 74 | 75 | Returns 76 | ------- 77 | data : array_like 78 | The spectrum 79 | 80 | """ 81 | 82 | data = data * window_function 83 | 84 | if padding > 0: 85 | data = torch.nn.functional.pad( 86 | data, 87 | pad=( 88 | 0, 89 | len(data) * padding 90 | ), 91 | mode='constant', 92 | value=0 93 | ) 94 | 95 | result = transform(data,n_fft) 96 | 97 | if(halved): 98 | result = result[0:result.size // 2 + 1] 99 | 100 | return result 101 | 102 | 103 | def iprocess( 104 | data, 105 | window_function, 106 | halved, 107 | transform, 108 | padding=0, 109 | n_fft=1024 110 | ): 111 | """Calculate the inverse short time fourier transform of a spectrum 112 | 113 | Parameters 114 | ---------- 115 | data : array_like 116 | The spectrum to be calculated. Must be a 1D array. 117 | window : array_like 118 | Tapering window 119 | halved : boolean 120 | Switch for turning on signal truncation. For real output signals, the 121 | inverse fourier transform consumes a symmetrically mirrored spectrum. 122 | This additional data is not needed and can be removed. Setting this 123 | value to :code:`True` will automatically create a mirrored spectrum. 124 | transform : callable 125 | The transform to be used. 126 | padding : int 127 | Signal before FFT transform was padded with x zeros. 128 | 129 | 130 | Returns 131 | ------- 132 | data : array_like 133 | The signal 134 | 135 | """ 136 | if halved: 137 | data = torch.nn.functional.pad(data, (0, data.shape[0] - 2), 'reflect') 138 | start = data.shape[0] // 2 + 1 139 | data[start:] = data[start:].conjugate() 140 | 141 | output = transform(data,n_fft) 142 | if torch.is_complex(output): 143 | output = torch.real(output) 144 | 145 | if padding > 0: 146 | output = output[0:-(len(data) * padding // (padding + 1))] 147 | 148 | return output * window_function 149 | 150 | 151 | def spectrogram( 152 | data, 153 | frame_length=1024, 154 | step_length=None, 155 | overlap=None, 156 | centered=True, 157 | n_fft=1024, 158 | window_function=None, 159 | halved=True, 160 | transform=None, 161 | padding=0 162 | ): 163 | """Calculate the spectrogram of a signal 164 | 165 | Parameters 166 | ---------- 167 | data : array_like 168 | The signal to be transformed. May be a 1D vector for single channel or 169 | a 2D matrix for multi channel data. In case of a mono signal, the data 170 | is must be a 1D vector of length :code:`samples`. In case of a multi 171 | channel signal, the data must be in the shape of :code:`samples x 172 | channels`. 173 | frame_length : int 174 | The signal frame length. Defaults to :code:`1024`. 175 | step_length : int 176 | The signal frame step_length. Defaults to :code:`None`. Setting this 177 | value will override :code:`overlap`. 178 | overlap : int 179 | The signal frame overlap coefficient. Value :code:`x` means 180 | :code:`1/x` overlap. Defaults to :code:`2`. 181 | centered : boolean 182 | Pad input signal so that the first and last window are centered around 183 | the beginning of the signal. Defaults to true. 184 | window : callable, array_like 185 | Window to be used for deringing. Can be :code:`False` to disable 186 | windowing. Defaults to :code:`scipy.signal.cosine`. 187 | halved : boolean 188 | Switch for turning on signal truncation. For real signals, the fourier 189 | transform of real signals returns a symmetrically mirrored spectrum. 190 | This additional data is not needed and can be removed. Defaults to 191 | :code:`True`. 192 | transform : callable 193 | The transform to be used. Defaults to :code:`scipy.fft.fft`. 194 | padding : int 195 | Zero-pad signal with x times the number of samples. 196 | save_settings : boolean 197 | Save settings used here in attribute :code:`out.stft_settings` so that 198 | :func:`ispectrogram` can infer these settings without the developer 199 | having to pass them again. 200 | 201 | Returns 202 | ------- 203 | data : array_like 204 | The spectrogram (or tensor of spectograms) In case of a mono signal, 205 | the data is formatted as :code:`bins x frames`. In case of a multi 206 | channel signal, the data is formatted as :code:`bins x frames x 207 | channels`. 208 | 209 | Notes 210 | ----- 211 | The data will be padded to be a multiple of the desired FFT length. 212 | 213 | See Also 214 | -------- 215 | stft.stft.process : The function used to transform the data 216 | 217 | """ 218 | 219 | if overlap is None: 220 | overlap = 2 221 | 222 | if step_length is None: 223 | step_length = frame_length // overlap 224 | 225 | if halved and torch.any(torch.iscomplex(data)): 226 | raise ValueError("You cannot treat a complex input signal as real " 227 | "valued. Please set keyword argument halved=False.") 228 | 229 | data = torch.squeeze(data) 230 | 231 | if transform is None: 232 | transform = torch.fft.fft 233 | 234 | if not isinstance(transform, (list, tuple)): 235 | transform = [transform] 236 | 237 | transforms = itertools.cycle(transform) 238 | 239 | if centered: 240 | #print("center pad") 241 | data = center_pad(data, frame_length) 242 | 243 | if window_function is None: 244 | window_array = torch.ones(frame_length) 245 | 246 | if callable(window_function): 247 | window_array = window_function(frame_length) 248 | else: 249 | window_array = window_function 250 | frame_length = len(window_array) 251 | 252 | def traf(data): 253 | # Pad input signal so it fits into frame_length spec 254 | #print(data.size()) 255 | #data = _pad(data, frame_length) 256 | #print(data.size()) 257 | 258 | values = list(enumerate( 259 | range(0, len(data) - frame_length + step_length, step_length) 260 | )) 261 | #print(values) 262 | for j, i in values: 263 | sig = process( 264 | data[i:i + frame_length], 265 | window_function=window_array, 266 | halved=halved, 267 | transform=next(transforms), 268 | padding=padding, 269 | n_fft = n_fft 270 | ) / (frame_length // step_length // 2) 271 | if(i == 0): 272 | output_ = torch.zeros( 273 | (sig.shape[0], len(values)), dtype=sig.dtype 274 | ) 275 | 276 | output_[:, j] = sig 277 | 278 | return output_ 279 | 280 | if data.ndim > 2: 281 | raise ValueError("spectrogram: Only 1D or 2D input data allowed") 282 | if data.ndim == 1: 283 | out = traf(data) 284 | elif data.ndim == 2: 285 | #print(data.size()) 286 | for i in range(data.shape[0]): 287 | tmp = traf(data[i,:]) 288 | #print(tmp.size()) 289 | if i == 0: 290 | out = torch.empty( 291 | ((data.shape[0],)+tmp.shape), dtype=tmp.dtype 292 | ) 293 | out[i, :, :] = tmp 294 | return out 295 | 296 | 297 | def ispectrogram( 298 | data, 299 | frame_length=1024, 300 | step_length=None, 301 | overlap=None, 302 | centered=True, 303 | n_fft=1024, 304 | window_function=None, 305 | halved=True, 306 | transform=None, 307 | padding=0, 308 | out_length=None 309 | ): 310 | """Calculate the inverse spectrogram of a signal 311 | 312 | Parameters 313 | ---------- 314 | data : array_like 315 | The spectrogram to be inverted. May be a 2D matrix for single channel 316 | or a 3D tensor for multi channel data. In case of a mono signal, the 317 | data must be in the shape of :code:`bins x frames`. In case of a multi 318 | channel signal, the data must be in the shape of :code:`bins x frames x 319 | channels`. 320 | frame_length : int 321 | The signal frame length. Defaults to infer from data. 322 | step_length : int 323 | The signal frame step_length. Defaults to infer from data. Setting this 324 | value will override :code:`overlap`. 325 | overlap : int 326 | The signal frame overlap coefficient. Value :code:`x` means 327 | :code:`1/x` overlap. Defaults to infer from data. 328 | centered : boolean 329 | Pad input signal so that the first and last window are centered around 330 | the beginning of the signal. Defaults to to infer from data. 331 | window : callable, array_like 332 | Window to be used for deringing. Can be :code:`False` to disable 333 | windowing. Defaults to to infer from data. 334 | halved : boolean 335 | Switch to reconstruct the other halve of the spectrum if the forward 336 | transform has been truncated. Defaults to to infer from data. 337 | transform : callable 338 | The transform to be used. Defaults to infer from data. 339 | padding : int 340 | Zero-pad signal with x times the number of samples. Defaults to infer 341 | from data. 342 | outlength : int 343 | Crop output signal to length. Useful when input length of spectrogram 344 | did not fit into frame_length and input data had to be padded. Not 345 | setting this value will disable cropping, the output data may be 346 | longer than expected. 347 | 348 | Returns 349 | ------- 350 | data : array_like 351 | The signal (or matrix of signals). In case of a mono output signal, the 352 | data is formatted as a 1D vector of length :code:`samples`. In case of 353 | a multi channel output signal, the data is formatted as :code:`samples 354 | x channels`. 355 | 356 | Notes 357 | ----- 358 | By default :func:`spectrogram` saves its transformation parameters in 359 | the output array. This data is used to infer the transform parameters 360 | here. Any aspect of the settings can be overridden by passing the according 361 | parameter to this function. 362 | 363 | During transform the data will be padded to be a multiple of the desired 364 | FFT length. Hence, the result of the inverse transform might be longer 365 | than the input signal. However it is safe to remove the additional data, 366 | e.g. by using 367 | 368 | .. code:: python 369 | 370 | output.resize(input.shape) 371 | 372 | where :code:`input` is the input of :code:`stft.spectrogram()` and 373 | :code:`output` is the output of :code:`stft.ispectrogram()` 374 | 375 | See Also 376 | -------- 377 | stft.stft.iprocess : The function used to transform the data 378 | 379 | """ 380 | 381 | if overlap is None: 382 | overlap = 2 383 | 384 | if step_length is None: 385 | step_length = frame_length // overlap 386 | 387 | if window_function is None: 388 | window_array = torch.ones(frame_length) 389 | 390 | if callable(window_function): 391 | window_array = window_function(frame_length) 392 | else: 393 | window_array = window_function 394 | 395 | if transform is None: 396 | transform = torch.fft.ifft 397 | 398 | if not isinstance(transform, (list, tuple)): 399 | transform = [transform] 400 | 401 | transforms = itertools.cycle(transform) 402 | 403 | def traf(data): 404 | i = 0 405 | values = range(0, data.shape[1]) 406 | for j in values: 407 | sig = iprocess( 408 | data[:, j], 409 | window_function=window_array, 410 | halved=halved, 411 | transform=next(transforms), 412 | padding=padding, 413 | n_fft=n_fft 414 | ) 415 | 416 | if(i == 0): 417 | output = torch.zeros( 418 | frame_length + (len(values) - 1) * step_length, 419 | dtype=sig.dtype 420 | ).cuda() 421 | 422 | output[i:i + frame_length] += sig 423 | 424 | i += step_length 425 | 426 | return output 427 | 428 | if data.ndim == 2: 429 | out = traf(data) 430 | elif data.ndim == 3: 431 | for i in range(data.shape[0]): 432 | tmp = traf(data[i ,:, :]) 433 | 434 | if i == 0: 435 | out = torch.empty( 436 | ((data.shape[0],)+tmp.shape), dtype=tmp.dtype 437 | ) 438 | out[i,:] = tmp 439 | else: 440 | raise ValueError("ispectrogram: Only 2D or 3D input data allowed") 441 | 442 | if centered: 443 | print(out.size()) 444 | out = center_unpad(out, frame_length) 445 | 446 | return unpad(out, out_length) -------------------------------------------------------------------------------- /data/test.csv: -------------------------------------------------------------------------------- 1 | wav48/p250/p250_328.wav 2 | wav48/p310/p310_345.wav 3 | wav48/p227/p227_020.wav 4 | wav48/p285/p285_050.wav 5 | wav48/p248/p248_011.wav 6 | wav48/p246/p246_030.wav 7 | wav48/p247/p247_191.wav 8 | wav48/p287/p287_127.wav 9 | wav48/p334/p334_220.wav 10 | wav48/p340/p340_414.wav 11 | wav48/p236/p236_231.wav 12 | wav48/p301/p301_334.wav 13 | wav48/p258/p258_021.wav 14 | wav48/p374/p374_403.wav 15 | wav48/p312/p312_182.wav 16 | wav48/p232/p232_237.wav 17 | wav48/p334/p334_080.wav 18 | wav48/p347/p347_206.wav 19 | wav48/p313/p313_347.wav 20 | wav48/p304/p304_274.wav 21 | wav48/p351/p351_117.wav 22 | wav48/p351/p351_051.wav 23 | wav48/p312/p312_109.wav 24 | wav48/p323/p323_148.wav 25 | wav48/p329/p329_187.wav 26 | wav48/p308/p308_255.wav 27 | wav48/p281/p281_039.wav 28 | wav48/p294/p294_004.wav 29 | wav48/p310/p310_056.wav 30 | wav48/p271/p271_016.wav 31 | wav48/p249/p249_276.wav 32 | wav48/p285/p285_305.wav 33 | wav48/p330/p330_266.wav 34 | wav48/p229/p229_006.wav 35 | wav48/p273/p273_313.wav 36 | wav48/p329/p329_392.wav 37 | wav48/p305/p305_384.wav 38 | wav48/p256/p256_177.wav 39 | wav48/p237/p237_327.wav 40 | wav48/p308/p308_028.wav 41 | wav48/p339/p339_401.wav 42 | wav48/p298/p298_049.wav 43 | wav48/p257/p257_331.wav 44 | wav48/p301/p301_302.wav 45 | wav48/p227/p227_330.wav 46 | wav48/p255/p255_292.wav 47 | wav48/p362/p362_040.wav 48 | wav48/p237/p237_218.wav 49 | wav48/p281/p281_233.wav 50 | wav48/p312/p312_350.wav 51 | wav48/p312/p312_130.wav 52 | wav48/p287/p287_056.wav 53 | wav48/p376/p376_051.wav 54 | wav48/p253/p253_148.wav 55 | wav48/p266/p266_336.wav 56 | wav48/p288/p288_307.wav 57 | wav48/p254/p254_165.wav 58 | wav48/p333/p333_395.wav 59 | wav48/p307/p307_171.wav 60 | wav48/p243/p243_255.wav 61 | wav48/p274/p274_439.wav 62 | wav48/p302/p302_235.wav 63 | wav48/p248/p248_047.wav 64 | wav48/p363/p363_273.wav 65 | wav48/p292/p292_143.wav 66 | wav48/p286/p286_001.wav 67 | wav48/p259/p259_223.wav 68 | wav48/p233/p233_334.wav 69 | wav48/p267/p267_109.wav 70 | wav48/p339/p339_006.wav 71 | wav48/p295/p295_006.wav 72 | wav48/p281/p281_018.wav 73 | wav48/p252/p252_175.wav 74 | wav48/p297/p297_108.wav 75 | wav48/p243/p243_398.wav 76 | wav48/p226/p226_112.wav 77 | wav48/p341/p341_232.wav 78 | wav48/p363/p363_368.wav 79 | wav48/p362/p362_240.wav 80 | wav48/p306/p306_100.wav 81 | wav48/p301/p301_049.wav 82 | wav48/p257/p257_238.wav 83 | wav48/p245/p245_039.wav 84 | wav48/p313/p313_133.wav 85 | wav48/p278/p278_147.wav 86 | wav48/p280/p280_130.wav 87 | wav48/p271/p271_052.wav 88 | wav48/p341/p341_131.wav 89 | wav48/p233/p233_074.wav 90 | wav48/p283/p283_181.wav 91 | wav48/p260/p260_034.wav 92 | wav48/p273/p273_294.wav 93 | wav48/p286/p286_245.wav 94 | wav48/p314/p314_009.wav 95 | wav48/p300/p300_277.wav 96 | wav48/p271/p271_022.wav 97 | wav48/p361/p361_214.wav 98 | wav48/p271/p271_385.wav 99 | wav48/p313/p313_026.wav 100 | wav48/p277/p277_096.wav 101 | wav48/p270/p270_352.wav 102 | wav48/p279/p279_257.wav 103 | wav48/p297/p297_015.wav 104 | wav48/p241/p241_014.wav 105 | wav48/p228/p228_072.wav 106 | wav48/p251/p251_092.wav 107 | wav48/p252/p252_169.wav 108 | wav48/p249/p249_003.wav 109 | wav48/p278/p278_023.wav 110 | wav48/p339/p339_134.wav 111 | wav48/p295/p295_193.wav 112 | wav48/p283/p283_397.wav 113 | wav48/p266/p266_173.wav 114 | wav48/p276/p276_394.wav 115 | wav48/p311/p311_354.wav 116 | wav48/p246/p246_194.wav 117 | wav48/p307/p307_256.wav 118 | wav48/p306/p306_010.wav 119 | wav48/p278/p278_159.wav 120 | wav48/p229/p229_359.wav 121 | wav48/p376/p376_086.wav 122 | wav48/p263/p263_265.wav 123 | wav48/p243/p243_058.wav 124 | wav48/p244/p244_215.wav 125 | wav48/p257/p257_174.wav 126 | wav48/p300/p300_295.wav 127 | wav48/p247/p247_113.wav 128 | wav48/p286/p286_032.wav 129 | wav48/p314/p314_344.wav 130 | wav48/p347/p347_234.wav 131 | wav48/p255/p255_302.wav 132 | wav48/p376/p376_140.wav 133 | wav48/p362/p362_182.wav 134 | wav48/p294/p294_123.wav 135 | wav48/p330/p330_380.wav 136 | wav48/p294/p294_242.wav 137 | wav48/p243/p243_280.wav 138 | wav48/p298/p298_240.wav 139 | wav48/p274/p274_302.wav 140 | wav48/p281/p281_285.wav 141 | wav48/p271/p271_093.wav 142 | wav48/p323/p323_314.wav 143 | wav48/p249/p249_140.wav 144 | wav48/p333/p333_060.wav 145 | wav48/p300/p300_108.wav 146 | wav48/p236/p236_450.wav 147 | wav48/p251/p251_196.wav 148 | wav48/p262/p262_308.wav 149 | wav48/p292/p292_259.wav 150 | wav48/p311/p311_361.wav 151 | wav48/p226/p226_001.wav 152 | wav48/p239/p239_394.wav 153 | wav48/p240/p240_028.wav 154 | wav48/p304/p304_143.wav 155 | wav48/p279/p279_209.wav 156 | wav48/p286/p286_072.wav 157 | wav48/p374/p374_330.wav 158 | wav48/p302/p302_208.wav 159 | wav48/p287/p287_344.wav 160 | wav48/p261/p261_220.wav 161 | wav48/p345/p345_253.wav 162 | wav48/p326/p326_002.wav 163 | wav48/p277/p277_417.wav 164 | wav48/p249/p249_211.wav 165 | wav48/p283/p283_441.wav 166 | wav48/p351/p351_292.wav 167 | wav48/p301/p301_309.wav 168 | wav48/p363/p363_252.wav 169 | wav48/p329/p329_005.wav 170 | wav48/p229/p229_098.wav 171 | wav48/p261/p261_312.wav 172 | wav48/p334/p334_422.wav 173 | wav48/p258/p258_059.wav 174 | wav48/p262/p262_183.wav 175 | wav48/p271/p271_441.wav 176 | wav48/p276/p276_301.wav 177 | wav48/p280/p280_221.wav 178 | wav48/p245/p245_293.wav 179 | wav48/p225/p225_301.wav 180 | wav48/p275/p275_272.wav 181 | wav48/p275/p275_310.wav 182 | wav48/p263/p263_002.wav 183 | wav48/p317/p317_302.wav 184 | wav48/p248/p248_252.wav 185 | wav48/p236/p236_499.wav 186 | wav48/p266/p266_356.wav 187 | wav48/p274/p274_233.wav 188 | wav48/p275/p275_246.wav 189 | wav48/p232/p232_025.wav 190 | wav48/p240/p240_009.wav 191 | wav48/p232/p232_236.wav 192 | wav48/p227/p227_096.wav 193 | wav48/p347/p347_172.wav 194 | wav48/p305/p305_363.wav 195 | wav48/p288/p288_270.wav 196 | wav48/p258/p258_286.wav 197 | wav48/p243/p243_220.wav 198 | wav48/p252/p252_409.wav 199 | wav48/p247/p247_283.wav 200 | wav48/p254/p254_355.wav 201 | wav48/p226/p226_246.wav 202 | wav48/p283/p283_468.wav 203 | wav48/p233/p233_132.wav 204 | wav48/p343/p343_066.wav 205 | wav48/p351/p351_339.wav 206 | wav48/p313/p313_099.wav 207 | wav48/p305/p305_383.wav 208 | wav48/p234/p234_318.wav 209 | wav48/p272/p272_228.wav 210 | wav48/p239/p239_308.wav 211 | wav48/p265/p265_287.wav 212 | wav48/p250/p250_424.wav 213 | wav48/p260/p260_357.wav 214 | wav48/p239/p239_228.wav 215 | wav48/p277/p277_130.wav 216 | wav48/p317/p317_269.wav 217 | wav48/p236/p236_287.wav 218 | wav48/p310/p310_151.wav 219 | wav48/p225/p225_166.wav 220 | wav48/p341/p341_290.wav 221 | wav48/p323/p323_167.wav 222 | wav48/p245/p245_208.wav 223 | wav48/p308/p308_287.wav 224 | wav48/p307/p307_154.wav 225 | wav48/p239/p239_137.wav 226 | wav48/p326/p326_306.wav 227 | wav48/p299/p299_394.wav 228 | wav48/p270/p270_203.wav 229 | wav48/p255/p255_143.wav 230 | wav48/p286/p286_143.wav 231 | wav48/p278/p278_248.wav 232 | wav48/p232/p232_035.wav 233 | wav48/p297/p297_404.wav 234 | wav48/p345/p345_207.wav 235 | wav48/p268/p268_107.wav 236 | wav48/p362/p362_148.wav 237 | wav48/p364/p364_223.wav 238 | wav48/p252/p252_404.wav 239 | wav48/p267/p267_238.wav 240 | wav48/p268/p268_182.wav 241 | wav48/p238/p238_387.wav 242 | wav48/p230/p230_207.wav 243 | wav48/p306/p306_126.wav 244 | wav48/p310/p310_201.wav 245 | wav48/p227/p227_065.wav 246 | wav48/p228/p228_214.wav 247 | wav48/p264/p264_082.wav 248 | wav48/p257/p257_226.wav 249 | wav48/p294/p294_302.wav 250 | wav48/p312/p312_184.wav 251 | wav48/p376/p376_105.wav 252 | wav48/p364/p364_056.wav 253 | wav48/p234/p234_067.wav 254 | wav48/p252/p252_061.wav 255 | wav48/p225/p225_027.wav 256 | wav48/p254/p254_136.wav 257 | wav48/p361/p361_078.wav 258 | wav48/p256/p256_088.wav 259 | wav48/p303/p303_331.wav 260 | wav48/p334/p334_406.wav 261 | wav48/p281/p281_243.wav 262 | wav48/p343/p343_293.wav 263 | wav48/p253/p253_126.wav 264 | wav48/p297/p297_152.wav 265 | wav48/p258/p258_226.wav 266 | wav48/p255/p255_363.wav 267 | wav48/p244/p244_107.wav 268 | wav48/p233/p233_295.wav 269 | wav48/p335/p335_420.wav 270 | wav48/p271/p271_082.wav 271 | wav48/p323/p323_136.wav 272 | wav48/p280/p280_259.wav 273 | wav48/p362/p362_221.wav 274 | wav48/p246/p246_121.wav 275 | wav48/p253/p253_049.wav 276 | wav48/p333/p333_196.wav 277 | wav48/p329/p329_048.wav 278 | wav48/p305/p305_062.wav 279 | wav48/p341/p341_103.wav 280 | wav48/p283/p283_264.wav 281 | wav48/p252/p252_038.wav 282 | wav48/p260/p260_314.wav 283 | wav48/p313/p313_352.wav 284 | wav48/p295/p295_360.wav 285 | wav48/p266/p266_169.wav 286 | wav48/p279/p279_108.wav 287 | wav48/p303/p303_123.wav 288 | wav48/p282/p282_087.wav 289 | wav48/p345/p345_265.wav 290 | wav48/p306/p306_292.wav 291 | wav48/p268/p268_321.wav 292 | wav48/p286/p286_311.wav 293 | wav48/p259/p259_268.wav 294 | wav48/p295/p295_328.wav 295 | wav48/p230/p230_185.wav 296 | wav48/p226/p226_257.wav 297 | wav48/p288/p288_042.wav 298 | wav48/p254/p254_334.wav 299 | wav48/p298/p298_227.wav 300 | wav48/p299/p299_075.wav 301 | wav48/p251/p251_053.wav 302 | wav48/p301/p301_284.wav 303 | wav48/p278/p278_313.wav 304 | wav48/p264/p264_491.wav 305 | wav48/p310/p310_032.wav 306 | wav48/p292/p292_120.wav 307 | wav48/p345/p345_386.wav 308 | wav48/p253/p253_305.wav 309 | wav48/p335/p335_072.wav 310 | wav48/p241/p241_228.wav 311 | wav48/p360/p360_362.wav 312 | wav48/p333/p333_088.wav 313 | wav48/p279/p279_122.wav 314 | wav48/p336/p336_127.wav 315 | wav48/p376/p376_045.wav 316 | wav48/p250/p250_240.wav 317 | wav48/p299/p299_273.wav 318 | wav48/p231/p231_226.wav 319 | wav48/p281/p281_367.wav 320 | wav48/p293/p293_185.wav 321 | wav48/p241/p241_032.wav 322 | wav48/p279/p279_029.wav 323 | wav48/p362/p362_116.wav 324 | wav48/p225/p225_243.wav 325 | wav48/p228/p228_323.wav 326 | wav48/p265/p265_342.wav 327 | wav48/p250/p250_282.wav 328 | wav48/p339/p339_355.wav 329 | wav48/p287/p287_357.wav 330 | wav48/p275/p275_390.wav 331 | wav48/p360/p360_249.wav 332 | wav48/p341/p341_012.wav 333 | wav48/p237/p237_150.wav 334 | wav48/p253/p253_121.wav 335 | wav48/p339/p339_417.wav 336 | wav48/p238/p238_355.wav 337 | wav48/p256/p256_129.wav 338 | wav48/p304/p304_047.wav 339 | wav48/p248/p248_200.wav 340 | wav48/p274/p274_268.wav 341 | wav48/p362/p362_150.wav 342 | wav48/p234/p234_069.wav 343 | wav48/p283/p283_288.wav 344 | wav48/p243/p243_080.wav 345 | wav48/p267/p267_338.wav 346 | wav48/p299/p299_188.wav 347 | wav48/p292/p292_229.wav 348 | wav48/p256/p256_076.wav 349 | wav48/p343/p343_174.wav 350 | wav48/p345/p345_321.wav 351 | wav48/p334/p334_194.wav 352 | wav48/p305/p305_204.wav 353 | wav48/p230/p230_128.wav 354 | wav48/p360/p360_286.wav 355 | wav48/p305/p305_324.wav 356 | wav48/p252/p252_315.wav 357 | wav48/p297/p297_137.wav 358 | wav48/p326/p326_059.wav 359 | wav48/p301/p301_394.wav 360 | wav48/p306/p306_180.wav 361 | wav48/p281/p281_132.wav 362 | wav48/p268/p268_394.wav 363 | wav48/p264/p264_116.wav 364 | wav48/p239/p239_259.wav 365 | wav48/p313/p313_087.wav 366 | wav48/p264/p264_108.wav 367 | wav48/p315/p315_250.wav 368 | wav48/p306/p306_290.wav 369 | wav48/p281/p281_036.wav 370 | wav48/p248/p248_113.wav 371 | wav48/p285/p285_038.wav 372 | wav48/p238/p238_027.wav 373 | wav48/p249/p249_278.wav 374 | wav48/p267/p267_091.wav 375 | wav48/p249/p249_311.wav 376 | wav48/p247/p247_012.wav 377 | wav48/p301/p301_293.wav 378 | wav48/p257/p257_361.wav 379 | wav48/p363/p363_017.wav 380 | wav48/p301/p301_314.wav 381 | wav48/p260/p260_180.wav 382 | wav48/p247/p247_456.wav 383 | wav48/p280/p280_066.wav 384 | wav48/p234/p234_056.wav 385 | wav48/p243/p243_352.wav 386 | wav48/p298/p298_400.wav 387 | wav48/p227/p227_003.wav 388 | wav48/p226/p226_350.wav 389 | wav48/p282/p282_036.wav 390 | wav48/p271/p271_273.wav 391 | wav48/p374/p374_042.wav 392 | wav48/p374/p374_186.wav 393 | wav48/p238/p238_046.wav 394 | wav48/p293/p293_361.wav 395 | wav48/p330/p330_148.wav 396 | wav48/p334/p334_190.wav 397 | wav48/p230/p230_220.wav 398 | wav48/p232/p232_393.wav 399 | wav48/p310/p310_179.wav 400 | wav48/p252/p252_391.wav 401 | wav48/p257/p257_370.wav 402 | wav48/p340/p340_184.wav 403 | wav48/p339/p339_171.wav 404 | wav48/p272/p272_160.wav 405 | wav48/p334/p334_185.wav 406 | wav48/p258/p258_052.wav 407 | wav48/p248/p248_314.wav 408 | wav48/p310/p310_421.wav 409 | wav48/p238/p238_232.wav 410 | wav48/p298/p298_097.wav 411 | wav48/p316/p316_099.wav 412 | wav48/p287/p287_182.wav 413 | wav48/p241/p241_062.wav 414 | wav48/p317/p317_350.wav 415 | wav48/p314/p314_105.wav 416 | wav48/p333/p333_096.wav 417 | wav48/p301/p301_345.wav 418 | wav48/p243/p243_090.wav 419 | wav48/p268/p268_277.wav 420 | wav48/p237/p237_253.wav 421 | wav48/p307/p307_021.wav 422 | wav48/p307/p307_413.wav 423 | wav48/p225/p225_303.wav 424 | wav48/p312/p312_414.wav 425 | wav48/p253/p253_340.wav 426 | wav48/p345/p345_117.wav 427 | wav48/p268/p268_246.wav 428 | wav48/p263/p263_264.wav 429 | wav48/p301/p301_130.wav 430 | wav48/p271/p271_295.wav 431 | wav48/p233/p233_072.wav 432 | wav48/p333/p333_406.wav 433 | wav48/p303/p303_202.wav 434 | wav48/p307/p307_386.wav 435 | wav48/p239/p239_162.wav 436 | wav48/p227/p227_040.wav 437 | wav48/p329/p329_256.wav 438 | wav48/p275/p275_020.wav 439 | wav48/p302/p302_116.wav 440 | wav48/p340/p340_308.wav 441 | wav48/p249/p249_189.wav 442 | wav48/p306/p306_096.wav 443 | wav48/p254/p254_335.wav 444 | wav48/p272/p272_042.wav 445 | wav48/p313/p313_258.wav 446 | wav48/p253/p253_168.wav 447 | wav48/p251/p251_225.wav 448 | wav48/p293/p293_293.wav 449 | wav48/p345/p345_370.wav 450 | wav48/p280/p280_161.wav 451 | wav48/p256/p256_124.wav 452 | wav48/p313/p313_329.wav 453 | wav48/p305/p305_072.wav 454 | wav48/p363/p363_364.wav 455 | wav48/p286/p286_251.wav 456 | wav48/p288/p288_006.wav 457 | wav48/p295/p295_034.wav 458 | wav48/p336/p336_200.wav 459 | wav48/p336/p336_104.wav 460 | wav48/p295/p295_090.wav 461 | wav48/p311/p311_145.wav 462 | wav48/p264/p264_266.wav 463 | wav48/p318/p318_381.wav 464 | wav48/p238/p238_206.wav 465 | wav48/p313/p313_409.wav 466 | wav48/p272/p272_285.wav 467 | wav48/p299/p299_193.wav 468 | wav48/p279/p279_040.wav 469 | wav48/p311/p311_042.wav 470 | wav48/p313/p313_035.wav 471 | wav48/p281/p281_196.wav 472 | wav48/p274/p274_083.wav 473 | wav48/p345/p345_099.wav 474 | wav48/p230/p230_069.wav 475 | wav48/p226/p226_163.wav 476 | wav48/p275/p275_411.wav 477 | wav48/p364/p364_266.wav 478 | wav48/p286/p286_404.wav 479 | wav48/p364/p364_054.wav 480 | wav48/p244/p244_147.wav 481 | wav48/p255/p255_073.wav 482 | wav48/p298/p298_024.wav 483 | wav48/p329/p329_028.wav 484 | wav48/p314/p314_128.wav 485 | wav48/p273/p273_066.wav 486 | wav48/p360/p360_325.wav 487 | wav48/p283/p283_434.wav 488 | wav48/p281/p281_197.wav 489 | wav48/p279/p279_359.wav 490 | wav48/p252/p252_162.wav 491 | wav48/p276/p276_242.wav 492 | wav48/p252/p252_191.wav 493 | wav48/p312/p312_279.wav 494 | wav48/p308/p308_337.wav 495 | wav48/p288/p288_188.wav 496 | wav48/p272/p272_044.wav 497 | wav48/p260/p260_284.wav 498 | wav48/p316/p316_224.wav 499 | wav48/p298/p298_355.wav 500 | wav48/p230/p230_079.wav 501 | wav48/p374/p374_059.wav 502 | wav48/p247/p247_399.wav 503 | wav48/p245/p245_067.wav 504 | wav48/p250/p250_423.wav 505 | wav48/p363/p363_204.wav 506 | wav48/p318/p318_175.wav 507 | wav48/p248/p248_163.wav 508 | wav48/p248/p248_293.wav 509 | wav48/p279/p279_133.wav 510 | wav48/p244/p244_359.wav 511 | wav48/p239/p239_176.wav 512 | wav48/p335/p335_423.wav 513 | wav48/p232/p232_030.wav 514 | wav48/p314/p314_219.wav 515 | wav48/p334/p334_103.wav 516 | wav48/p252/p252_099.wav 517 | wav48/p288/p288_409.wav 518 | wav48/p288/p288_316.wav 519 | wav48/p288/p288_030.wav 520 | wav48/p340/p340_187.wav 521 | wav48/p228/p228_272.wav 522 | wav48/p230/p230_044.wav 523 | wav48/p245/p245_022.wav 524 | wav48/p266/p266_222.wav 525 | wav48/p240/p240_138.wav 526 | wav48/p334/p334_062.wav 527 | wav48/p363/p363_220.wav 528 | wav48/p298/p298_011.wav 529 | wav48/p227/p227_258.wav 530 | wav48/p360/p360_114.wav 531 | wav48/p254/p254_283.wav 532 | wav48/p236/p236_065.wav 533 | wav48/p278/p278_139.wav 534 | wav48/p298/p298_270.wav 535 | wav48/p229/p229_134.wav 536 | wav48/p339/p339_218.wav 537 | wav48/p312/p312_358.wav 538 | wav48/p284/p284_079.wav 539 | wav48/p225/p225_005.wav 540 | wav48/p282/p282_356.wav 541 | wav48/p376/p376_257.wav 542 | wav48/p227/p227_200.wav 543 | wav48/p268/p268_276.wav 544 | wav48/p268/p268_245.wav 545 | wav48/p229/p229_053.wav 546 | wav48/p231/p231_332.wav 547 | wav48/p272/p272_077.wav 548 | wav48/p264/p264_053.wav 549 | wav48/p265/p265_038.wav 550 | wav48/p303/p303_335.wav 551 | wav48/p264/p264_100.wav 552 | wav48/p351/p351_042.wav 553 | wav48/p253/p253_336.wav 554 | wav48/p225/p225_324.wav 555 | wav48/p277/p277_307.wav 556 | wav48/p238/p238_459.wav 557 | wav48/p311/p311_320.wav 558 | wav48/p271/p271_322.wav 559 | wav48/p285/p285_039.wav 560 | wav48/p266/p266_145.wav 561 | wav48/p361/p361_383.wav 562 | wav48/p237/p237_285.wav 563 | wav48/p310/p310_386.wav 564 | wav48/p240/p240_281.wav 565 | wav48/p232/p232_369.wav 566 | wav48/p362/p362_231.wav 567 | wav48/p248/p248_080.wav 568 | wav48/p255/p255_257.wav 569 | wav48/p340/p340_279.wav 570 | wav48/p323/p323_174.wav 571 | wav48/p275/p275_060.wav 572 | wav48/p364/p364_083.wav 573 | wav48/p261/p261_462.wav 574 | wav48/p226/p226_221.wav 575 | wav48/p275/p275_112.wav 576 | wav48/p288/p288_211.wav 577 | wav48/p297/p297_049.wav 578 | wav48/p263/p263_160.wav 579 | wav48/p347/p347_259.wav 580 | wav48/p299/p299_367.wav 581 | wav48/p339/p339_297.wav 582 | wav48/p276/p276_186.wav 583 | wav48/p268/p268_330.wav 584 | wav48/p312/p312_232.wav 585 | wav48/p312/p312_262.wav 586 | wav48/p238/p238_389.wav 587 | wav48/p234/p234_052.wav 588 | wav48/p294/p294_406.wav 589 | wav48/p318/p318_335.wav 590 | wav48/p333/p333_176.wav 591 | wav48/p240/p240_304.wav 592 | wav48/p330/p330_230.wav 593 | wav48/p347/p347_245.wav 594 | wav48/p326/p326_368.wav 595 | wav48/p256/p256_025.wav 596 | wav48/p282/p282_123.wav 597 | wav48/p341/p341_316.wav 598 | wav48/p282/p282_292.wav 599 | wav48/p250/p250_385.wav 600 | wav48/p306/p306_352.wav 601 | wav48/p292/p292_371.wav 602 | wav48/p259/p259_291.wav 603 | wav48/p317/p317_397.wav 604 | wav48/p275/p275_181.wav 605 | wav48/p253/p253_334.wav 606 | wav48/p361/p361_055.wav 607 | wav48/p304/p304_406.wav 608 | wav48/p351/p351_295.wav 609 | wav48/p278/p278_164.wav 610 | wav48/p236/p236_091.wav 611 | wav48/p318/p318_169.wav 612 | wav48/p236/p236_386.wav 613 | wav48/p311/p311_195.wav 614 | wav48/p294/p294_376.wav 615 | wav48/p270/p270_219.wav 616 | wav48/p276/p276_265.wav 617 | wav48/p316/p316_374.wav 618 | wav48/p302/p302_243.wav 619 | wav48/p229/p229_177.wav 620 | wav48/p315/p315_309.wav 621 | wav48/p265/p265_331.wav 622 | wav48/p323/p323_353.wav 623 | wav48/p248/p248_286.wav 624 | wav48/p303/p303_205.wav 625 | wav48/p307/p307_012.wav 626 | wav48/p249/p249_295.wav 627 | wav48/p286/p286_235.wav 628 | wav48/p292/p292_386.wav 629 | wav48/p260/p260_105.wav 630 | wav48/p250/p250_137.wav 631 | wav48/p329/p329_372.wav 632 | wav48/p329/p329_342.wav 633 | wav48/p310/p310_169.wav 634 | wav48/p310/p310_228.wav 635 | wav48/p302/p302_017.wav 636 | wav48/p226/p226_136.wav 637 | wav48/p343/p343_145.wav 638 | wav48/p335/p335_069.wav 639 | wav48/p335/p335_143.wav 640 | wav48/p262/p262_312.wav 641 | wav48/p263/p263_334.wav 642 | wav48/p231/p231_421.wav 643 | wav48/p263/p263_385.wav 644 | wav48/p343/p343_104.wav 645 | wav48/p364/p364_050.wav 646 | wav48/p236/p236_200.wav 647 | wav48/p233/p233_240.wav 648 | wav48/p317/p317_299.wav 649 | wav48/p313/p313_089.wav 650 | wav48/p284/p284_050.wav 651 | wav48/p284/p284_056.wav 652 | wav48/p266/p266_272.wav 653 | wav48/p362/p362_008.wav 654 | wav48/p265/p265_323.wav 655 | wav48/p239/p239_390.wav 656 | wav48/p313/p313_027.wav 657 | wav48/p347/p347_334.wav 658 | wav48/p312/p312_077.wav 659 | wav48/p307/p307_305.wav 660 | wav48/p347/p347_073.wav 661 | wav48/p278/p278_349.wav 662 | wav48/p266/p266_320.wav 663 | wav48/p292/p292_344.wav 664 | wav48/p326/p326_023.wav 665 | wav48/p257/p257_301.wav 666 | wav48/p312/p312_273.wav 667 | wav48/p311/p311_024.wav 668 | wav48/p254/p254_276.wav 669 | wav48/p237/p237_235.wav 670 | wav48/p250/p250_141.wav 671 | wav48/p343/p343_177.wav 672 | wav48/p245/p245_354.wav 673 | wav48/p250/p250_302.wav 674 | wav48/p362/p362_360.wav 675 | wav48/p241/p241_064.wav 676 | wav48/p255/p255_248.wav 677 | wav48/p310/p310_042.wav 678 | wav48/p314/p314_222.wav 679 | wav48/p256/p256_115.wav 680 | wav48/p254/p254_249.wav 681 | wav48/p251/p251_281.wav 682 | wav48/p257/p257_074.wav 683 | wav48/p312/p312_272.wav 684 | wav48/p270/p270_008.wav 685 | wav48/p262/p262_171.wav 686 | wav48/p361/p361_176.wav 687 | wav48/p256/p256_258.wav 688 | wav48/p234/p234_193.wav 689 | wav48/p265/p265_343.wav 690 | wav48/p287/p287_338.wav 691 | wav48/p295/p295_294.wav 692 | wav48/p298/p298_362.wav 693 | wav48/p238/p238_395.wav 694 | wav48/p318/p318_067.wav 695 | wav48/p283/p283_249.wav 696 | wav48/p300/p300_148.wav 697 | wav48/p261/p261_422.wav 698 | wav48/p261/p261_272.wav 699 | wav48/p244/p244_323.wav 700 | wav48/p317/p317_055.wav 701 | wav48/p267/p267_246.wav 702 | wav48/p360/p360_043.wav 703 | wav48/p274/p274_196.wav 704 | wav48/p256/p256_087.wav 705 | wav48/p285/p285_379.wav 706 | wav48/p258/p258_069.wav 707 | wav48/p277/p277_414.wav 708 | wav48/p255/p255_118.wav 709 | wav48/p280/p280_392.wav 710 | wav48/p374/p374_325.wav 711 | wav48/p336/p336_322.wav 712 | wav48/p361/p361_399.wav 713 | wav48/p292/p292_407.wav 714 | wav48/p351/p351_350.wav 715 | wav48/p294/p294_285.wav 716 | wav48/p259/p259_374.wav 717 | wav48/p351/p351_297.wav 718 | wav48/p280/p280_028.wav 719 | wav48/p310/p310_144.wav 720 | wav48/p245/p245_203.wav 721 | wav48/p304/p304_326.wav 722 | wav48/p266/p266_382.wav 723 | wav48/p343/p343_288.wav 724 | wav48/p284/p284_398.wav 725 | wav48/p283/p283_120.wav 726 | wav48/p268/p268_297.wav 727 | wav48/p303/p303_284.wav 728 | wav48/p251/p251_116.wav 729 | wav48/p364/p364_071.wav 730 | wav48/p268/p268_283.wav 731 | wav48/p323/p323_045.wav 732 | wav48/p243/p243_019.wav 733 | wav48/p248/p248_306.wav 734 | wav48/p297/p297_256.wav 735 | wav48/p250/p250_459.wav 736 | wav48/p266/p266_062.wav 737 | wav48/p340/p340_394.wav 738 | wav48/p240/p240_338.wav 739 | wav48/p281/p281_273.wav 740 | wav48/p245/p245_112.wav 741 | wav48/p345/p345_015.wav 742 | wav48/p259/p259_014.wav 743 | wav48/p303/p303_065.wav 744 | wav48/p275/p275_161.wav 745 | wav48/p230/p230_374.wav 746 | wav48/p269/p269_043.wav 747 | wav48/p279/p279_251.wav 748 | wav48/p341/p341_306.wav 749 | wav48/p316/p316_401.wav 750 | wav48/p273/p273_231.wav 751 | wav48/p284/p284_420.wav 752 | wav48/p247/p247_048.wav 753 | wav48/p361/p361_375.wav 754 | wav48/p330/p330_143.wav 755 | wav48/p240/p240_102.wav 756 | wav48/p230/p230_296.wav 757 | wav48/p258/p258_400.wav 758 | wav48/p245/p245_015.wav 759 | wav48/p323/p323_349.wav 760 | wav48/p233/p233_155.wav 761 | wav48/p341/p341_039.wav 762 | wav48/p360/p360_200.wav 763 | wav48/p308/p308_421.wav 764 | wav48/p317/p317_401.wav 765 | wav48/p247/p247_180.wav 766 | wav48/p240/p240_088.wav 767 | wav48/p287/p287_303.wav 768 | wav48/p314/p314_319.wav 769 | wav48/p255/p255_235.wav 770 | wav48/p287/p287_268.wav 771 | wav48/p306/p306_320.wav 772 | wav48/p231/p231_247.wav 773 | wav48/p316/p316_156.wav 774 | wav48/p310/p310_416.wav 775 | wav48/p288/p288_037.wav 776 | wav48/p228/p228_204.wav 777 | wav48/p297/p297_168.wav 778 | wav48/p340/p340_177.wav 779 | wav48/p230/p230_190.wav 780 | wav48/p330/p330_278.wav 781 | wav48/p336/p336_041.wav 782 | wav48/p275/p275_115.wav 783 | wav48/p306/p306_234.wav 784 | wav48/p237/p237_008.wav 785 | wav48/p263/p263_194.wav 786 | wav48/p240/p240_219.wav 787 | wav48/p361/p361_043.wav 788 | wav48/p299/p299_109.wav 789 | wav48/p305/p305_111.wav 790 | wav48/p341/p341_372.wav 791 | wav48/p260/p260_321.wav 792 | wav48/p247/p247_373.wav 793 | wav48/p335/p335_369.wav 794 | wav48/p308/p308_311.wav 795 | wav48/p295/p295_068.wav 796 | wav48/p292/p292_374.wav 797 | wav48/p277/p277_411.wav 798 | wav48/p294/p294_263.wav 799 | wav48/p276/p276_233.wav 800 | wav48/p360/p360_281.wav 801 | wav48/p278/p278_345.wav 802 | wav48/p269/p269_294.wav 803 | wav48/p305/p305_420.wav 804 | wav48/p227/p227_137.wav 805 | wav48/p229/p229_329.wav 806 | wav48/p360/p360_366.wav 807 | wav48/p363/p363_275.wav 808 | wav48/p287/p287_022.wav 809 | wav48/p249/p249_249.wav 810 | wav48/p295/p295_318.wav 811 | wav48/p244/p244_067.wav 812 | wav48/p310/p310_303.wav 813 | wav48/p302/p302_285.wav 814 | wav48/p260/p260_062.wav 815 | wav48/p277/p277_064.wav 816 | wav48/p326/p326_157.wav 817 | wav48/p259/p259_003.wav 818 | wav48/p275/p275_168.wav 819 | wav48/p272/p272_243.wav 820 | wav48/p304/p304_164.wav 821 | wav48/p314/p314_119.wav 822 | wav48/p241/p241_067.wav 823 | wav48/p282/p282_243.wav 824 | wav48/p305/p305_178.wav 825 | wav48/p272/p272_136.wav 826 | wav48/p271/p271_138.wav 827 | wav48/p282/p282_144.wav 828 | wav48/p299/p299_359.wav 829 | wav48/p363/p363_388.wav 830 | wav48/p281/p281_417.wav 831 | wav48/p239/p239_010.wav 832 | wav48/p295/p295_041.wav 833 | wav48/p300/p300_157.wav 834 | wav48/p275/p275_157.wav 835 | wav48/p245/p245_062.wav 836 | wav48/p292/p292_203.wav 837 | wav48/p255/p255_301.wav 838 | wav48/p334/p334_029.wav 839 | wav48/p313/p313_305.wav 840 | wav48/p293/p293_351.wav 841 | wav48/p334/p334_324.wav 842 | wav48/p274/p274_114.wav 843 | wav48/p266/p266_034.wav 844 | wav48/p307/p307_230.wav 845 | wav48/p295/p295_403.wav 846 | wav48/p227/p227_257.wav 847 | wav48/p311/p311_334.wav 848 | wav48/p288/p288_227.wav 849 | wav48/p374/p374_399.wav 850 | wav48/p301/p301_386.wav 851 | wav48/p264/p264_260.wav 852 | wav48/p333/p333_068.wav 853 | wav48/p364/p364_095.wav 854 | wav48/p286/p286_082.wav 855 | wav48/p281/p281_206.wav 856 | wav48/p240/p240_060.wav 857 | wav48/p263/p263_118.wav 858 | wav48/p245/p245_066.wav 859 | wav48/p251/p251_345.wav 860 | wav48/p270/p270_400.wav 861 | wav48/p264/p264_019.wav 862 | wav48/p278/p278_301.wav 863 | wav48/p277/p277_086.wav 864 | wav48/p257/p257_115.wav 865 | wav48/p276/p276_039.wav 866 | wav48/p329/p329_320.wav 867 | wav48/p278/p278_021.wav 868 | wav48/p268/p268_298.wav 869 | wav48/p329/p329_170.wav 870 | wav48/p236/p236_169.wav 871 | wav48/p340/p340_198.wav 872 | wav48/p306/p306_340.wav 873 | wav48/p285/p285_351.wav 874 | wav48/p247/p247_017.wav 875 | wav48/p287/p287_005.wav 876 | wav48/p288/p288_019.wav 877 | wav48/p230/p230_144.wav 878 | wav48/p376/p376_262.wav 879 | wav48/p312/p312_286.wav 880 | wav48/p258/p258_002.wav 881 | wav48/p246/p246_155.wav 882 | wav48/p268/p268_275.wav 883 | wav48/p259/p259_370.wav 884 | wav48/p298/p298_235.wav 885 | wav48/p343/p343_016.wav 886 | wav48/p323/p323_282.wav 887 | wav48/p288/p288_058.wav 888 | wav48/p260/p260_119.wav 889 | wav48/p229/p229_274.wav 890 | wav48/p318/p318_399.wav 891 | wav48/p306/p306_260.wav 892 | wav48/p313/p313_109.wav 893 | wav48/p314/p314_318.wav 894 | wav48/p226/p226_314.wav 895 | wav48/p261/p261_338.wav 896 | wav48/p247/p247_090.wav 897 | wav48/p266/p266_117.wav 898 | wav48/p279/p279_038.wav 899 | wav48/p276/p276_352.wav 900 | wav48/p268/p268_399.wav 901 | wav48/p363/p363_314.wav 902 | wav48/p277/p277_202.wav 903 | wav48/p340/p340_337.wav 904 | wav48/p288/p288_049.wav 905 | wav48/p232/p232_267.wav 906 | wav48/p294/p294_167.wav 907 | wav48/p275/p275_179.wav 908 | wav48/p312/p312_124.wav 909 | wav48/p284/p284_279.wav 910 | wav48/p267/p267_048.wav 911 | wav48/p335/p335_043.wav 912 | wav48/p265/p265_187.wav 913 | wav48/p272/p272_229.wav 914 | wav48/p347/p347_355.wav 915 | wav48/p294/p294_048.wav 916 | wav48/p298/p298_219.wav 917 | wav48/p326/p326_237.wav 918 | wav48/p304/p304_194.wav 919 | wav48/p265/p265_352.wav 920 | wav48/p275/p275_147.wav 921 | wav48/p276/p276_276.wav 922 | wav48/p231/p231_152.wav 923 | wav48/p231/p231_012.wav 924 | wav48/p265/p265_308.wav 925 | wav48/p343/p343_186.wav 926 | wav48/p225/p225_319.wav 927 | wav48/p232/p232_242.wav 928 | wav48/p293/p293_330.wav 929 | wav48/p308/p308_053.wav 930 | wav48/p279/p279_162.wav 931 | wav48/p263/p263_301.wav 932 | wav48/p297/p297_193.wav 933 | wav48/p297/p297_206.wav 934 | wav48/p292/p292_298.wav 935 | wav48/p263/p263_388.wav 936 | wav48/p280/p280_255.wav 937 | wav48/p283/p283_453.wav 938 | wav48/p287/p287_092.wav 939 | wav48/p234/p234_241.wav 940 | wav48/p255/p255_369.wav 941 | wav48/p323/p323_086.wav 942 | wav48/p304/p304_203.wav 943 | wav48/p259/p259_233.wav 944 | wav48/p285/p285_244.wav 945 | wav48/p255/p255_352.wav 946 | wav48/p269/p269_393.wav 947 | wav48/p244/p244_091.wav 948 | wav48/p229/p229_236.wav 949 | wav48/p231/p231_300.wav 950 | wav48/p258/p258_100.wav 951 | wav48/p273/p273_267.wav 952 | wav48/p283/p283_158.wav 953 | wav48/p292/p292_085.wav 954 | wav48/p347/p347_243.wav 955 | wav48/p272/p272_335.wav 956 | wav48/p312/p312_363.wav 957 | wav48/p345/p345_125.wav 958 | wav48/p376/p376_163.wav 959 | wav48/p262/p262_313.wav 960 | wav48/p240/p240_345.wav 961 | wav48/p292/p292_293.wav 962 | wav48/p263/p263_275.wav 963 | wav48/p360/p360_313.wav 964 | wav48/p251/p251_078.wav 965 | wav48/p294/p294_041.wav 966 | wav48/p246/p246_164.wav 967 | wav48/p335/p335_343.wav 968 | wav48/p301/p301_157.wav 969 | wav48/p281/p281_230.wav 970 | wav48/p329/p329_098.wav 971 | wav48/p263/p263_123.wav 972 | wav48/p240/p240_051.wav 973 | wav48/p283/p283_004.wav 974 | wav48/p316/p316_037.wav 975 | wav48/p351/p351_283.wav 976 | wav48/p268/p268_093.wav 977 | wav48/p267/p267_290.wav 978 | wav48/p228/p228_092.wav 979 | wav48/p265/p265_102.wav 980 | wav48/p286/p286_067.wav 981 | wav48/p274/p274_064.wav 982 | wav48/p374/p374_292.wav 983 | wav48/p288/p288_107.wav 984 | wav48/p341/p341_249.wav 985 | wav48/p226/p226_228.wav 986 | wav48/p351/p351_357.wav 987 | wav48/p361/p361_346.wav 988 | wav48/p334/p334_184.wav 989 | wav48/p316/p316_378.wav 990 | wav48/p258/p258_276.wav 991 | wav48/p276/p276_030.wav 992 | wav48/p266/p266_115.wav 993 | wav48/p335/p335_002.wav 994 | wav48/p232/p232_412.wav 995 | wav48/p232/p232_406.wav 996 | wav48/p310/p310_384.wav 997 | wav48/p227/p227_153.wav 998 | wav48/p363/p363_244.wav 999 | wav48/p335/p335_027.wav 1000 | wav48/p336/p336_388.wav 1001 | -------------------------------------------------------------------------------- /models/mdct.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Callable 2 | from torch_scatter import scatter 3 | from einops import rearrange 4 | import numpy as np 5 | import torch 6 | import torch.nn 7 | import torch.fft 8 | from torch.nn.functional import pad, fold 9 | from . import spectrogram 10 | 11 | import functools 12 | #import debugpy 13 | #debugpy.listen(("localhost", 5678)) 14 | 15 | 16 | class MDCT(torch.nn.Module): 17 | """ 18 | Serial version of the MDCT. 19 | Adapted from https://github.com/nils-werner/mdct/ 20 | I modified the original code from scipy+numpy to torch, and I made it into torch layers with GPU support. 21 | """ 22 | 23 | def __init__(self, window_function, step_length=None, n_fft=2048, center=True, device='cpu') -> None: 24 | super().__init__() 25 | self.window_function = window_function 26 | self.step_length = step_length 27 | self.n_fft = n_fft 28 | self.center = center 29 | self.device = device 30 | 31 | def mdct( 32 | self, 33 | x, 34 | odd=True, 35 | center=True, 36 | **kwargs 37 | ): 38 | """ Calculate lapped MDCT of input signal 39 | Parameters 40 | ---------- 41 | x : array_like 42 | The signal to be transformed. May be a 1D vector for single channel or 43 | a 2D matrix for multi channel data. In case of a mono signal, the data 44 | is must be a 1D vector of length :code:`samples`. In case of a multi 45 | channel signal, the data must be in the shape of :code:`samples x 46 | channels`. 47 | odd : boolean, optional 48 | Switch to oddly stacked transform. Defaults to :code:`True`. 49 | framelength : int 50 | The signal frame length. Defaults to :code:`2048`. 51 | hopsize : int 52 | The signal frame hopsize. Defaults to :code:`None`. Setting this 53 | value will override :code:`overlap`. 54 | overlap : int 55 | The signal frame overlap coefficient. Value :code:`x` means 56 | :code:`1/x` overlap. Defaults to :code:`2`. Note that anything but 57 | :code:`2` will result in a filterbank without perfect reconstruction. 58 | centered : boolean 59 | Pad input signal so that the first and last window are centered around 60 | the beginning of the signal. Defaults to :code:`True`. 61 | Disabling this will result in aliasing 62 | in the first and last half-frame. 63 | window : callable, array_like 64 | Window to be used for deringing. Can be :code:`False` to disable 65 | windowing. Defaults to :code:`scipy.signal.cosine`. 66 | transforms : module, optional 67 | Module reference to core transforms. Mostly used to replace 68 | fast with slow core transforms, for testing. Defaults to 69 | :mod:`mdct.fast` 70 | padding : int 71 | Zero-pad signal with x times the number of samples. 72 | Defaults to :code:`0`. 73 | save_settings : boolean 74 | Save settings used here in attribute :code:`out.stft_settings` so that 75 | :func:`ispectrogram` can infer these settings without the developer 76 | having to pass them again. 77 | Returns 78 | ------- 79 | out : array_like 80 | The signal (or matrix of signals). In case of a mono output signal, the 81 | data is formatted as a 1D vector of length :code:`samples`. In case of 82 | a multi channel output signal, the data is formatted as :code:`samples 83 | x channels`. 84 | See Also 85 | -------- 86 | mdct.fast.transforms.mdct : MDCT 87 | """ 88 | def cmdct(x, odd=True): 89 | """ Calculate complex MDCT/MCLT of input signal 90 | Parameters 91 | ---------- 92 | x : array_like 93 | The input signal 94 | odd : boolean, optional 95 | Switch to oddly stacked transform. Defaults to :code:`True`. 96 | Returns 97 | ------- 98 | out : array_like 99 | The output signal 100 | """ 101 | N = len(x) // 2 102 | n0 = (N + 1) / 2 103 | if odd: 104 | outlen = N 105 | pre_twiddle = torch.exp(-1j * np.pi * 106 | torch.arange(N * 2) / (N * 2)).to(self.device) 107 | offset = 0.5 108 | else: 109 | outlen = N + 1 110 | pre_twiddle = 1.0 111 | offset = 0.0 112 | 113 | post_twiddle = torch.exp( 114 | -1j * np.pi * n0 * (torch.arange(outlen) + offset) / N 115 | ).to(self.device) 116 | 117 | X = torch.fft.fft(x * pre_twiddle)[:outlen] 118 | 119 | if not odd: 120 | X[0] *= np.sqrt(0.5) 121 | X[-1] *= np.sqrt(0.5) 122 | 123 | return X * post_twiddle * np.sqrt(1 / N) 124 | 125 | def _mdct(x, odd=True): 126 | """ Calculate modified discrete cosine transform of input signal 127 | Parameters 128 | ---------- 129 | X : array_like 130 | The input signal 131 | odd : boolean, optional 132 | Switch to oddly stacked transform. Defaults to :code:`True`. 133 | Returns 134 | ------- 135 | out : array_like 136 | The output signal 137 | """ 138 | return torch.real(cmdct(x, odd=odd)) * np.sqrt(2) 139 | 140 | def _mdst(x, odd=True): 141 | """ Calculate modified discrete sine transform of input signal 142 | Parameters 143 | ---------- 144 | X : array_like 145 | The input signal 146 | odd : boolean, optional 147 | Switch to oddly stacked transform. Defaults to :code:`True`. 148 | Returns 149 | ------- 150 | out : array_like 151 | The output signal 152 | """ 153 | return -1 * torch.imag(cmdct(x, odd=odd)) * np.sqrt(2) 154 | 155 | frame_length = len(self.window_function) 156 | 157 | if not odd: 158 | return spectrogram.spectrogram( 159 | x, 160 | transform=[ 161 | functools.partial(_mdct, odd=False), 162 | functools.partial(_mdst, odd=False), 163 | ], 164 | halved=False, 165 | frame_length=frame_length, 166 | **kwargs 167 | ) 168 | else: 169 | return spectrogram.spectrogram( 170 | x, 171 | transform=_mdct, 172 | halved=False, 173 | frame_length=frame_length, 174 | **kwargs 175 | ) 176 | 177 | def forward(self, x): 178 | x = self.mdct(x=x, window_function=self.window_function, 179 | step_length=self.step_length, n_fft=self.n_fft, center=self.center, padding=0) 180 | return x 181 | 182 | 183 | class IMDCT(torch.nn.Module): 184 | def __init__(self, window_function, step_length=None, device='cuda', n_fft=2048, out_length=48000, center=True): 185 | super().__init__() 186 | self.window_function = window_function 187 | self.step_length = step_length 188 | self.n_fft = n_fft 189 | self.out_length = out_length 190 | self.device = device 191 | 192 | def imdct( 193 | self, 194 | X, 195 | odd=True, 196 | **kwargs 197 | ): 198 | """ Calculate lapped inverse MDCT of input signal 199 | Parameters 200 | ---------- 201 | x : array_like 202 | The spectrogram to be inverted. May be a 2D matrix for single channel 203 | or a 3D tensor for multi channel data. In case of a mono signal, the 204 | data must be in the shape of :code:`bins x frames`. In case of a multi 205 | channel signal, the data must be in the shape of :code:`bins x frames x 206 | channels`. 207 | odd : boolean, optional 208 | Switch to oddly stacked transform. Defaults to :code:`True`. 209 | framelength : int 210 | The signal frame length. Defaults to infer from data. 211 | hopsize : int 212 | The signal frame hopsize. Defaults to infer from data. Setting this 213 | value will override :code:`overlap`. 214 | overlap : int 215 | The signal frame overlap coefficient. Value :code:`x` means 216 | :code:`1/x` overlap. Defaults to infer from data. Note that anything 217 | but :code:`2` will result in a filterbank without perfect 218 | reconstruction. 219 | centered : boolean 220 | Pad input signal so that the first and last window are centered around 221 | the beginning of the signal. Defaults to to infer from data. 222 | The first and last half-frame will have aliasing, so using 223 | centering during forward MDCT is recommended. 224 | window : callable, array_like 225 | Window to be used for deringing. Can be :code:`False` to disable 226 | windowing. Defaults to to infer from data. 227 | halved : boolean 228 | Switch to reconstruct the other halve of the spectrum if the forward 229 | transform has been truncated. Defaults to to infer from data. 230 | transforms : module, optional 231 | Module reference to core transforms. Mostly used to replace 232 | fast with slow core transforms, for testing. Defaults to 233 | :mod:`mdct.fast` 234 | padding : int 235 | Zero-pad signal with x times the number of samples. Defaults to infer 236 | from data. 237 | outlength : int 238 | Crop output signal to length. Useful when input length of spectrogram 239 | did not fit into framelength and input data had to be padded. Not 240 | setting this value will disable cropping, the output data may be 241 | longer than expected. 242 | Returns 243 | ------- 244 | out : array_like 245 | The output signal 246 | See Also 247 | -------- 248 | mdct.fast.transforms.imdct : inverse MDCT 249 | """ 250 | def icmdct(X, odd=True): 251 | """ Calculate inverse complex MDCT/MCLT of input signal 252 | Parameters 253 | ---------- 254 | X : array_like 255 | The input signal 256 | odd : boolean, optional 257 | Switch to oddly stacked transform. Defaults to :code:`True`. 258 | Returns 259 | ------- 260 | out : array_like 261 | The output signal 262 | """ 263 | if not odd and len(X) % 2 == 0: 264 | raise ValueError( 265 | "Even inverse CMDCT requires an odd number " 266 | "of coefficients" 267 | ) 268 | 269 | if odd: 270 | N = len(X) 271 | n0 = (N + 1) / 2 272 | 273 | post_twiddle = torch.exp( 274 | 1j * np.pi * (torch.arange(N * 2) + n0) / (N * 2) 275 | ).to(self.device) 276 | 277 | Y = torch.zeros(N * 2, dtype=X.dtype) 278 | Y[:N] = X 279 | #Y[N:] = -1 * torch.conj(X[::-1]) 280 | Y[N:] = -1 * torch.conj(X.flip(dims=(0,))) 281 | else: 282 | N = len(X) - 1 283 | n0 = (N + 1) / 2 284 | 285 | post_twiddle = 1.0 286 | 287 | X[0] *= torch.sqrt(2) 288 | X[-1] *= torch.sqrt(2) 289 | 290 | Y = torch.zeros(N * 2, dtype=X.dtype) 291 | Y[:N+1] = X 292 | #Y[N+1:] = -1 * torch.conj(X[-2:0:-1]) 293 | Y[N+1:] = -1 * torch.conj(X[:-2].flip(dims=(0,))) 294 | 295 | pre_twiddle = (torch.exp(1j * np.pi * n0 * 296 | torch.arange(N * 2) / N)).to(self.device) 297 | 298 | y = torch.fft.ifft(Y.to(self.device) * pre_twiddle) 299 | 300 | return torch.real(y * post_twiddle) * np.sqrt(N) 301 | 302 | def _imdct(X, odd=True): 303 | """ Calculate inverse modified discrete cosine transform of input signal 304 | Parameters 305 | ---------- 306 | X : array_like 307 | The input signal 308 | odd : boolean, optional 309 | Switch to oddly stacked transform. Defaults to :code:`True`. 310 | Returns 311 | ------- 312 | out : array_like 313 | The output signal 314 | """ 315 | return icmdct(X, odd=odd) * np.sqrt(2) 316 | 317 | def _imdst(X, odd=True): 318 | """ Calculate inverse modified discrete sine transform of input signal 319 | Parameters 320 | ---------- 321 | X : array_like 322 | The input signal 323 | odd : boolean, optional 324 | Switch to oddly stacked transform. Defaults to :code:`True`. 325 | Returns 326 | ------- 327 | out : array_like 328 | The output signal 329 | """ 330 | return -1 * icmdct(X * 1j, odd=odd) * np.sqrt(2) 331 | 332 | frame_length = len(self.window_function) 333 | 334 | if not odd: 335 | return spectrogram.ispectrogram( 336 | X, 337 | transform=[ 338 | functools.partial(_imdct, odd=False), 339 | functools.partial(_imdst, odd=False), 340 | ], 341 | halved=False, 342 | **kwargs 343 | ) 344 | else: 345 | return spectrogram.ispectrogram( 346 | X, 347 | transform=_imdct, 348 | halved=False, 349 | frame_length=frame_length, 350 | **kwargs 351 | ) 352 | 353 | def forward(self, X): 354 | X = self.imdct(X=X, window_function=self.window_function, step_length=self.step_length, 355 | n_fft=self.n_fft, padding=0, out_length=self.out_length) 356 | return X 357 | 358 | 359 | class MDCT4(torch.nn.Module): 360 | """ 361 | The exact version of the MDCT, using modified DCT-IV. 362 | Borrowed from MATLAB implementation. 363 | """ 364 | 365 | def __init__(self, n_fft=2048, hop_length=None, win_length=None, window=None, center=True, pad_mode='constant', device='cuda') -> None: 366 | super().__init__() 367 | self.n_fft = n_fft 368 | self.pad_mode = pad_mode 369 | self.device = device 370 | self.hop_length = hop_length 371 | self.center = center 372 | 373 | # making window 374 | if window is None: 375 | window = torch.ones 376 | if callable(window): 377 | self.win_length = int(win_length) 378 | self.window = window(self.win_length).to(self.device) 379 | else: 380 | self.window = window.to(self.device) 381 | self.win_length = len(window) 382 | 383 | assert self.win_length <= self.n_fft, 'Window lenth %d should be no more than fft length %d' % ( 384 | self.win_length, self.n_fft) 385 | assert self.hop_length <= self.win_length, 'You hopped more than one frame' 386 | 387 | self.exp1 = torch.exp(-1j*torch.pi/self.n_fft*torch.arange(start=0, 388 | end=self.n_fft, step=1, dtype=torch.float64, device=self.device)) 389 | self.exp2 = torch.exp(-1j*(torch.pi/(2*self.n_fft)+torch.pi/4)*torch.arange( 390 | start=1, end=self.n_fft, step=2, dtype=torch.float64, device=self.device)) 391 | 392 | def forward(self, signal, return_frames: bool = False): 393 | # Pad the signal to a proper length 394 | signal_len = int(len(signal)) 395 | start_pad = 0 396 | # Pad the signal so that the t-th frame is centered at time t * hop_length. Otherwise, the t-th frame begins at time t * hop_length. 397 | if self.center: 398 | start_pad = self.hop_length 399 | additional_len = signal_len % self.hop_length 400 | end_pad = start_pad 401 | if additional_len: 402 | end_pad = start_pad + self.hop_length - additional_len 403 | signal = pad(signal, (start_pad, end_pad), mode=self.pad_mode) 404 | 405 | # Slice the signal with overlapping 406 | signal = signal.unfold( 407 | dimension=-1, size=self.win_length, step=self.hop_length) 408 | 409 | # Apply windows to each pieces 410 | signal = torch.mul(signal, self.window) 411 | if return_frames: 412 | frames = signal.clone() 413 | else: 414 | frames = torch.empty(1) 415 | 416 | # Pad zeros for DCT 417 | if self.n_fft > self.win_length: 418 | signal = pad(signal, (0, self.n_fft-self.win_length), 419 | mode='constant') 420 | 421 | signal = signal*self.exp1 422 | signal = torch.fft.fft(signal)[..., :self.n_fft//2] 423 | signal = torch.real(self.exp2*signal) 424 | 425 | return signal, frames 426 | 427 | 428 | class IMDCT4(torch.nn.Module): 429 | def __init__(self, n_fft=2048, hop_length=None, win_length=None, window=None, center=True, pad_mode='constant', out_length=None, device='cuda') -> None: 430 | super().__init__() 431 | self.n_fft = n_fft 432 | self.pad_mode = pad_mode 433 | self.device = device 434 | self.hop_length = hop_length 435 | self.center = center 436 | self.out_length = out_length 437 | 438 | # making window 439 | if window is None: 440 | window = torch.ones 441 | if callable(window): 442 | self.win_length = int(win_length) 443 | self.window = window(self.win_length).to(self.device) 444 | else: 445 | self.window = window.to(self.device) 446 | self.win_length = len(window) 447 | 448 | assert self.win_length <= self.n_fft, 'Window lenth %d should be no more than fft length %d' % ( 449 | self.win_length, self.n_fft) 450 | assert self.hop_length <= self.win_length, 'You hopped more than one frame' 451 | 452 | self.exp1 = torch.exp(-1j*(torch.pi/(2*self.n_fft)+torch.pi/4)*torch.arange( 453 | start=1, end=self.n_fft, step=2, dtype=torch.float64, device=self.device)) 454 | self.exp2 = torch.exp(-1j*torch.pi/(2*self.n_fft)*torch.arange( 455 | start=0, end=2*self.n_fft, step=2, dtype=torch.float64, device=self.device)) 456 | 457 | def forward(self, signal, return_frames: bool = False): 458 | assert signal.dim() == 3, 'Only tensors shaped in BHW are supported, got tensor of shape %s' % ( 459 | str(signal.size())) 460 | assert signal.size( 461 | )[-1] == self.n_fft//2, 'The last dim of input tensor should match the n_fft. Expected %d ,got %d' % (self.n_fft, signal.size()[-1]) 462 | 463 | # Inverse transform at the last dim 464 | signal = self.exp1*signal 465 | signal = torch.fft.fft(signal, n=self.n_fft) 466 | signal = torch.real(signal*self.exp2) 467 | 468 | # Remove padded zeros when doing dct 469 | if self.n_fft > self.win_length: 470 | signal = signal[..., :self.win_length] 471 | 472 | # Apply windows to each pieces 473 | signal = torch.mul(signal, self.window) 474 | if return_frames: 475 | frames = signal.clone() 476 | else: 477 | frames = torch.zeros(1) 478 | 479 | # Overlapping adding by fold() 480 | out_len = (signal.size()[-2]-1) * self.hop_length + self.win_length 481 | signal = 4/self.n_fft*fold(signal.transpose_(-1, -2), kernel_size=( 482 | 1, self.win_length), stride=(1, self.hop_length), output_size=(1, out_len)) 483 | 484 | if self.center: 485 | # extract the middle part 486 | signal = signal[..., self.win_length//2:-self.win_length//2] 487 | signal = signal if self.out_length is None else signal[..., 488 | :self.out_length] 489 | return signal, frames 490 | 491 | 492 | class FastMDCT4(torch.nn.Module): 493 | ''' 494 | Fast implementation of MDCT. Port to PyTorch by Chenhao Shuai 495 | Ref:_https://ccrma.stanford.edu/~bosse/proj/node28.html 496 | Sporer T, Brandenburg K, Edler B, The Use of Multirate Filter Banks for Coding of High Quality Digital Audio, 6th European Signal Processing Conference (EUSIPCO), Amsterdam, June 1992, Vol.1 pp. 211-214. 497 | ''' 498 | def __init__(self, n_fft: Optional[int] = 2048, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Union[torch.Tensor, np.ndarray, list, Callable, None] = None, center: bool = True, pad_mode: str = 'constant', device: str = 'cuda') -> None: 499 | super().__init__() 500 | self.n_fft = n_fft 501 | self.pad_mode = pad_mode 502 | self.device = device 503 | self.hop_length = hop_length 504 | self.center = center 505 | 506 | if callable(window): 507 | self.win_length = int(win_length) 508 | self.window = window(self.win_length).to( 509 | device=self.device, dtype=torch.float64) 510 | elif isinstance(window, torch.Tensor): 511 | self.window = window.to(device=self.device, dtype=torch.float64) 512 | self.win_length = len(window) 513 | elif isinstance(window, np.ndarray) or isinstance(window, list): 514 | self.window = torch.tensor( 515 | window, device=self.device, dtype=torch.float64) 516 | self.win_length = len(window) 517 | elif window is None: 518 | if win_length is not None: 519 | self.win_length = win_length 520 | elif n_fft is not None: 521 | self.win_length = n_fft 522 | else: 523 | assert False, 'You should specify window length or n_fft' 524 | self.window = torch.ones( 525 | (self.win_length,), device=self.device, dtype=torch.float64) 526 | else: 527 | raise NotImplementedError 528 | 529 | assert self.win_length <= self.n_fft, f'Window lenth {self.win_length} should be no more than fft length {self.n_fft}' 530 | assert self.hop_length <= self.win_length, 'You hopped more than one frame' 531 | 532 | self.idx = torch.stack(( 533 | torch.arange( 534 | start=0, end=n_fft//2, step=2, 535 | dtype=torch.long, device=self.device), 536 | torch.arange( 537 | start=n_fft-1, end=n_fft//2, step=-2, 538 | dtype=torch.long, device=self.device), 539 | torch.arange( 540 | start=n_fft//2, end=n_fft, step=2, 541 | dtype=torch.long, device=self.device), 542 | torch.arange( 543 | start=n_fft//2-1, end=0, step=-2, 544 | dtype=torch.long, device=self.device) 545 | ), dim=0) 546 | 547 | # self.sqrtN = torch.sqrt(torch.tensor( 548 | # [self.n_fft], device=self.device, dtype=torch.float64)) 549 | self.post_exp = torch.exp( 550 | -2j*torch.pi/self.n_fft*( 551 | torch.arange( 552 | start=0, 553 | end=self.n_fft//4, 554 | step=1, 555 | dtype=torch.float64, 556 | device=self.device 557 | )+1/8 558 | ) 559 | ).to(torch.complex64) 560 | 561 | self.pre_exp = (self.make_pre_exp()*self.window).to(torch.complex64) 562 | self.pre_idx = self.make_pre_idx() 563 | self.post_idx = self.make_post_idx() 564 | # self.idx = self.idx.clone().mT.roll(self.n_fft//8,0).contiguous() 565 | 566 | def make_pre_exp(self): 567 | sgn = torch.ones(1, self.n_fft, dtype=torch.complex128, 568 | device=self.device) 569 | # Shift for Time-Domain Aliasing Cancellation (TDAC) 570 | sgn[..., -self.n_fft//4:] *= -1 571 | sgn = sgn.roll(self.n_fft//4, dims=-1) 572 | sgn[..., self.idx[0]] *= self.post_exp 573 | sgn[..., self.idx[1]] *= -self.post_exp 574 | sgn[..., self.idx[2]] *= -1j*self.post_exp 575 | sgn[..., self.idx[3]] *= 1j*self.post_exp 576 | return sgn.roll(-self.n_fft//4, dims=-1).contiguous() 577 | 578 | def make_pre_idx(self): 579 | i = torch.arange(start=0, end=self.n_fft, step=1, 580 | dtype=torch.long, device=self.device) 581 | i = i.roll(self.n_fft//4, dims=-1) 582 | idx_ = torch.stack([i[self.idx[0]], i[self.idx[1]], 583 | i[self.idx[2]], i[self.idx[3]]], dim=1) 584 | index = torch.zeros( 585 | 1, self.n_fft, device=self.device, dtype=torch.long) 586 | for i in torch.arange(0, self.n_fft//4, dtype=torch.long): 587 | index[..., idx_[i]] = i 588 | return index.squeeze().contiguous() 589 | 590 | def make_post_idx(self): 591 | idx = torch.arange(self.n_fft//2, dtype=torch.long, 592 | device=self.device).reshape(-1, 2) 593 | idx[:, 1] = idx[:, 1].flip(-1) 594 | return idx.flatten().contiguous() 595 | 596 | def forward(self, signal: torch.tensor, return_frames: bool = False): 597 | if signal.dim() == 2: # B T (mono) 598 | signal = signal[:, None, :] 599 | elif signal.dim() == 3: # B C T (stereo) 600 | pass 601 | else: 602 | raise NotImplementedError 603 | 604 | # Pad the signal to a proper length 605 | B, C, T = signal.shape 606 | start_pad = 0 607 | # Pad the signal so that the t-th frame is centered at time t * hop_length. Otherwise, the t-th frame begins at time t * hop_length. 608 | if self.center: 609 | start_pad = self.hop_length 610 | additional_len = T % self.hop_length 611 | end_pad = start_pad 612 | if additional_len: 613 | end_pad = start_pad + self.hop_length - additional_len 614 | signal = pad(signal, (start_pad, end_pad), mode=self.pad_mode) 615 | 616 | # Slice the signal with overlapping 617 | signal = signal.unfold( 618 | dimension=-1, size=self.win_length, step=self.hop_length) 619 | signal = signal*self.pre_exp 620 | signal = scatter(signal, self.pre_idx, dim=-1, reduce='sum') 621 | signal = torch.fft.fft(signal, dim=-1) 622 | # post-twiddle 623 | signal = torch.conj_physical(signal*self.post_exp) 624 | # rearranging 625 | signal = torch.view_as_real(signal) 626 | signal = signal.flatten(-2)[..., self.post_idx] 627 | 628 | return signal, None 629 | 630 | 631 | class FastIMDCT4(torch.nn.Module): 632 | def __init__(self, n_fft: Optional[int] = 2048, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Union[torch.Tensor, np.ndarray, list, Callable, None] = None, center: bool = True, pad_mode: str = 'constant', out_length: Optional[int] = None, device: str = 'cuda') -> None: 633 | super().__init__() 634 | self.n_fft = n_fft 635 | self.pad_mode = pad_mode 636 | self.device = device 637 | self.hop_length = hop_length 638 | self.center = center 639 | self.out_length = out_length 640 | 641 | if callable(window): 642 | self.win_length = int(win_length) 643 | self.window = window(self.win_length).to( 644 | device=self.device, dtype=torch.float64) 645 | elif isinstance(window, torch.Tensor): 646 | self.window = window.to(device=self.device, dtype=torch.float64) 647 | self.win_length = len(window) 648 | elif isinstance(window, np.ndarray) or isinstance(window, list): 649 | self.window = torch.tensor( 650 | window, device=self.device, dtype=torch.float64) 651 | self.win_length = len(window) 652 | elif isinstance(window, None): 653 | if win_length is not None: 654 | self.win_length = win_length 655 | elif n_fft is not None: 656 | self.win_length = n_fft 657 | else: 658 | assert False, 'You should specify window length or n_fft' 659 | self.window = torch.ones( 660 | (self.win_length,), device=self.device, dtype=torch.float64) 661 | else: 662 | raise NotImplementedError 663 | 664 | assert self.win_length <= self.n_fft, f'Window lenth {self.win_length} should be no more than fft length {self.n_fft}' 665 | assert self.hop_length <= self.win_length, 'You hopped more than one frame' 666 | 667 | self.exp = torch.exp( 668 | -2j*torch.pi/self.n_fft*( 669 | torch.arange( 670 | start=0, 671 | end=self.n_fft//4, 672 | step=1, 673 | dtype=torch.float32, 674 | device=self.device 675 | )+1/8 676 | ) 677 | ).contiguous() 678 | self.pre_idx = self.make_pre_idx() 679 | self.post_idx = self.make_post_index() 680 | self.window = (4.0*self.make_sign()*self.window / 681 | self.n_fft).to(torch.float32).contiguous() 682 | 683 | def make_pre_idx(self): 684 | a = torch.arange(self.n_fft//2, dtype=torch.long, 685 | device=self.device).unfold(-1, 2, 2) 686 | return torch.stack((a[..., 0], a[..., 1].flip(-1)), dim=-1).contiguous() 687 | 688 | def make_post_index(self): 689 | a = torch.arange(0, self.n_fft//2, 2, 690 | dtype=torch.long, device=self.device) 691 | b = torch.arange(self.n_fft//2-1, 0, -2, 692 | dtype=torch.long, device=self.device) 693 | idx = torch.empty((self.n_fft,), dtype=torch.long, device=self.device) 694 | idx[0:self.n_fft//2:2] = a 695 | idx[1:self.n_fft//2:2] = b 696 | idx[self.n_fft//2:] = idx[:self.n_fft//2].flip(0) 697 | return idx.roll(-self.n_fft//4).contiguous() 698 | 699 | def make_sign(self): 700 | sign = torch.ones((self.n_fft,), device=self.device, 701 | dtype=torch.float64) 702 | sign[1::2] *= -1 703 | sign[..., 0:self.n_fft//4] *= -1 704 | return sign.roll(-self.n_fft//4).contiguous() 705 | 706 | def forward(self, signal: torch.Tensor, return_frames: bool = False): 707 | assert signal.dim( 708 | ) <= 4, f'Only tensors shaped in BHW or BCHW are supported, got tensor of shape {signal.shape}' 709 | assert signal.shape[ 710 | -1] == self.n_fft//2, f'The last dim of input tensor should match the n_fft. Expected {self.n_fft}, got {signal.shape[-1]}' 711 | 712 | if signal.dim() == 4: 713 | C = signal.shape[1] 714 | signal = rearrange(signal, 'B C T N -> (B C) T N') 715 | else: 716 | C = 1 717 | 718 | signal = signal.to(self.device) 719 | # # Inverse transform at the last dim 720 | signal = torch.view_as_complex(signal[..., self.pre_idx]) 721 | 722 | signal = self.exp*signal 723 | signal = torch.fft.fft(signal) 724 | signal = self.exp*signal 725 | 726 | # [0+4j, 1+5j, 2+6j, 3+7j] -> [2,-5, 3, -4, 4, -3, 5, -2, 6, -1, 7, 0, 0, 7, -1, 6] 727 | signal = torch.view_as_real(signal).flatten(-2)[..., self.post_idx] 728 | 729 | # Apply windows to each pieces 730 | signal = self.window*signal 731 | if return_frames: 732 | frames = signal.clone() 733 | else: 734 | frames = torch.empty(1) 735 | 736 | # Overlapping adding by fold() 737 | out_len = (signal.shape[-2]-1) * self.hop_length + self.win_length 738 | signal = fold(signal.mT, kernel_size=(1, self.win_length), 739 | stride=(1, self.hop_length), output_size=(1, out_len)) 740 | 741 | if self.center: # extract the middle part 742 | signal = signal[..., self.win_length//2:-self.win_length//2] 743 | if self.out_length is not None: 744 | signal = signal[..., :self.out_length] 745 | if C != 1: 746 | signal = rearrange(signal, '(B C) T N-> B C T N') 747 | return signal, frames 748 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | import torch.nn as nn 4 | import functools 5 | import numpy as np 6 | from torch.nn.functional import interpolate, pad 7 | 8 | ############################################################################### 9 | # Functions 10 | ############################################################################### 11 | 12 | 13 | def weights_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv2d') != -1: 16 | m.weight.data.normal_(0.0, 0.02) 17 | elif classname.find('BatchNorm2d') != -1: 18 | m.weight.data.normal_(1.0, 0.02) 19 | m.bias.data.fill_(0) 20 | 21 | 22 | def get_norm_layer(norm_type='instance'): 23 | if norm_type == 'batch': 24 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 25 | elif norm_type == 'instance': 26 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 27 | else: 28 | raise NotImplementedError( 29 | 'normalization layer [%s] is not found' % norm_type) 30 | return norm_layer 31 | 32 | 33 | def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, 34 | n_blocks_local=3, norm='instance', gpu_ids=[], upsample_type='transconv', downsample_type='conv', input_size=(128, 256), n_attn_g=0, n_attn_l=0, proj_factor_g=4, heads_g=4, dim_head_g=128, proj_factor_l=4, heads_l=4, dim_head_l=128): 35 | norm_layer = get_norm_layer(norm_type=norm) 36 | if netG == 'global': 37 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, 38 | n_blocks_global, norm_layer, downsample_type=downsample_type, upsample_type=upsample_type, 39 | input_size=input_size, 40 | n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g) 41 | elif netG == 'local': 42 | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, 43 | n_local_enhancers, n_blocks_local, norm_layer, downsample_type=downsample_type, upsample_type=upsample_type, 44 | input_size=input_size, 45 | n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g, n_attn_l=n_attn_l, proj_factor_l=proj_factor_l, heads_l=heads_l, dim_head_l=dim_head_l) 46 | elif netG == 'encoder': 47 | netG = Encoder(input_nc, output_nc, ngf, 48 | n_downsample_global, norm_layer) 49 | else: 50 | raise('generator not implemented!') 51 | print(netG) 52 | if len(gpu_ids) > 0: 53 | assert(torch.cuda.is_available()) 54 | netG.cuda(gpu_ids[0]) 55 | netG.apply(weights_init) 56 | return netG 57 | 58 | 59 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): 60 | norm_layer = get_norm_layer(norm_type=norm) 61 | netD = MultiscaleDiscriminator( 62 | input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) 63 | print(netD) 64 | if len(gpu_ids) > 0: 65 | assert(torch.cuda.is_available()) 66 | netD.cuda(gpu_ids[0]) 67 | netD.apply(weights_init) 68 | return netD 69 | 70 | 71 | def define_MR_D(ndf, n_layers_D, input_nc, norm='instance', use_sigmoid=False, num_D=1, gpu_ids=[], base_nfft=2048, window=None, min_value=1e-7, mdct_type='4', normalizer=None, getIntermFeat=False, abs_spectro=False): 72 | norm_layer = get_norm_layer(norm_type=norm) 73 | netD = MultiResolutionDiscriminator(ndf=ndf, n_layers=n_layers_D, input_nc=input_nc, norm_layer=norm_layer, num_D=num_D, base_nfft=base_nfft, window=window, 74 | min_value=min_value, mdct_type=mdct_type, use_sigmoid=use_sigmoid, normalizer=normalizer, getIntermFeat=getIntermFeat, abs_spectro=abs_spectro) 75 | print(netD) 76 | if len(gpu_ids) > 0: 77 | assert(torch.cuda.is_available()) 78 | netD.cuda(gpu_ids[0]) 79 | netD.apply(weights_init) 80 | return netD 81 | 82 | 83 | def print_network(net): 84 | if isinstance(net, list): 85 | net = net[0] 86 | num_params = 0 87 | for param in net.parameters(): 88 | num_params += param.numel() 89 | print(net) 90 | print('Total number of parameters: %d' % num_params) 91 | 92 | ############################################################################## 93 | # Losses 94 | ############################################################################## 95 | 96 | 97 | class GANLoss(nn.Module): 98 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 99 | device='cuda'): 100 | super(GANLoss, self).__init__() 101 | self.real_label = target_real_label 102 | self.fake_label = target_fake_label 103 | self.real_label_var = None 104 | self.fake_label_var = None 105 | self.device = device 106 | if use_lsgan: 107 | self.loss = nn.MSELoss() 108 | else: 109 | self.loss = nn.BCELoss() 110 | 111 | def get_target_tensor(self, input, target_is_real): 112 | target_tensor = None 113 | if target_is_real: 114 | create_label = ((self.real_label_var is None) or 115 | (self.real_label_var.shape != input.shape)) 116 | if create_label: 117 | self.real_label_var = torch.full(size=input.size(),fill_value=self.real_label, device=self.device, requires_grad=False) 118 | target_tensor = self.real_label_var 119 | else: 120 | create_label = ((self.fake_label_var is None) or 121 | (self.fake_label_var.shape != input.shape)) 122 | if create_label: 123 | self.fake_label_var = torch.full(size=input.size(),fill_value=self.fake_label, device=self.device, requires_grad=False) 124 | target_tensor = self.fake_label_var 125 | return target_tensor 126 | 127 | def __call__(self, input, target_is_real): 128 | if isinstance(input[0], list): 129 | loss = 0 130 | for input_i in input: 131 | pred = input_i[-1] 132 | target_tensor = self.get_target_tensor(pred, target_is_real) 133 | loss += self.loss(pred, target_tensor) 134 | return loss 135 | else: 136 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 137 | return self.loss(input[-1], target_tensor) 138 | 139 | 140 | class VGGLoss(nn.Module): 141 | def __init__(self, gpu_ids): 142 | super(VGGLoss, self).__init__() 143 | self.vgg = Vgg19().cuda() 144 | self.criterion = nn.L1Loss() 145 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 146 | 147 | def forward(self, x, y): 148 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 149 | loss = 0 150 | for i in range(len(x_vgg)): 151 | loss += self.weights[i] * \ 152 | self.criterion(x_vgg[i], y_vgg[i].detach()) 153 | return loss 154 | 155 | 156 | class SpecLoss(nn.Module): 157 | def __init__(self) -> None: 158 | super(SpecLoss, self).__init__() 159 | 160 | def forward(self, x, y): 161 | # input shape B,C,H,W 162 | N = x.shape[-1] 163 | spec_loss = torch.norm(x-y, p='fro', dim=(-1, -2)) / \ 164 | torch.norm(x, p='fro', dim=(-1, -2)) 165 | mag_loss = torch.norm(torch.log10( 166 | torch.abs(x)+1e-7) - torch.log10(torch.abs(y)+1e-7), p=1, dim=(-1, -2)) / N 167 | return torch.mean(spec_loss+mag_loss) 168 | ############################################################################## 169 | # Generator 170 | ############################################################################## 171 | 172 | 173 | class LocalEnhancer(nn.Module): 174 | def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, 175 | n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect', downsample_type='conv', upsample_type='transconv', n_attn_g=0, n_attn_l=0, input_size=(128, 256), proj_factor_g=4, heads_g=4, dim_head_g=128, proj_factor_l=4, heads_l=4, dim_head_l=128): 176 | super(LocalEnhancer, self).__init__() 177 | self.n_local_enhancers = n_local_enhancers 178 | 179 | ###### global generator model ##### 180 | ngf_global = ngf * (2**n_local_enhancers) 181 | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer, 182 | downsample_type=downsample_type, upsample_type=upsample_type, 183 | input_size=tuple(map(lambda x: x//2, input_size)), n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g).model 184 | # get rid of final convolution layers 185 | model_global = [model_global[i] for i in range(len(model_global)-3)] 186 | self.model = nn.Sequential(*model_global) 187 | 188 | # downsample 189 | if downsample_type == 'conv': 190 | downsample_layer = nn.Conv2d 191 | elif downsample_type == 'resconv': 192 | downsample_layer = ConvResBlock 193 | else: 194 | raise NotImplementedError( 195 | 'downsample layer [{:s}] is not found'.format(downsample_type)) 196 | # upsample 197 | if upsample_type == 'transconv': 198 | upsample_layer = nn.ConvTranspose2d 199 | elif upsample_type == 'interpolate': 200 | upsample_layer = InterpolateUpsample 201 | else: 202 | raise NotImplementedError( 203 | 'upsample layer [{:s}] is not found'.format(upsample_type)) 204 | ###### local enhancer layers ##### 205 | # downsample 206 | ngf_global = ngf * (2**(n_local_enhancers-1)) 207 | model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), 208 | norm_layer(ngf_global), nn.ReLU(True), 209 | downsample_layer(ngf_global, ngf_global * 2, 210 | kernel_size=3, stride=2, padding=1), 211 | norm_layer(ngf_global * 2), nn.ReLU(True)] 212 | # residual blocks 213 | model_upsample = [] 214 | for i in range(n_blocks_local): 215 | model_upsample += [ResnetBlock(ngf_global * 2, 216 | padding_type=padding_type, norm_layer=norm_layer)] 217 | # attention bottleneck 218 | if n_attn_l > 0: 219 | middle = n_blocks_local//2 220 | # 8x downsample 221 | down = [downsample_layer(ngf_global * 2, ngf_global, 222 | kernel_size=3, stride=2, padding=1), 223 | norm_layer(ngf_global), nn.ReLU(True)] 224 | down += [downsample_layer(ngf_global, ngf_global, 225 | kernel_size=3, stride=2, padding=1), 226 | norm_layer(ngf_global), nn.ReLU(True)]*2 227 | down = nn.Sequential(*down) 228 | model_upsample.insert(middle, down) 229 | 230 | middle += 1 231 | input_size = tuple(map(lambda x: x//16, input_size)) 232 | from bottleneck_transformer_pytorch import BottleStack 233 | attn_block = BottleStack(dim=ngf_global, fmap_size=input_size, dim_out=ngf_global*2, num_layers=n_attn_l, proj_factor=proj_factor_l, 234 | downsample=False, heads=heads_l, dim_head=dim_head_l, activation=nn.ReLU(True), rel_pos_emb=False) 235 | model_upsample.insert(middle, attn_block) 236 | model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global*2, kernel_size=3, stride=2, padding=1, output_padding=1), 237 | norm_layer(ngf_global), nn.ReLU(True)]*3 238 | 239 | model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), 240 | norm_layer(ngf_global), nn.ReLU(True)] 241 | 242 | # final convolution 243 | model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d( 244 | ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 245 | 246 | self.model1_1 = nn.Sequential(*model_downsample) 247 | self.model1_2 = nn.Sequential(*model_upsample) 248 | 249 | self.downsample = nn.AvgPool2d( 250 | 3, stride=2, padding=[1, 1], count_include_pad=False) 251 | self.freeze = False 252 | 253 | def forward(self, input): 254 | # create input pyramid 255 | input_downsampled = [input] 256 | for i in range(self.n_local_enhancers): 257 | input_downsampled.append(self.downsample(input_downsampled[-1])) 258 | 259 | # output at coarest level 260 | output_prev = self.model(input_downsampled[-1]) 261 | # build up one layer at a time 262 | model_downsample = self.model1_1 263 | model_upsample = self.model1_2 264 | input_i = input_downsampled[0] 265 | output_prev = model_upsample( 266 | model_downsample(input_i) + output_prev) 267 | return output_prev 268 | 269 | def set_freeze(self, freeze_global_d=True, freeze_global_u=False, freeze_local_d=True, freeze_local_u=False): 270 | print("The following layers will be freezed:") 271 | '''Freeze downsample layers''' 272 | print('Global:') 273 | for name, layer in self.model.named_children(): 274 | module_name = layer.__class__.__name__ 275 | if 'Conv2d' in module_name or 'ConvResBlock' in module_name: 276 | if freeze_global_d: 277 | print(name, module_name) 278 | for param in layer.parameters(): 279 | param.requires_grad = not freeze_global_d 280 | elif 'InterpolateUpsample' in module_name or 'ConvTranspose2d' in module_name or 'ResnetBlock' in module_name or 'BottleStack' in module_name: 281 | if freeze_global_u: 282 | print(name, module_name) 283 | for param in layer.parameters(): 284 | param.requires_grad = not freeze_global_u 285 | print('Loacl:') 286 | for name, layer in self.model1_1.named_children(): 287 | module_name = layer.__class__.__name__ 288 | for param in layer.parameters(): 289 | if freeze_local_d: 290 | print(name, module_name) 291 | param.requires_grad = not freeze_local_d 292 | 293 | for name, layer in self.model1_2.named_children(): 294 | module_name = layer.__class__.__name__ 295 | for param in layer.parameters(): 296 | if freeze_local_u: 297 | print(name, module_name) 298 | param.requires_grad = not freeze_local_u 299 | 300 | 301 | class GlobalGenerator(nn.Module): 302 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 303 | padding_type='reflect', upsample_type='transconv', downsample_type='conv', n_attn_g=0, input_size=(128, 256), proj_factor_g=4, heads_g=4, dim_head_g=128): 304 | assert(n_blocks >= 0) 305 | super(GlobalGenerator, self).__init__() 306 | activation = nn.ReLU(True) 307 | 308 | model = [nn.ReflectionPad2d(3), nn.Conv2d( 309 | input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] 310 | # downsample 311 | if downsample_type == 'conv': 312 | downsample_layer = nn.Conv2d 313 | elif downsample_type == 'resconv': 314 | downsample_layer = ConvResBlock 315 | else: 316 | raise NotImplementedError( 317 | 'downsample layer [{:s}] is not found'.format(downsample_type)) 318 | # upsample 319 | if upsample_type == 'transconv': 320 | upsample_layer = nn.ConvTranspose2d 321 | elif upsample_type == 'interpolate': 322 | upsample_layer = InterpolateUpsample 323 | else: 324 | raise NotImplementedError( 325 | 'upsample layer [{:s}] is not found'.format(upsample_type)) 326 | 327 | for i in range(n_downsampling): 328 | mult = 2**i 329 | model += [downsample_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 330 | norm_layer(ngf * mult * 2), activation] 331 | 332 | # resnet blocks 333 | mult = 2**n_downsampling 334 | bottle_neck = [] 335 | for i in range(n_blocks): 336 | bottle_neck += [ResnetBlock(ngf * mult, padding_type=padding_type, 337 | activation=activation, norm_layer=norm_layer)] 338 | if n_attn_g > 0: 339 | middle = n_blocks//2 340 | input_size = tuple(map(lambda x: x//mult, input_size)) 341 | from bottleneck_transformer_pytorch import BottleStack 342 | attn_block = BottleStack(dim=ngf * mult, fmap_size=input_size, dim_out=ngf * mult, num_layers=n_attn_g, proj_factor=proj_factor_g, 343 | downsample=False, heads=heads_g, dim_head=dim_head_g, activation=activation, rel_pos_emb=False) 344 | bottle_neck.insert(middle, attn_block) 345 | model += bottle_neck 346 | 347 | for i in range(n_downsampling): 348 | mult = 2**(n_downsampling - i) 349 | model += [upsample_layer(in_channels=ngf * mult, out_channels=int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 350 | norm_layer(int(ngf * mult / 2)), activation] 351 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 352 | output_nc, kernel_size=7, padding=0), nn.Tanh()] 353 | self.model = nn.Sequential(*model) 354 | self.freeze = False 355 | 356 | def forward(self, input): 357 | return self.model(input) 358 | 359 | def set_freeze(self, freeze=True): 360 | if self.freeze == freeze: 361 | return 362 | else: 363 | self.freeze = freeze 364 | print("The following layers will be freezed:") 365 | '''Freeze downsample layers''' 366 | for name, layer in self.model.named_children(): 367 | module_name = layer.__class__.__name__ 368 | if 'ResnetBlock' in module_name or 'BottleStack' in module_name: 369 | break 370 | print(name, module_name) 371 | for param in layer.parameters(): 372 | param.requires_grad = not freeze 373 | 374 | 375 | class InterpolateUpsample(nn.Module): 376 | """ 377 | An upsampling layer with an optional convolution. 378 | :param channels: channels in the inputs and outputs. 379 | :param use_conv: a bool determining if a convolution is applied. 380 | 381 | """ 382 | 383 | def __init__(self, *args, **kwargs): 384 | super(InterpolateUpsample, self).__init__() 385 | self.in_channels = kwargs['in_channels'] 386 | self.out_channels = kwargs['out_channels'] 387 | self.conv1 = nn.Conv2d( 388 | self.in_channels, self.out_channels, 5, padding=1) 389 | self.conv2 = nn.Conv2d( 390 | self.out_channels, self.out_channels, 3, padding=2) 391 | self.conv_res = nn.Conv2d( 392 | self.in_channels, self.out_channels, 3, padding=1) 393 | 394 | def forward(self, x): 395 | assert x.shape[1] == self.in_channels 396 | x = interpolate(x, scale_factor=2.0, mode="nearest") 397 | res_x = self.conv_res(x) 398 | x = self.conv1(x) 399 | x = self.conv2(x) 400 | return x+res_x 401 | 402 | 403 | class ConvResBlock(nn.Module): 404 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 405 | super(ConvResBlock, self).__init__() 406 | self.conv1 = nn.Conv2d(in_channels, in_channels, 407 | kernel_size, stride, padding) 408 | self.conv2 = nn.Conv2d( 409 | in_channels, out_channels, 5, padding=2) 410 | self.conv_res = nn.Conv2d( 411 | in_channels, out_channels, kernel_size=3, stride=1, padding=1) 412 | 413 | def forward(self, x): 414 | x = self.conv1(x) 415 | res_x = self.conv_res(x) 416 | x = self.conv2(x) 417 | return x+res_x 418 | # Define a resnet block 419 | 420 | 421 | class ResnetBlock(nn.Module): 422 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): 423 | super(ResnetBlock, self).__init__() 424 | self.conv_block = self.build_conv_block( 425 | dim, padding_type, norm_layer, activation, use_dropout) 426 | 427 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): 428 | conv_block = [] 429 | p = 0 430 | if padding_type == 'reflect': 431 | conv_block += [nn.ReflectionPad2d(1)] 432 | elif padding_type == 'replicate': 433 | conv_block += [nn.ReplicationPad2d(1)] 434 | elif padding_type == 'zero': 435 | p = 1 436 | else: 437 | raise NotImplementedError( 438 | 'padding [%s] is not implemented' % padding_type) 439 | 440 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 441 | norm_layer(dim), 442 | activation] 443 | if use_dropout: 444 | conv_block += [nn.Dropout(0.5)] 445 | 446 | p = 0 447 | if padding_type == 'reflect': 448 | conv_block += [nn.ReflectionPad2d(1)] 449 | elif padding_type == 'replicate': 450 | conv_block += [nn.ReplicationPad2d(1)] 451 | elif padding_type == 'zero': 452 | p = 1 453 | else: 454 | raise NotImplementedError( 455 | 'padding [%s] is not implemented' % padding_type) 456 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 457 | norm_layer(dim)] 458 | 459 | return nn.Sequential(*conv_block) 460 | 461 | def forward(self, x): 462 | out = x + self.conv_block(x) 463 | return out 464 | 465 | 466 | class Encoder(nn.Module): 467 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): 468 | super(Encoder, self).__init__() 469 | self.output_nc = output_nc 470 | 471 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), 472 | norm_layer(ngf), nn.ReLU(True)] 473 | # downsample 474 | for i in range(n_downsampling): 475 | mult = 2**i 476 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 477 | norm_layer(ngf * mult * 2), nn.ReLU(True)] 478 | 479 | # upsample 480 | for i in range(n_downsampling): 481 | mult = 2**(n_downsampling - i) 482 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 483 | norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] 484 | 485 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 486 | output_nc, kernel_size=7, padding=0), nn.Tanh()] 487 | self.model = nn.Sequential(*model) 488 | 489 | def forward(self, input, inst): 490 | outputs = self.model(input) 491 | 492 | # instance-wise average pooling 493 | outputs_mean = outputs.clone() 494 | inst_list = np.unique(inst.cpu().numpy().astype(int)) 495 | for i in inst_list: 496 | for b in range(input.size()[0]): 497 | indices = (inst[b:b+1] == int(i)).nonzero() # n x 4 498 | for j in range(self.output_nc): 499 | output_ins = outputs[indices[:, 0] + b, 500 | indices[:, 1] + j, indices[:, 2], indices[:, 3]] 501 | mean_feat = torch.mean(output_ins).expand_as(output_ins) 502 | outputs_mean[indices[:, 0] + b, indices[:, 1] + 503 | j, indices[:, 2], indices[:, 3]] = mean_feat 504 | return outputs_mean 505 | 506 | 507 | class MultiscaleDiscriminator(nn.Module): 508 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 509 | use_sigmoid=False, num_D=3, getIntermFeat=False): 510 | super(MultiscaleDiscriminator, self).__init__() 511 | self.num_D = num_D 512 | self.n_layers = n_layers 513 | self.getIntermFeat = getIntermFeat 514 | 515 | for i in range(num_D): 516 | netD = NLayerDiscriminator( 517 | input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 518 | if getIntermFeat: 519 | for j in range(n_layers+2): 520 | setattr(self, 'scale'+str(i)+'_layer' + 521 | str(j), getattr(netD, 'model'+str(j))) 522 | else: 523 | setattr(self, 'layer'+str(i), netD.model) 524 | 525 | self.downsample = nn.AvgPool2d( 526 | 3, stride=2, padding=[1, 1], count_include_pad=False) 527 | 528 | def singleD_forward(self, model, input): 529 | if self.getIntermFeat: 530 | result = [input] 531 | for i in range(len(model)): 532 | result.append(model[i](result[-1])) 533 | return result[1:] 534 | else: 535 | return [model(input)] 536 | 537 | def forward(self, input): 538 | num_D = self.num_D 539 | result = [] 540 | input_downsampled = input 541 | for i in range(num_D): 542 | if self.getIntermFeat: 543 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) 544 | for j in range(self.n_layers+2)] 545 | else: 546 | model = getattr(self, 'layer'+str(num_D-1-i)) 547 | result.append(self.singleD_forward(model, input_downsampled)) 548 | if i != (num_D-1): 549 | input_downsampled = self.downsample(input_downsampled) 550 | return result 551 | 552 | 553 | class MultiResolutionDiscriminator(nn.Module): 554 | def __init__(self, input_nc=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 555 | use_sigmoid=False, num_D=3, base_nfft=2048, window=None, min_value=1e-7, mdct_type='4', normalizer=None, getIntermFeat=False, abs_spectro=False): 556 | super(MultiResolutionDiscriminator, self).__init__() 557 | self.num_D = num_D 558 | self.n_layers = n_layers 559 | self.base_nfft = base_nfft 560 | self.window = window 561 | self.min_value = min_value 562 | self.mdct = [] 563 | self.normalizer = normalizer 564 | self.getIntermFeat = getIntermFeat 565 | self.abs_spectro = abs_spectro 566 | 567 | if mdct_type == '4': 568 | from .mdct import MDCT4 569 | elif mdct_type == '2': 570 | from .mdct import MDCT2 571 | from dct.dct_native import DCT_2N_native 572 | else: 573 | raise NotImplementedError( 574 | 'MDCT type [%s] is not implemented' % mdct_type) 575 | 576 | for i in range(num_D): 577 | netD = NLayerDiscriminator( 578 | input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 579 | if getIntermFeat: 580 | for j in range(n_layers+2): 581 | setattr(self, 'scale'+str(i)+'_layer' + 582 | str(j), getattr(netD, 'model'+str(j))) 583 | else: 584 | setattr(self, 'layer'+str(i), netD.model) 585 | 586 | if i == 0: 587 | N = int(self.base_nfft*2) 588 | else: 589 | N = int(self.base_nfft//(2**i)) 590 | if mdct_type == '4': 591 | self.mdct.append(MDCT4(n_fft=N, hop_length=N//2, 592 | win_length=N, window=self.window, center=True)) 593 | elif mdct_type == '2': 594 | _dct = DCT_2N_native() 595 | self.mdct.append(MDCT2(n_fft=N, hop_length=N//2, win_length=N, 596 | window=self.window, dct_op=_dct, center=True)) 597 | 598 | def singleD_forward(self, model, input): 599 | if self.getIntermFeat: 600 | result = [input] 601 | for i in range(len(model)): 602 | result.append(model[i](result[-1])) 603 | return result[1:] 604 | else: 605 | return [model(input)] 606 | 607 | def forward(self, waveform): 608 | result = [] 609 | # FRAME_LENGTH = (BINS-1)*HOP_LENGTH 610 | bins = waveform.size(-1)//self.base_nfft//2 + 1 611 | for i in range(self.num_D): 612 | if i == 0: 613 | frame_len = int((bins//2-1)*self.base_nfft) 614 | else: 615 | N = int(self.base_nfft//(2**i)) 616 | frame_len = int((bins*(2**i)-1)*N) 617 | len_diff = frame_len - waveform.size(-1) 618 | if len_diff < 0: 619 | waveform_ = waveform[..., :len_diff] 620 | else: 621 | waveform_ = pad(waveform, (0, len_diff)) 622 | spectro = self.mdct[i](waveform_) 623 | if self.abs_spectro: 624 | # [LR, HR/SR, abs(HR/SR)] 625 | spectro = torch.cat( 626 | (spectro, spectro[:, 1, :, :].abs().unsqueeze(1)), dim=1) 627 | if callable(self.normalizer): 628 | # [0] avoids multiple return values 629 | spectro = self.normalizer(spectro)[0] 630 | if self.getIntermFeat: 631 | model = [getattr(self, 'scale'+str(self.num_D-1-i)+'_layer'+str(j)) 632 | for j in range(self.n_layers+2)] 633 | else: 634 | model = getattr(self, 'layer'+str(self.num_D-1-i)) 635 | result.append(self.singleD_forward(model, spectro.float())) 636 | return result 637 | 638 | # Defines the PatchGAN discriminator with the specified arguments. 639 | 640 | 641 | class NLayerDiscriminator(nn.Module): 642 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): 643 | super(NLayerDiscriminator, self).__init__() 644 | self.getIntermFeat = getIntermFeat 645 | self.n_layers = n_layers 646 | 647 | kw = 4 648 | padw = int(np.ceil((kw-1.0)/2)) 649 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, 650 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 651 | 652 | nf = ndf 653 | for n in range(1, n_layers): 654 | nf_prev = nf 655 | nf = min(nf * 2, 512) 656 | sequence += [[ 657 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 658 | norm_layer(nf), nn.LeakyReLU(0.2, True) 659 | ]] 660 | 661 | nf_prev = nf 662 | nf = min(nf * 2, 512) 663 | sequence += [[ 664 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 665 | norm_layer(nf), 666 | nn.LeakyReLU(0.2, True) 667 | ]] 668 | 669 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, 670 | stride=1, padding=padw)]] 671 | 672 | if use_sigmoid: 673 | sequence += [[nn.Sigmoid()]] 674 | 675 | if getIntermFeat: 676 | for n in range(len(sequence)): 677 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 678 | else: 679 | sequence_stream = [] 680 | for n in range(len(sequence)): 681 | sequence_stream += sequence[n] 682 | self.model = nn.Sequential(*sequence_stream) 683 | 684 | def forward(self, input): 685 | if self.getIntermFeat: 686 | res = [input] 687 | for n in range(self.n_layers+2): 688 | model = getattr(self, 'model'+str(n)) 689 | res.append(model(res[-1])) 690 | return res[1:] 691 | else: 692 | return self.model(input) 693 | 694 | 695 | class Vgg19(torch.nn.Module): 696 | def __init__(self, requires_grad=False): 697 | super(Vgg19, self).__init__() 698 | vgg_pretrained_features = models.vgg19(pretrained=True).features 699 | self.slice1 = torch.nn.Sequential() 700 | self.slice2 = torch.nn.Sequential() 701 | self.slice3 = torch.nn.Sequential() 702 | self.slice4 = torch.nn.Sequential() 703 | self.slice5 = torch.nn.Sequential() 704 | for x in range(2): 705 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 706 | for x in range(2, 7): 707 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 708 | for x in range(7, 12): 709 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 710 | for x in range(12, 21): 711 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 712 | for x in range(21, 30): 713 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 714 | if not requires_grad: 715 | for param in self.parameters(): 716 | param.requires_grad = False 717 | 718 | def forward(self, X): 719 | h_relu1 = self.slice1(X) 720 | h_relu2 = self.slice2(h_relu1) 721 | h_relu3 = self.slice3(h_relu2) 722 | h_relu4 = self.slice4(h_relu3) 723 | h_relu5 = self.slice5(h_relu4) 724 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 725 | return out 726 | --------------------------------------------------------------------------------