├── data
├── __init__.py
├── compress_audio.sh
├── data_loader.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── audio_dataset.py
└── test.csv
├── models
├── __init__.py
├── models.py
├── base_model.py
├── spectrogram.py
├── mdct.py
└── networks.py
├── util
├── __init__.py
├── spectro_img.py
├── image_pool.py
├── html.py
├── visualizer.py
└── util.py
├── options
├── __init__.py
├── audio_config.py
├── test_options.py
├── train_options.py
└── base_options.py
├── .gitignore
├── requirements.txt
├── generate_audio.sh
├── train.sh
├── LICENSE
├── generate_audio.py
├── README.md
└── train.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | */__pycache__
--------------------------------------------------------------------------------
/data/compress_audio.sh:
--------------------------------------------------------------------------------
1 | find ./ -name '*.wav' -exec bash -c 'ffmpeg -i $0 -c:a flac -ar 48000 -compression_level 12 ${0/.wav/.flac} && rm $0' {} \;
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bottleneck_transformer_pytorch==0.1.4
2 | dominate
3 | einops
4 | matplotlib
5 | numpy
6 | Pillow
7 | scipy
8 | torch
9 | torchaudio
10 | torchvision
11 | torch_scatter
12 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def get_train_dataloader():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/options/audio_config.py:
--------------------------------------------------------------------------------
1 | N_FFT = 512
2 | HOP_LENGTH = 256
3 | WIN_LENGTH = 512
4 | LR_SAMPLE_RATE = 8000
5 | HR_SAMPLE_RATE = 48000
6 | SR_SAMPLE_RATE = 48000
7 | BINS = 128
8 | assert BINS%16 == 0 #must divisable by 16
9 | CENTER = True
10 | if CENTER:
11 | FRAME_LENGTH = (BINS-1)*HOP_LENGTH
12 | else:
13 | FRAME_LENGTH = (BINS-1)*HOP_LENGTH + WIN_LENGTH
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def create_model(opt):
4 | if opt.model == 'pix2pixHD':
5 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
6 | if opt.isTrain:
7 | model = Pix2PixHDModel()
8 | else:
9 | model = InferenceModel()
10 | else:
11 | from .ui_model import UIModel
12 | model = UIModel()
13 | model.initialize(opt)
14 | if opt.verbose:
15 | print("model [%s] was created" % (model.name()))
16 |
17 | # if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
18 | # model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
19 |
20 | return model
21 |
--------------------------------------------------------------------------------
/generate_audio.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python generate_audio.py \
4 | --name output_folder_name \
5 | --load_pretrain /home/neoncloud/pix2pixHDAudioSR/checkpoints/vctk_fintune_G4A3L3_56ngf_3x \
6 | --lr_sampling_rate 16000 --sr_sampling_rate 48000 \
7 | --dataroot /mnt/d/Desktop/LJ001-0056.wav --batchSize 16 \
8 | --gpu_id 0 --fp16 --nThreads 1 \
9 | --arcsinh_transform --abs_spectro --arcsinh_gain 1000 --center \
10 | --norm_range -1 1 --smooth 0.0 --abs_norm --src_range -5 5 \
11 | --netG local --ngf 56 --niter 40 \
12 | --n_downsample_global 3 --n_blocks_global 4 \
13 | --n_blocks_attn_g 3 --dim_head_g 128 --heads_g 6 --proj_factor_g 4 \
14 | --n_blocks_attn_l 0 --n_blocks_local 3 --gen_overlap 0 \
15 | --fit_residual --upsample_type interpolate --downsample_type resconv --phase test
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py \
4 | --name your_training_name \
5 | --dataroot /home/neoncloud/VCTK-Corpus/train.csv --evalroot /home/neoncloud/VCTK-Corpus/test.csv \
6 | --lr_sampling_rate 16000 --sr_sampling_rate 48000 \
7 | --batchSize 20 \
8 | --gpu_id 0 --fp16 --nThreads 16 --lr 1.5e-4 \
9 | --arcsinh_transform --abs_spectro --arcsinh_gain 1000 --center \
10 | --norm_range -1 1 --smooth 0.0 --abs_norm --src_range -5 5 \
11 | --netG local --ngf 56 \
12 | --n_downsample_global 3 --n_blocks_global 4 \
13 | --n_blocks_attn_g 3 --dim_head_g 128 --heads_g 6 --proj_factor_g 4 \
14 | --n_blocks_attn_l 0 --n_blocks_local 3 \
15 | --fit_residual --upsample_type interpolate --downsample_type resconv \
16 | --niter 60 --niter_decay 60 --num_D 3 \
17 | --eval_freq 32000 --save_latest_freq 16000 --save_epoch_freq 10 --display_freq 16000 --tf_log
--------------------------------------------------------------------------------
/util/spectro_img.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | plt.switch_backend('agg')
4 | def fig2img(fig):
5 | """Convert a Matplotlib figure to a PIL Image and return it"""
6 | fig.canvas.draw()
7 | return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (3,))
8 |
9 | def compute_visuals(sp=None, pha=None, abs=False):
10 | sp = sp.transpose() if sp is not None else None
11 | pha = pha.transpose() if pha is not None else None
12 | sp_spectro = None
13 | sp_hist = None
14 | if sp is not None:
15 | sp_fig, sp_ax = plt.subplots()
16 | sp_ax.pcolormesh(sp if not abs else np.abs(sp), cmap='PuBu_r')
17 | sp_spectro = fig2img(sp_fig)
18 |
19 | sp_hist_fig, sp_hist_ax = plt.subplots()
20 | sp_hist_ax.hist(sp.reshape(-1,1),bins=100)
21 | sp_hist = fig2img(sp_hist_fig)
22 |
23 | if pha is not None:
24 | pha_fig, pha_ax = plt.subplots()
25 | pha_ax.pcolormesh(pha, cmap='cool')
26 | pha = fig2img(pha_fig)
27 |
28 | plt.close('all')
29 | return sp_spectro, sp_hist, pha
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | --------------------------- LICENSE FOR pix2pixHD ----------------
2 | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
3 | BSD License. All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
17 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
18 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
19 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
20 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
8 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
10 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
11 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
12 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
13 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
14 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
15 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
16 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
17 | self.isTrain = False
18 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 | def save(self):
45 | html_file = '%s/index.html' % self.web_dir
46 | f = open(html_file, 'wt')
47 | f.write(self.doc.render())
48 | f.close()
49 |
50 |
51 | if __name__ == '__main__':
52 | html = HTML('web/', 'test_html')
53 | html.add_header('hello world')
54 |
55 | ims = []
56 | txts = []
57 | links = []
58 | for n in range(4):
59 | ims.append('image_%d.jpg' % n)
60 | txts.append('text_%d' % n)
61 | links.append('image_%d.jpg' % n)
62 | html.add_images(ims, txts, links)
63 | html.save()
64 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import random
6 |
7 | class BaseDataset(data.Dataset):
8 | def __init__(self):
9 | super(BaseDataset, self).__init__()
10 |
11 | def name(self):
12 | return 'BaseDataset'
13 |
14 | def initialize(self, opt):
15 | pass
16 |
17 | def get_params(opt, size):
18 | w, h = size
19 | new_h = h
20 | new_w = w
21 | if opt.resize_or_crop == 'resize_and_crop':
22 | new_h = new_w = opt.loadSize
23 | elif opt.resize_or_crop == 'scale_width_and_crop':
24 | new_w = opt.loadSize
25 | new_h = opt.loadSize * h // w
26 |
27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
29 |
30 | flip = random.random() > 0.5
31 | return {'crop_pos': (x, y), 'flip': flip}
32 |
33 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
34 | transform_list = []
35 | if 'resize' in opt.resize_or_crop:
36 | osize = [opt.loadSize, opt.loadSize]
37 | transform_list.append(transforms.Scale(osize, method))
38 | elif 'scale_width' in opt.resize_or_crop:
39 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
40 |
41 | if 'crop' in opt.resize_or_crop:
42 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
43 |
44 | if opt.resize_or_crop == 'none':
45 | base = float(2 ** opt.n_downsample_global)
46 | if opt.netG == 'local':
47 | base *= (2 ** opt.n_local_enhancers)
48 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
49 |
50 | if opt.isTrain and not opt.no_flip:
51 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
52 |
53 | transform_list += [transforms.ToTensor()]
54 |
55 | if normalize:
56 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
57 | (0.5, 0.5, 0.5))]
58 | return transforms.Compose(transform_list)
59 |
60 | def normalize():
61 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
62 |
63 | def __make_power_2(img, base, method=Image.BICUBIC):
64 | ow, oh = img.size
65 | h = int(round(oh / base) * base)
66 | w = int(round(ow / base) * base)
67 | if (h == oh) and (w == ow):
68 | return img
69 | return img.resize((w, h), method)
70 |
71 | def __scale_width(img, target_width, method=Image.BICUBIC):
72 | ow, oh = img.size
73 | if (ow == target_width):
74 | return img
75 | w = target_width
76 | h = int(target_width * oh / ow)
77 | return img.resize((w, h), method)
78 |
79 | def __crop(img, pos, size):
80 | ow, oh = img.size
81 | x1, y1 = pos
82 | tw = th = size
83 | if (ow > tw or oh > th):
84 | return img.crop((x1, y1, x1 + tw, y1 + th))
85 | return img
86 |
87 | def __flip(img, flip):
88 | if flip:
89 | return img.transpose(Image.FLIP_LEFT_RIGHT)
90 | return img
91 |
--------------------------------------------------------------------------------
/generate_audio.py:
--------------------------------------------------------------------------------
1 | from util.util import compute_matrics
2 | import torch
3 | import torchaudio
4 | import os
5 |
6 | from options.train_options import TrainOptions
7 | from data.data_loader import CreateDataLoader
8 | from models.models import create_model
9 | from util.visualizer import Visualizer
10 | from util.spectro_img import compute_visuals
11 |
12 | # Initilize the setup
13 | opt = TrainOptions().parse()
14 | visualizer = Visualizer(opt)
15 | data_loader = CreateDataLoader(opt)
16 | dataset = data_loader.get_train_dataloader()
17 | dataset_size = len(data_loader)
18 | model = create_model(opt)
19 | print('#audio segments = %d' % dataset_size)
20 |
21 |
22 | # Forward pass
23 | spectro_mag = []
24 | spectro_pha = []
25 | norm_params = []
26 | audio = []
27 | model.eval()
28 | stride = opt.segment_length-opt.gen_overlap
29 | with torch.no_grad():
30 | for i, data in enumerate(dataset):
31 | sr_spectro, sr_audio, lr_pha, norm_param, lr_spectro = model.inference(
32 | data['LR_audio'].cuda())
33 | print(sr_spectro.size())
34 | # spectro_mag.append(sr_spectro)
35 | # spectro_pha.append(lr_pha)
36 | # norm_params.append(norm_param)
37 | audio.append(sr_audio.cpu())
38 |
39 | # Concatenate the audio
40 | if opt.gen_overlap > 0:
41 | from torch.nn.functional import fold
42 | out_len = (dataset_size-1) * stride + opt.segment_length
43 | print(out_len)
44 | audio = torch.cat(audio,dim=0)
45 | print(audio.shape)
46 | audio[...,:opt.gen_overlap] *= 0.5
47 | audio[...,-opt.gen_overlap:] *= 0.5
48 | audio = audio.squeeze().transpose(-1,-2)
49 | audio = fold(audio, kernel_size=(1,opt.segment_length), stride=(1,stride), output_size=(1,out_len)).squeeze(0)
50 | audio = audio[...,opt.gen_overlap:-opt.gen_overlap]
51 | print(audio.shape)
52 | else:
53 | audio = torch.cat(audio, dim=0).view(1, -1)
54 | audio_len = data_loader.train_dataset.raw_audio.size(-1)
55 | # print(audio.size())
56 |
57 | # Evaluate the matrics
58 | audio_len = data_loader.train_dataset.raw_audio.size(-1)
59 | _mse, _snr_sr, _snr_lr, _ssnr_sr, _ssnr_lr, _pesq, _lsd = compute_matrics(
60 | data_loader.train_dataset.raw_audio, data_loader.train_dataset.lr_audio[..., :audio_len], audio[..., :audio_len], opt)
61 | print('MSE: %.4f' % _mse)
62 | print('SNR_SR: %.4f' % _snr_sr)
63 | print('SNR_LR: %.4f' % _snr_lr)
64 | #print('SSNR_SR: %.4f' % _ssnr_sr)
65 | #print('SSNR_LR: %.4f' % _ssnr_lr)
66 | #print('PESQ: %.4f' % _pesq)
67 | print('LSD: %.4f' % _lsd)
68 |
69 | # # Generate visuals
70 | # lr_mag, _, sr_mag, _, _, _, _, _ = model.encode_input(
71 | # lr_audio=data_loader.dataset.lr_audio, hr_audio=audio)
72 | # if opt.explicit_encoding:
73 | # lr_mag = 0.5*(lr_mag[:, 0, :, :]+lr_mag[:, 1, :, :])
74 | # sr_mag = 0.5*(sr_mag[:, 0, :, :]+sr_mag[:, 1, :, :])
75 | # lr_spectro, lr_hist, _ = compute_visuals(
76 | # sp=lr_mag.squeeze().detach().cpu().numpy(), abs=True)
77 | # sr_spectro, sr_hist, _ = compute_visuals(
78 | # sp=sr_mag.squeeze().detach().cpu().numpy(), abs=True)
79 | # visuals = {'lable_spectro': lr_spectro,
80 | # 'generated_spectro': sr_spectro,
81 | # 'lable_hist': lr_hist,
82 | # 'generated_hist': sr_hist}
83 |
84 | # # Save files
85 | # visualizer.display_current_results(visuals, 1, 1)
86 | with open(os.path.join(opt.checkpoints_dir, opt.name, 'metric.txt'), 'w') as f:
87 | f.write('MSE,SNR_SR,LSD\n')
88 | f.write('%f,%f,%f' % (_mse, _snr_sr, _lsd))
89 | sr_path = os.path.join(opt.checkpoints_dir, opt.name, 'sr_audio.wav')
90 | lr_path = os.path.join(opt.checkpoints_dir, opt.name, 'lr_audio.wav')
91 | hr_path = os.path.join(opt.checkpoints_dir, opt.name, 'hr_audio.wav')
92 | torchaudio.save(sr_path, audio.cpu().to(torch.float32), opt.hr_sampling_rate)
93 | torchaudio.save(lr_path, data_loader.train_dataset.lr_audio.cpu(),
94 | opt.hr_sampling_rate)
95 | torchaudio.save(hr_path, data_loader.train_dataset.raw_audio.cpu(),
96 | data_loader.train_dataset.in_sampling_rate)
97 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | from queue import Queue
2 | import torch.utils.data
3 | from data.base_data_loader import BaseDataLoader
4 | from threading import Thread
5 |
6 |
7 | def CreateDataset(opt):
8 | train_dataset = None
9 | if opt.phase == 'train':
10 | from data.audio_dataset import AudioDataset
11 | train_dataset = AudioDataset(opt)
12 | eval_dataset = AudioDataset(opt, True)
13 | print("dataset [%s] was created" % (train_dataset.name()))
14 | print("dataset [%s] was created" % (eval_dataset.name()))
15 | elif opt.phase == 'test':
16 | from data.audio_dataset import AudioTestDataset
17 | train_dataset = AudioTestDataset(opt)
18 | print("dataset [%s] was created" % (train_dataset.name()))
19 | eval_dataset = None
20 |
21 | # dataset.initialize(opt)
22 | return train_dataset, eval_dataset
23 |
24 | class CustomDatasetDataLoader(BaseDataLoader):
25 | def name(self):
26 | return 'CustomDatasetDataLoader'
27 |
28 | def initialize(self, opt):
29 | self.q_size = 16
30 | self.idx = 0
31 | self.load_stream = torch.cuda.Stream(device='cuda')
32 | self.queue: Queue = Queue(maxsize=self.q_size)
33 | BaseDataLoader.initialize(self, opt)
34 | self.train_dataset, self.eval_dataset = CreateDataset(opt)
35 | self.data_lenth = len(self.train_dataset)
36 | self.eval_data_lenth = len(self.eval_dataset) if self.eval_dataset is not None else None
37 | if opt.phase == "train":
38 | self.train_dataloader = torch.utils.data.DataLoader(
39 | self.train_dataset,
40 | batch_size=opt.batchSize,
41 | shuffle=True,
42 | num_workers=int(opt.nThreads),
43 | prefetch_factor=8,
44 | pin_memory=True)
45 |
46 | self.eval_dataloder = torch.utils.data.DataLoader(
47 | self.eval_dataset,
48 | batch_size=opt.batchSize,
49 | shuffle=True,
50 | num_workers=int(opt.nThreads),
51 | pin_memory=True)
52 |
53 | elif opt.phase == "test":
54 | self.train_dataloader = torch.utils.data.DataLoader(
55 | self.train_dataset,
56 | batch_size=opt.batchSize,
57 | num_workers=int(opt.nThreads),
58 | shuffle=False,
59 | pin_memory=True)
60 | self.eval_dataloder = None
61 | self.eval_data_lenth = 0
62 |
63 | def load_loop(self) -> None: # The loop that will load into the queue in the background
64 | for i, sample in enumerate(self.train_dataloader):
65 | self.queue.put(self.load_instance(sample))
66 | if i == len(self):
67 | break
68 |
69 | def load_instance(self, sample:dict):
70 | with torch.cuda.stream(self.load_stream):
71 | return {k:v.cuda(non_blocking=True) for k,v in sample.items()}
72 |
73 | def get_train_dataloader(self):
74 | return self.train_dataloader
75 |
76 | def async_load_data(self):
77 | return self
78 |
79 | def get_eval_dataloader(self):
80 | return self.eval_dataloder
81 |
82 | def eval_data_len(self):
83 | return self.eval_data_lenth
84 |
85 | def __len__(self):
86 | return self.data_lenth
87 |
88 | def __iter__(self):
89 | if_worker = not hasattr(self, "worker") or not self.worker.is_alive() # type: ignore[has-type]
90 | if if_worker and self.queue.empty() and self.idx == 0:
91 | self.worker = Thread(target=self.load_loop)
92 | self.worker.daemon = True
93 | self.worker.start()
94 | return self
95 |
96 | def __next__(self):
97 | # If we've reached the number of batches to return
98 | # or the queue is empty and the worker is dead then exit
99 | done = not self.worker.is_alive() and self.queue.empty()
100 | done = done or self.idx >= len(self)
101 | if done:
102 | self.idx = 0
103 | self.queue.join()
104 | self.worker.join()
105 | raise StopIteration
106 | # Otherwise return the next batch
107 | out = self.queue.get()
108 | self.queue.task_done()
109 | self.idx += 1
110 | return out
111 |
112 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 |
5 | class BaseModel(torch.nn.Module):
6 | def name(self):
7 | return 'BaseModel'
8 |
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.gpu_ids = opt.gpu_ids
12 | self.isTrain = opt.isTrain
13 | #self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
15 | self.device = 'cuda' if len(self.gpu_ids) > 0 else 'cpu'
16 |
17 | def set_input(self, input):
18 | self.input = input
19 |
20 | def forward(self):
21 | pass
22 |
23 | # used in test time, no backprop
24 | def test(self):
25 | pass
26 |
27 | def get_image_paths(self):
28 | pass
29 |
30 | def optimize_parameters(self):
31 | pass
32 |
33 | def get_current_visuals(self):
34 | return self.input
35 |
36 | def get_current_errors(self):
37 | return {}
38 |
39 | def save(self, label):
40 | pass
41 |
42 | # helper saving function that can be used by subclasses
43 | def save_network(self, network, network_label, epoch_label, gpu_ids):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | torch.save(network.state_dict(), save_path)
47 |
48 | # helper loading function that can be used by subclasses
49 | def load_network(self, network, network_label, epoch_label, save_dir=''):
50 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
51 | if not save_dir:
52 | save_dir = self.save_dir
53 | save_path = os.path.join(save_dir, save_filename)
54 | if not os.path.isfile(save_path):
55 | print('%s not exists yet!' % save_path)
56 | if network_label == 'G':
57 | raise('Generator must exist!')
58 | else:
59 | #network.load_state_dict(torch.load(save_path))
60 | try:
61 | network.load_state_dict(torch.load(save_path, map_location='cpu'))
62 | network.to(self.device)
63 | except:
64 | pretrained_dict = torch.load(save_path)
65 | model_dict = network.state_dict()
66 | try:
67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
68 | network.load_state_dict(pretrained_dict)
69 | if self.opt.verbose:
70 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
71 | except:
72 | print('Pretrained network %s has fewer layers; The following layers are possibly matched:' % network_label)
73 | pretrained_dict = torch.load(save_path)
74 | module_map = self.opt.param_key_map
75 | for name, param in pretrained_dict.items():
76 | if name not in model_dict or param.size()!=model_dict[name].size():
77 | #print('No match %s. Try to find mapping...'%name)
78 | layer_name = name.split('.')
79 | key = layer_name[0]+'.'+layer_name[1]
80 | if key in module_map:
81 | layer_name[1] = module_map[key]
82 | name_ = name
83 | name = "."
84 | name = name.join(layer_name)
85 | print(" ",name_,'->',name)
86 | else:
87 | for k, v in model_dict.items():
88 | if v.size() == param.size():
89 | print(" ",k,":",name)
90 | continue
91 | # if isinstance(param, torch.nn.Parameter):
92 | # # backwards compatibility for serialized parameters
93 | # param = param.data
94 | model_dict[name]=param
95 | # for k, v in pretrained_dict.items():
96 | # if v.size() == model_dict[k].size():
97 | # print('Layer %s initialized'%k)
98 | # model_dict[k] = v
99 |
100 | # if sys.version_info >= (3,0):
101 | # not_initialized = set()
102 | # else:
103 | # from sets import Set
104 | # not_initialized = Set()
105 |
106 | # for k, v in model_dict.items():
107 | # if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
108 | # not_initialized.add(k.split('.')[0])
109 |
110 | # print(sorted(not_initialized))
111 | network.load_state_dict(model_dict)
112 |
113 | def update_learning_rate():
114 | pass
115 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 | import scipy.misc
8 | try:
9 | from StringIO import StringIO # Python 2.7
10 | except ImportError:
11 | from io import BytesIO # Python 3.x
12 |
13 | class Visualizer():
14 | def __init__(self, opt):
15 | # self.opt = opt
16 | self.tf_log = opt.tf_log
17 | self.use_html = opt.isTrain and not opt.no_html
18 | self.win_size = opt.display_winsize
19 | self.name = opt.name
20 | if self.tf_log:
21 | from torch.utils.tensorboard import SummaryWriter
22 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
23 | self.writer = SummaryWriter(self.log_dir)
24 |
25 | if self.use_html:
26 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
27 | self.img_dir = os.path.join(self.web_dir, 'images')
28 | print('create web directory %s...' % self.web_dir)
29 | util.mkdirs([self.web_dir, self.img_dir])
30 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
31 | with open(self.log_name, "a") as log_file:
32 | now = time.strftime("%c")
33 | log_file.write('================ Training Loss (%s) ================\n' % now)
34 |
35 | # |visuals|: dictionary of images to display or save
36 | def display_current_results(self, visuals, epoch, step):
37 | if self.tf_log: # show images in tensorboard output
38 | for label, image_numpy in visuals.items():
39 | # Create an Image object
40 | if 'spectro' in label:
41 | cat = 'spctro/'
42 | elif 'hist' in label:
43 | cat = 'histogram/'
44 | elif 'pha' in label:
45 | cat = 'phase/'
46 | self.writer.add_image(cat+label, image_numpy, step, dataformats='HWC')
47 |
48 | if self.use_html: # save images to a html file
49 | for label, image_numpy in visuals.items():
50 | if isinstance(image_numpy, list):
51 | for i in range(len(image_numpy)):
52 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
53 | util.save_image(image_numpy[i], img_path)
54 | else:
55 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
56 | util.save_image(image_numpy, img_path)
57 |
58 | # update website
59 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
60 | for n in range(epoch, 0, -1):
61 | webpage.add_header('epoch [%d]' % n)
62 | ims = []
63 | txts = []
64 | links = []
65 |
66 | for label, image_numpy in visuals.items():
67 | if isinstance(image_numpy, list):
68 | for i in range(len(image_numpy)):
69 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
70 | ims.append(img_path)
71 | txts.append(label+str(i))
72 | links.append(img_path)
73 | else:
74 | img_path = 'epoch%.3d_%s.jpg' % (n, label)
75 | ims.append(img_path)
76 | txts.append(label)
77 | links.append(img_path)
78 | if len(ims) < 10:
79 | webpage.add_images(ims, txts, links, width=self.win_size)
80 | else:
81 | num = int(round(len(ims)/2.0))
82 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
83 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
84 | webpage.save()
85 |
86 | # errors: dictionary of error labels and values
87 | def plot_current_errors(self, errors, step):
88 | if self.tf_log:
89 | self.writer.add_scalars('Losses', errors, step)
90 |
91 | # errors: same format as |errors| of plotCurrentErrors
92 | def print_current_errors(self, epoch, i, errors, t):
93 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
94 | for k, v in errors.items():
95 | if v != 0:
96 | message += '%s: %.3f ' % (k, v)
97 |
98 | print(message)
99 | with open(self.log_name, "a") as log_file:
100 | log_file.write('%s\n' % message)
101 |
102 | # save image to the disk
103 | def save_images(self, webpage, visuals, image_path):
104 | image_dir = webpage.get_image_dir()
105 | short_path = ntpath.basename(image_path[0])
106 | name = os.path.splitext(short_path)[0]
107 |
108 | webpage.add_header(name)
109 | ims = []
110 | txts = []
111 | links = []
112 |
113 | for label, image_numpy in visuals.items():
114 | image_name = '%s_%s.jpg' % (name, label)
115 | save_path = os.path.join(image_dir, image_name)
116 | util.save_image(image_numpy, save_path)
117 |
118 | ims.append(image_name)
119 | txts.append(label)
120 | links.append(image_name)
121 | webpage.add_images(ims, txts, links, width=self.win_size)
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
mdctGAN: Taming transformer-based GAN for speech super-resolution with Modified DCT spectra
3 |
4 |
5 |
6 | By: Chenhao Shuai, Chaohua Shi, Lu Gan and Hongqing Liu
7 |
8 |
9 |
12 |
13 | ## Requirements
14 | * bottleneck_transformer_pytorch==0.1.4
15 | * dominate
16 | * einops
17 | * matplotlib
18 | * numpy
19 | * Pillow
20 | * scipy
21 | * torch
22 | * torchaudio
23 | * torchvision
24 | * torch_scatter (Optional if you want to use `FastMDCT4`)
25 |
26 | ## Pretrained Models
27 | [](https://huggingface.co/neoncloud/mdctGAN)
28 |
29 | ## Data Preparation
30 | Firstly, for excessively long speech audio file, we recommend that you remove long gaps and split it into smaller segments. Other than this no other pre-processing is required, the program will automatically sample a random section from the longer audio file.
31 |
32 | It also automatically resamples the high sample rate audio to the low sample rate and upsamples it again to the target sample rate. This process simulates the loss of speech after downsampling. And up-sampling again aligns the low-res audio with the original high sample rate audio. So you don't need to manually resample the original audio.
33 |
34 | Secondly, Prepare your dataset index file like this (VCTK dataset example):
35 | ```
36 | wav48/p250/p250_328.wav
37 | wav48/p310/p310_345.wav
38 | wav48/p227/p227_020.wav
39 | wav48/p285/p285_050.wav
40 | wav48/p248/p248_011.wav
41 | wav48/p246/p246_030.wav
42 | wav48/p247/p247_191.wav
43 | wav48/p287/p287_127.wav
44 | wav48/p334/p334_220.wav
45 | wav48/p340/p340_414.wav
46 | wav48/p236/p236_231.wav
47 | wav48/p301/p301_334.wav
48 | ...
49 | ```
50 | Save it to the root directory of your dataset as a text file and the program will splice the parent folder of index file with the relative path of the records in the file. You can also find the index file used in our experiments in `data/train.csv`.
51 |
52 | ## Train
53 | Modify & run `sh train.sh`. Detailed explanation of args can be found in `options/base_options.py` and `options/train_options.py`
54 |
55 |
56 | | Parameter Name | Description |
57 | |----------------------|--------------------------------------------------------------------------------------------------|
58 | | --name | Name of the experiment. It decides where to store samples and models. |
59 | | **--dataroot** | Path to your train set csv file. |
60 | | **--evalroot** | Path to your eval set csv file. |
61 | | **--lr_sampling_rate** | Input Low-res sampling rate. It will be automatically resampled to this value. |
62 | | **--sr_sampling_rate** | Target super-resolution sampling rate. |
63 | | --fp16 | Train with Automatic Mixed Precision (AMP). |
64 | | --nThreads | Number of threads for loading data. |
65 | | --lr | Initial learning rate for the Adam optimizer. |
66 | | --arcsinh_transform | Use $\log(x+\sqrt{x^2+1})$ to compress the range of input. |
67 | | --abs_spectro | Use the absolute value of the spectrogram. |
68 | | --arcsinh_gain | Gain parameter for the arcsinh_transform. |
69 | | --center | Centered MDCT. |
70 | | --norm_range | Specify the target distribution range. |
71 | | --abs_norm | Assume the spectrograms are all distributed in a fixed range. Normalize by an absolute range. |
72 | | --src_range | Specify the source distribution range. Used when --abs_norm is specified. |
73 | | --netG | Select the model to use for netG. |
74 | | --ngf | Number of generator filters in the first conv layer. |
75 | | --n_downsample_global| Number of downsampling layers in netG. |
76 | | --n_blocks_global | Number of residual blocks in the global generator network. |
77 | | --n_blocks_attn_g | Number of attention blocks in the global generator network. |
78 | | --dim_head_g | Dimension of attention heads in the global generator network. |
79 | | --heads_g | Number of attention heads in the global generator network. |
80 | | --proj_factor_g | Projection factor of attention blocks in the global generator network. |
81 | | --n_blocks_local | Number of residual blocks in the local enhancer network. |
82 | | --n_blocks_attn_l | Number of attention blocks in the local enhancer network. |
83 | | --fit_residual | If specified, fit $HR-LR$ than directly fit $HR$. |
84 | | --upsample_type | Select upsampling layers for netG. Supported options: interpolate, transconv. |
85 | | --downsample_type | Select downsampling layers for netG. Supported options: resconv, conv. |
86 | | --num_D | Number of discriminators to use. |
87 | | --eval_freq | Frequency of evaluating metrics. |
88 | | --save_latest_freq | Frequency of saving the latest results. |
89 | | --save_epoch_freq | Frequency of saving checkpoints at the end of epochs. |
90 | | --display_freq | Frequency of showing training results on screen. |
91 | | --tf_log | If specified, use TensorBoard logging. Requires TensorFlow installed. |
92 |
93 | ## Evaluate & Generate audio
94 | Modify & run `sh gen_audio.sh`.
95 |
96 | ## Acknowledgement
97 | This code repository refers heavily to the [official pix2pixHD implementation](https://github.com/NVIDIA/pix2pixHD). Also, this work is based on an improved version of my undergraduate Final Year Project, see: [pix2pixHDAudioSR](https://github.com/neoncloud/pix2pixHDAudioSR)
98 |
99 | ## Bonus
100 | Try `FastMDCT4`/`FastIMDCT4` in `models/mdct.py` to have faster MDCT conversion. You can use `FastMDCT4` as an in-place replacement for `MDCT4`, or modify the import statement in `models/pix2pixHD_model.py` to `from .mdct import FastMDCT4 as MDCT4, FastIMDCT4 as IMDCT4`
101 |
102 | On my computer (RTX3070 laptop, Intel Core i7 11800H), each forward transformation saves 2ms.
103 |
104 | ```python
105 | sig = torch.randn(64,32512, device='cuda')
106 | %timeit -r 20 -n 500 mdct(sig)
107 | # 9.61 ms ± 643 µs per loop (mean ± std. dev. of 20 runs, 500 loops each)
108 | %timeit -r 20 -n 500 fast_mdct(sig)
109 | # 7.68 ms ± 691 µs per loop (mean ± std. dev. of 20 runs, 500 loops each)
110 | ```
111 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 | from .audio_config import *
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | # for displays
8 | self.parser.add_argument('--display_freq', type=int, default=200, help='frequency of showing training results on screen')
9 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
10 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
11 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
12 | self.parser.add_argument('--eval_freq', type=int, default=32000, help='frequency of evaluating matrics')
13 | self.parser.add_argument('--loss_update_freq', type=int, default=256, help='frequency of updating scalers of auxiliary losses')
14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
16 | self.parser.add_argument('--abs_spectro', action='store_true', help='use absolute value of spectrogram')
17 |
18 | # for training
19 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
20 | self.parser.add_argument('--freeze_g_d', action='store_true', help='freeze downsample in G_g')
21 | self.parser.add_argument('--freeze_g_u', action='store_true', help='freeze upsample in G_g')
22 | self.parser.add_argument('--freeze_l_d', action='store_true', help='freeze downsample in G_l')
23 | self.parser.add_argument('--freeze_l_u', action='store_true', help='freeze upsample in G_l')
24 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
25 | self.parser.add_argument('--param_key_map', type=lambda x: {str(k):str(v) for k,v in (i.split(':') for i in x.split(','))}, default={}, help='if the pretrained model do not match the current model, this is helpful to map the pretrained modules to current model.')
26 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
27 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
28 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
29 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
30 | self.parser.add_argument('--niter_limit_aux', type=int, default=20, help='# of iter to limit auxiliary losses')
31 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
32 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
33 | self.parser.add_argument('--validation_split', type=float, default=0.05, help='path to file containing validation split indices')
34 | self.parser.add_argument('--val_indices', type=str, help='proportion of training data to be used as validation data if validation_split is not specified')
35 | self.parser.add_argument('--eval_size', type=int, default=100, help='how many samples to evaluate')
36 | self.parser.add_argument('--phase_encoding_mode', type=str, default=None, help='norm_dist|uni_dist|None')
37 |
38 | # for discriminators
39 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
40 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
41 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
42 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
43 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
44 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
45 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
46 | # self.parser.add_argument('--num_mr_D', type=int, default=2, help='number of multires discriminators to use')
47 | # self.parser.add_argument('--lambda_mat', type=float, default=0.01, help='weight for phase matching loss')
48 | # self.parser.add_argument('--lambda_time', type=float, default=0.4, help='weight for time domain loss')
49 | # self.parser.add_argument('--lambda_mr', type=float, default=0.08, help='weight for Multi-Res D loss')
50 | # self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss', default=True)
51 | # self.parser.add_argument('--use_match_loss', action='store_true', help='if specified, use matching loss')
52 | # self.parser.add_argument('--match_loss_count', type=int, default=1, help='let the match loss gradually increase after this number of iterations')
53 | # self.parser.add_argument('--match_loss_count_max', type=int, default=10000, help='let the match loss gradually increase after this number of iterations')
54 | # self.parser.add_argument('--match_loss_thres', type=float, default=0.5, help='if the matching loss is greater than this threshold, the loss will be discarded')
55 | # self.parser.add_argument('--use_hifigan_D', action='store_true', help='if specified, use multi-scale-multi-period hifigan time domain discriminator')
56 | # self.parser.add_argument('--use_time_D', action='store_true', help='if specified, use time domain discriminator')
57 | # self.parser.add_argument('--use_multires_D', action='store_true', help='if specified, use Multi-Resolution discriminator')
58 | # self.parser.add_argument('--use_shifted_match', action='store_true', help='if specified, shift audios randomly, then concat to the original ones as Discriminator inputs')
59 | # self.parser.add_argument('--time_D_count', type=int, default=0, help='let the time D gradually increase after this number of iterations')
60 | # self.parser.add_argument('--time_D_count_max', type=int, default=150, help='let the time D gradually increase after this number of iterations')
61 |
62 | # STFT params
63 | self.parser.add_argument('--lr_sampling_rate', type=int, default=LR_SAMPLE_RATE, help='low resolution sampling rate')
64 | self.parser.add_argument('--hr_sampling_rate', type=int, default=HR_SAMPLE_RATE, help='high resolution sampling rate')
65 | self.parser.add_argument('--sr_sampling_rate', type=int, default=SR_SAMPLE_RATE, help='target resolution sampling rate')
66 | self.parser.add_argument('--segment_length', type=int, default=FRAME_LENGTH, help='audio segment length')
67 | self.parser.add_argument('--gen_overlap', type=int, default=0, help='overlap length when generating. It is helpful to eliminate the transient effect')
68 | self.parser.add_argument('--n_fft', type=int, default=N_FFT, help='num of FFT points')
69 | self.parser.add_argument('--bins', type=int, default=BINS, help='num of time bins. This does not effect anything.')
70 | self.parser.add_argument('--hop_length', type=int, default=HOP_LENGTH, help='sliding window increament')
71 | self.parser.add_argument('--win_length', type=int, default=WIN_LENGTH, help='sliding window width')
72 | self.parser.add_argument('--center', action='store_true', help='centered FFT')
73 | self.parser.add_argument('--is_lr_input', action='store_true', help='if specified, the audio generator will assert the input as low res. And it will only do upsampling.')
74 | self.isTrain = True
75 |
--------------------------------------------------------------------------------
/data/audio_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from numpy import ceil
4 | import torch
5 | import torch.nn.functional as F
6 | import torchaudio
7 | import torchaudio.functional as aF
8 | from data.base_dataset import BaseDataset
9 | torchaudio.set_audio_backend('sox_io')
10 |
11 | class AudioDataset(BaseDataset):
12 | def __init__(self, opt, test=False) -> None:
13 | BaseDataset.__init__(self)
14 | self.lr_sampling_rate = opt.lr_sampling_rate
15 | self.hr_sampling_rate = opt.hr_sampling_rate
16 | self.segment_length = opt.segment_length
17 | self.n_fft = opt.n_fft
18 | self.hop_length = opt.hop_length
19 | self.win_length = opt.win_length
20 | self.audio_file = self.get_files(opt.evalroot if test else opt.dataroot)
21 | self.audio_len = [(0,0)]*len(self.audio_file)
22 | self.center = opt.center
23 | self.add_noise = opt.add_noise
24 | self.snr = opt.snr
25 |
26 | torch.manual_seed(opt.seed)
27 |
28 | def __len__(self):
29 | return len(self.audio_file)
30 |
31 | def name(self):
32 | return 'AudioMDCTSpectrogramDataset'
33 |
34 | def readaudio(self, idx):
35 | file_path = self.audio_file[idx]
36 | if self.audio_len[idx][1] == 0:
37 | metadata = torchaudio.info(file_path)
38 | audio_length = metadata.num_frames
39 | fs = metadata.sample_rate
40 | self.audio_len[idx] = (fs,audio_length)
41 | else:
42 | fs,audio_length = self.audio_len[idx]
43 | max_audio_start = int(audio_length - self.segment_length*fs/self.hr_sampling_rate)
44 | if max_audio_start > 0:
45 | offset = torch.randint(
46 | low=0, high=max_audio_start, size=(1,)).item()
47 | waveform, orig_sample_rate = torchaudio.load(
48 | file_path, frame_offset=offset, num_frames=self.segment_length)
49 | else:
50 | #print("Warning: %s is shorter than segment_length"%file_path, audio_length)
51 | waveform, orig_sample_rate = torchaudio.load(file_path)
52 | return waveform, orig_sample_rate
53 |
54 | def __getitem__(self, idx):
55 | try:
56 | waveform, orig_sample_rate = self.readaudio(idx)
57 | except: # try next until success
58 | i = 1
59 | while 1:
60 | print('Load failed!')
61 | try:
62 | waveform, orig_sample_rate = self.readaudio(idx+i)
63 | break
64 | except:
65 | i += 1
66 | hr_waveform = aF.resample(
67 | waveform=waveform, orig_freq=orig_sample_rate, new_freq=self.hr_sampling_rate)
68 | lr_waveform = aF.resample(
69 | waveform=waveform, orig_freq=orig_sample_rate, new_freq=self.lr_sampling_rate)
70 | lr_waveform = aF.resample(
71 | waveform=lr_waveform, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
72 | if self.add_noise:
73 | noise = torch.randn(lr_waveform.size())
74 | noise = noise-noise.mean()
75 | signal_power = torch.sum(lr_waveform**2)/self.segment_length
76 | noise_var = signal_power / 10**(self.snr/10)
77 | noise = torch.sqrt(noise_var)/noise.std()*noise
78 | lr_waveform = lr_waveform + noise
79 | # lr_waveform = aF.lowpass_biquad(waveform, sample_rate=self.hr_sampling_rate, cutoff_freq = self.lr_sampling_rate//2) #Meet the Nyquest sampling theorem
80 | hr = self.seg_pad_audio(hr_waveform)
81 | lr = self.seg_pad_audio(lr_waveform)
82 | return {'HR_audio': hr.squeeze(0), 'LR_audio': lr.squeeze(0)}
83 |
84 | def get_files(self, file_path):
85 | if os.path.isdir(file_path):
86 | print("Searching for audio file")
87 | file_list = []
88 | for root, dirs, files in os.walk(file_path, topdown=False):
89 | for name in files:
90 | if os.path.splitext(name)[1] == ".wav" or ".mp3" or ".flac":
91 | file_list.append(os.path.join(root, name))
92 | else:
93 | print("Using csv file list")
94 | root, csv_file = os.path.split(file_path)
95 | with open(file_path, 'r') as csv_file:
96 | csv_reader = csv.reader(csv_file)
97 | file_list = [os.path.join(root, item) for sublist in list(
98 | csv_reader) for item in sublist]
99 | print(len(file_list))
100 | return file_list
101 |
102 | def seg_pad_audio(self, waveform):
103 | if waveform.size(1) >= self.segment_length:
104 | waveform = waveform[0][:self.segment_length]
105 | else:
106 | waveform = F.pad(
107 | waveform, (0, self.segment_length -
108 | waveform.size(1)), 'constant'
109 | ).data
110 | return waveform
111 |
112 |
113 | class AudioTestDataset(BaseDataset):
114 | def __init__(self, opt) -> None:
115 | BaseDataset.__init__(self)
116 | self.lr_sampling_rate = opt.lr_sampling_rate
117 | self.hr_sampling_rate = opt.hr_sampling_rate
118 | self.segment_length = opt.segment_length
119 | self.n_fft = opt.n_fft
120 | self.hop_length = opt.hop_length
121 | self.win_length = opt.win_length
122 | self.center = opt.center
123 | self.dataroot = opt.dataroot
124 | self.is_lr_input = opt.is_lr_input
125 | self.overlap = opt.gen_overlap
126 | self.add_noise = opt.add_noise
127 | self.snr = opt.snr
128 |
129 | self.read_audio()
130 | self.post_processing()
131 |
132 | def __len__(self):
133 | return self.seg_audio.size(0)
134 |
135 | def name(self):
136 | return 'AudioMDCTSpectrogramTestDataset'
137 |
138 | def __getitem__(self, idx):
139 | return {'LR_audio': self.seg_audio[idx, :].squeeze(0)}
140 |
141 | def read_audio(self):
142 | try:
143 | self.raw_audio, self.in_sampling_rate = torchaudio.load(
144 | self.dataroot)
145 | self.audio_len = self.raw_audio.size(-1)
146 | self.raw_audio += 1e-4 - torch.mean(self.raw_audio)
147 | print("Audio length:", self.audio_len)
148 | except:
149 | self.raw_audio = []
150 | print("load audio failed")
151 | exit(0)
152 |
153 | def seg_pad_audio(self, audio):
154 | audio = audio.squeeze(0)
155 | length = len(audio)
156 | if length >= self.segment_length:
157 | num_segments = int(ceil(length/self.segment_length))
158 | audio = F.pad(audio, (self.overlap, self.segment_length *
159 | num_segments - length + self.overlap), "constant").data
160 | audio = audio.unfold(
161 | dimension=0, size=self.segment_length, step=self.segment_length-self.overlap)
162 | else:
163 | audio = F.pad(
164 | audio, (0, self.segment_length - length), 'constant').data
165 | audio = audio.unsqueeze(0)
166 |
167 | return audio
168 |
169 | def post_processing(self):
170 | if self.is_lr_input:
171 | self.lr_audio = aF.resample(
172 | waveform=self.raw_audio, orig_freq=self.in_sampling_rate, new_freq=self.hr_sampling_rate)
173 | else:
174 | self.lr_audio = aF.resample(
175 | waveform=self.raw_audio, orig_freq=self.in_sampling_rate, new_freq=self.lr_sampling_rate)
176 | self.lr_audio = aF.resample(
177 | waveform=self.lr_audio, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
178 | if self.add_noise:
179 | noise = torch.randn(self.lr_audio.size())
180 | noise = noise-noise.mean()
181 | signal_power = torch.sum(self.lr_audio**2)/self.segment_length
182 | noise_var = signal_power / 10**(self.snr/10)
183 | noise = torch.sqrt(noise_var)/noise.std()*noise
184 | self.lr_audio = self.lr_audio + noise
185 | self.seg_audio = self.seg_pad_audio(self.lr_audio)
186 |
187 | class AudioAppDataset(AudioTestDataset):
188 | def __init__(self, opt, audio:torch.Tensor, fs) -> None:
189 | self.lr_sampling_rate = opt.lr_sampling_rate
190 | self.hr_sampling_rate = opt.hr_sampling_rate
191 | self.segment_length = opt.segment_length
192 | self.n_fft = opt.n_fft
193 | self.hop_length = opt.hop_length
194 | self.win_length = opt.win_length
195 | self.center = opt.center
196 | self.dataroot = audio
197 | self.is_lr_input = opt.is_lr_input
198 | self.overlap = opt.gen_overlap
199 | self.add_noise = opt.add_noise
200 | self.snr = opt.snr
201 | self.raw_audio = audio
202 | self.in_sampling_rate = fs
203 | self.post_processing()
204 | def read_audio(self):
205 | pass
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | import torchaudio.functional as aF
7 | from torch.nn.functional import conv1d
8 | #import pysepm
9 |
10 | # Converts a Tensor into a Numpy array
11 | # |imtype|: the desired type of the converted numpy array
12 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
13 | if isinstance(image_tensor, list):
14 | image_numpy = []
15 | for i in range(len(image_tensor)):
16 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
17 | return image_numpy
18 | image_numpy = image_tensor.cpu().float().numpy()
19 | if normalize:
20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
21 | else:
22 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
23 | image_numpy = np.clip(image_numpy, 0, 255)
24 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
25 | image_numpy = image_numpy[:,:,0]
26 | return image_numpy.astype(imtype)
27 |
28 | # Converts a one-hot tensor into a colorful label map
29 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
30 | if n_label == 0:
31 | return tensor2im(label_tensor, imtype)
32 | label_tensor = label_tensor.cpu().float()
33 | if label_tensor.size()[0] > 1:
34 | label_tensor = label_tensor.max(0, keepdim=True)[1]
35 | label_tensor = Colorize(n_label)(label_tensor)
36 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
37 | return label_numpy.astype(imtype)
38 |
39 | def save_image(image_numpy, image_path):
40 | image_pil = Image.fromarray(image_numpy)
41 | image_pil.save(image_path)
42 |
43 | def mkdirs(paths):
44 | if isinstance(paths, list) and not isinstance(paths, str):
45 | for path in paths:
46 | mkdir(path)
47 | else:
48 | mkdir(paths)
49 |
50 | def mkdir(path):
51 | if not os.path.exists(path):
52 | os.makedirs(path)
53 |
54 | ###############################################################################
55 | # Code from
56 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
57 | # Modified so it complies with the Citscape label map colors
58 | ###############################################################################
59 | def uint82bin(n, count=8):
60 | """returns the binary of integer n, count refers to amount of bits"""
61 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
62 |
63 | def labelcolormap(N):
64 | if N == 35: # cityscape
65 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
66 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
67 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
68 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
69 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
70 | dtype=np.uint8)
71 | else:
72 | cmap = np.zeros((N, 3), dtype=np.uint8)
73 | for i in range(N):
74 | r, g, b = 0, 0, 0
75 | id = i
76 | for j in range(7):
77 | str_id = uint82bin(id)
78 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
79 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
80 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
81 | id = id >> 3
82 | cmap[i, 0] = r
83 | cmap[i, 1] = g
84 | cmap[i, 2] = b
85 | return cmap
86 |
87 | class Colorize(object):
88 | def __init__(self, n=35):
89 | self.cmap = labelcolormap(n)
90 | self.cmap = torch.from_numpy(self.cmap[:n])
91 |
92 | def __call__(self, gray_image):
93 | size = gray_image.size()
94 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
95 |
96 | for label in range(0, len(self.cmap)):
97 | mask = (label == gray_image[0]).cpu()
98 | color_image[0][mask] = self.cmap[label][0]
99 | color_image[1][mask] = self.cmap[label][1]
100 | color_image[2][mask] = self.cmap[label][2]
101 |
102 | return color_image
103 |
104 | def imdct(spectro, pha, norm_param, _imdct, min_value=1e-7, up_ratio=1, explicit_encoding=False):
105 | device = spectro.device
106 | spectro = torch.abs(spectro)*(norm_param['max'].to(device)-norm_param['min'].to(device))+norm_param['min'].to(device)
107 | #log_mag = log_mag*norm_param['std']+norm_param['mean']
108 | spectro = aF.DB_to_amplitude(spectro.to(device),10,0.5)-min_value
109 | if explicit_encoding:
110 | pha = pha.squeeze()
111 | psudo_pha = torch.sign(spectro[...,0,:,:]-spectro[...,1,:,:])
112 | spectro = spectro[...,0,:,:]+spectro[...,1,:,:]
113 | if up_ratio > 1:
114 | size = pha.size(-2)
115 | if pha.dim() != 3:
116 | pha = pha.unsqueeze(0)
117 | pha = torch.cat((pha[...,:int(size*(1/up_ratio)),:],psudo_pha[...,int(size*(1/up_ratio)):,:]),dim=-2)
118 | else:
119 | if up_ratio > 1:
120 | size = pha.size(-2)
121 | psudo_pha = 2*torch.randint(low=0,high=2,size=pha.size(),device=device)-1
122 | pha = torch.cat((pha[...,:int(size*(1/up_ratio)),:],psudo_pha[...,int(size*(1/up_ratio)):,:]),dim=-2)
123 | # BCHW -> BWH
124 | #print(spectro.shape)
125 | spectro = spectro*pha
126 | if explicit_encoding:
127 | audio = _imdct(spectro.permute(0,2,1).contiguous())/2
128 | else:
129 | audio = _imdct(spectro.squeeze(1).permute(0,2,1).contiguous())/2
130 | return audio
131 |
132 | def compute_matrics(hr_audio,lr_audio,sr_audio,opt):
133 | #print(hr_audio.shape,lr_audio.shape,sr_audio.shape)
134 | device = sr_audio.device
135 | hr_audio = hr_audio.to(device)
136 | lr_audio = lr_audio.to(device)
137 |
138 | # Calculate error
139 | mse = ((sr_audio-hr_audio)**2).mean().item()
140 |
141 | # Calculate SNR
142 | snr_sr = 10*torch.log10(torch.sum(hr_audio**2, dim=-1)/torch.sum((sr_audio-hr_audio)**2, dim=-1)).mean().item()
143 | snr_lr = 10*torch.log10(torch.sum(hr_audio**2,dim=-1)/torch.sum((lr_audio-hr_audio)**2,dim=-1)).mean().item()
144 |
145 | # Calculate segmental SNR
146 | #ssnr_sr = pysepm.SNRseg(clean_speech=hr_audio.numpy(), processed_speech=sr_audio.numpy(), fs=opt.hr_sampling_rate)
147 | #ssnr_lr = pysepm.SNRseg(clean_speech=hr_audio.numpy(), processed_speech=lr_audio.numpy(), fs=opt.hr_sampling_rate)
148 |
149 | # Calculate PESQ
150 | """ if hr_audio.dim() > 1:
151 | hr_audio = hr_audio.squeeze()
152 | sr_audio = sr_audio.squeeze()
153 | for i in range(hr_audio.size(-2)):
154 | p = []
155 | h = hr_audio[i,:]
156 | s = sr_audio[i,:]
157 | try:
158 | pesq = pysepm.pesq(aF.resample(h, orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), aF.resample(s, orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), 16000)
159 | p.append(pesq)
160 | except:
161 | print('PESQ no utterance')
162 |
163 | pesq = np.mean(p)
164 | else:
165 | try:
166 | pesq = pysepm.pesq(aF.resample(hr_audio,orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), aF.resample(sr_audio,orig_freq=opt.hr_sampling_rate, new_freq=16000).numpy(), 16000)
167 | except:
168 | pesq = 0 """
169 |
170 | # Calculte STFT loss(LSD)
171 | hr_stft = aF.spectrogram(hr_audio, n_fft=2*opt.n_fft, hop_length=2*opt.hop_length, win_length=2*opt.win_length, window=kbdwin(2*opt.win_length).to(device), center=opt.center, pad=0, power=2, normalized=False)
172 | sr_stft = aF.spectrogram(sr_audio, n_fft=2*opt.n_fft, hop_length=2*opt.hop_length, win_length=2*opt.win_length, window=kbdwin(2*opt.win_length).to(device), center=opt.center, pad=0, power=2, normalized=False)
173 | hr_stft_log = torch.log10(hr_stft+1e-6)
174 | sr_stft_log = torch.log10(sr_stft+1e-6)
175 | lsd = torch.sqrt(torch.mean((hr_stft_log-sr_stft_log)**2,dim=-2)).mean().item()
176 |
177 | return mse,snr_sr,snr_lr,0,0,0,lsd
178 |
179 | def kbdwin(N:int, beta:float=12.0, device='cpu')->torch.Tensor:
180 | # Matlab style Kaiser-Bessel window
181 | # Author: Chenhao Shuai
182 | assert N%2==0, "N must be even"
183 | w = torch.kaiser_window(window_length=N//2+1, beta=beta*torch.pi, periodic=False, device=device)
184 | w_sum = w.sum()
185 | wdw_half = torch.sqrt(torch.cumsum(w,dim=0)/w_sum)[:-1]
186 | return torch.cat((wdw_half,wdw_half.flip(dims=(0,))),dim=0)
187 |
188 | def alignment(x,y,win_len=128):
189 | x_max_idx = torch.argmax(x)
190 | x_sample = x[...,int(x_max_idx-win_len//2):int(x_max_idx+win_len//2)]
191 | corr = conv1d(y,x_sample,dilation=0)
192 |
193 |
194 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.parser = argparse.ArgumentParser()
9 | self.initialized = False
10 |
11 | def initialize(self):
12 | # experiment specifics
13 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
15 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
16 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
17 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
21 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
22 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
23 | self.parser.add_argument('--seed', type=int, default=42, help='random seed for reproducing results')
24 | self.parser.add_argument('--fit_residual', action='store_true', default=False, help='if specified, fit HR-LR than directly fit HR')
25 |
26 | # input/output sizes
27 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
28 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
29 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
30 | self.parser.add_argument('--label_nc', type=int, default=0, help='# of input label channels')
31 | self.parser.add_argument('--input_nc', type=int, default=2, help='# of input spectro channels')
32 | self.parser.add_argument('--output_nc', type=int, default=1, help='# of output spectro channels')
33 |
34 | # for setting inputs
35 | self.parser.add_argument('--dataroot', type=str, default='./datasets/vctk/train.csv')
36 | self.parser.add_argument('--evalroot', type=str, default='./datasets/vctk/test.csv')
37 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
40 | self.parser.add_argument('--explicit_encoding', action='store_true', help='if selected, using trick to encode phase')
41 | self.parser.add_argument('--alpha', type=float, default=0.6, help='phase encoding factor')
42 | self.parser.add_argument('--norm_range', type=float, default=(0,1), nargs=2, help='specify the target ditribution range')
43 | self.parser.add_argument('--abs_norm', action='store_true', help='if selected, assuming the spectrograms are all distributed in a fixed range. Thus instead of normalizing by min and max each by each, normalize by an absolute range.')
44 | self.parser.add_argument('--src_range', type=float, default=(-5,5), nargs=2, help='specify the source ditribution range. This value is used when --abs_norm is specified.')
45 | self.parser.add_argument('--arcsinh_transform', action='store_true', help='if selected, using log(G*x+sqrt(((G*x)^2+1))) to compressing the range of input. Do not use this option with --explicit_encoding')
46 | self.parser.add_argument('--raw_mdct', action='store_true', help='if selected, DO NO transform. Do not use this option with --explicit_encoding|arcsinh_transform')
47 | self.parser.add_argument('--arcsinh_gain', type=float, default=500, help='gain of arcsinh_trasform input')
48 | self.parser.add_argument('--add_noise', action='store_true', help='if selected, add some noise to input waveform')
49 | self.parser.add_argument('--snr', type=float, default=55, help='add noise by SnR (working if --add_noise is selected)')
50 |
51 | # for displays
52 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
53 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
54 |
55 | # for generator
56 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
57 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
58 | self.parser.add_argument('--upsample_type', type=str, default='transconv', help='selects upsampling layers for netG [transconv|interpolate]')
59 | self.parser.add_argument('--downsample_type', type=str, default='conv', help='selects upsampling layers for netG [resconv|conv]')
60 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
61 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
62 | self.parser.add_argument('--n_blocks_attn_g', type=int, default=1, help='number of attention blocks in the global generator network')
63 | self.parser.add_argument('--proj_factor_g', type=int, default=4, help='projection factor of attention blocks in the global generator network')
64 | self.parser.add_argument('--dim_head_g', type=int, default=128, help='dim of attention heads in the global generator network')
65 | self.parser.add_argument('--heads_g', type=int, default=4, help='number of attention heads in the global generator network')
66 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
67 | self.parser.add_argument('--n_blocks_attn_l', type=int, default=0, help='number of attention blocks in the local enhancer network')
68 | self.parser.add_argument('--proj_factor_l', type=int, default=4, help='projection factor of attention blocks in the local enhancers network')
69 | self.parser.add_argument('--dim_head_l', type=int, default=128, help='dim of attention heads in the local enhancers network')
70 | self.parser.add_argument('--heads_l', type=int, default=4, help='number of attention heads in the local enhancers network')
71 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
72 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
73 |
74 | # for instance-wise features
75 | # self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input', default=True)
76 | # self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
77 | # self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
78 | # self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
79 | # self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
80 | # self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
81 | # self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
82 | # self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
83 |
84 | # input mask options
85 | self.parser.add_argument('--mask', action='store_true', help='mask high freq conponent of lr spectro')
86 | self.parser.add_argument('--smooth', type=float, default=0.0, help='smooth the edge of the sr and lr')
87 | self.parser.add_argument('--mask_hr', action='store_true', help='mask high freq conponent of hr spectro')
88 | self.parser.add_argument('--mask_mode', type=str, default=None, help='[None|mode0|mode1]')
89 | self.parser.add_argument('--min_value', type=float, default=1e-7, help='minimum value to cutoff the spectrogram')
90 |
91 | self.initialized = True
92 |
93 | def parse(self, save=True):
94 | if not self.initialized:
95 | self.initialize()
96 | self.opt = self.parser.parse_args()
97 | self.opt.isTrain = self.isTrain # train or test
98 |
99 | str_ids = self.opt.gpu_ids.split(',')
100 | self.opt.gpu_ids = []
101 | for str_id in str_ids:
102 | id = int(str_id)
103 | if id >= 0:
104 | self.opt.gpu_ids.append(id)
105 |
106 | # set gpu ids
107 | if len(self.opt.gpu_ids) > 0:
108 | torch.cuda.set_device(self.opt.gpu_ids[0])
109 |
110 | args = vars(self.opt)
111 |
112 | print('------------ Options -------------')
113 | for k, v in sorted(args.items()):
114 | print('%s: %s' % (str(k), str(v)))
115 | print('-------------- End ----------------')
116 |
117 | # save to the disk
118 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
119 | util.mkdirs(expr_dir)
120 | if save and not self.opt.continue_train:
121 | file_name = os.path.join(expr_dir, 'opt.txt')
122 | with open(file_name, 'wt') as opt_file:
123 | opt_file.write('------------ Options -------------\n')
124 | for k, v in sorted(args.items()):
125 | opt_file.write('%s: %s\n' % (str(k), str(v)))
126 | opt_file.write('-------------- End ----------------\n')
127 | return self.opt
128 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from data.data_loader import CreateDataLoader
2 | import signal
3 | from util.util import compute_matrics
4 | from util.visualizer import Visualizer
5 | from options.train_options import TrainOptions
6 | from models.models import create_model
7 |
8 | import math
9 | import os
10 | import time
11 | import csv
12 | import gc
13 |
14 | import numpy as np
15 | import torch
16 |
17 |
18 |
19 | def lcm(a, b): return abs(a * b)/math.gcd(a, b) if a and b else 0
20 |
21 |
22 | # import debugpy
23 | # debugpy.listen(("localhost", 5678))
24 | # debugpy.wait_for_client()
25 | # os.environ['CUDA_VISIBLE_DEVICES']='0'
26 | torch.backends.cudnn.benchmark = True
27 | # Get the training options
28 | opt = TrainOptions().parse()
29 | # Set the seed
30 | torch.manual_seed(opt.seed)
31 | # Set the path for save the trainning losses
32 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
33 | eval_path = os.path.join(opt.checkpoints_dir, opt.name, 'eval.csv')
34 |
35 | if opt.continue_train:
36 | try:
37 | start_epoch, epoch_iter = np.loadtxt(
38 | iter_path, delimiter=',', dtype=int)
39 | except:
40 | start_epoch, epoch_iter = 1, 0
41 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
42 | else:
43 | start_epoch, epoch_iter = 1, 0
44 |
45 | # Create the data loader
46 | data_loader = CreateDataLoader(opt)
47 | train_dataloader = data_loader.get_train_dataloader()
48 | train_dataset_size = len(data_loader)
49 | eval_dataloader = data_loader.get_eval_dataloader()
50 | eval_dataset_size = data_loader.eval_data_len()
51 | print('#training data = %d' % train_dataset_size)
52 | print('#evaluating data = %d' % eval_dataset_size)
53 |
54 | # Create the model
55 | model = create_model(opt)
56 | visualizer = Visualizer(opt)
57 | optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D
58 |
59 | # IMDCT for evaluation
60 | # from util.util import kbdwin, imdct
61 | # # from dct.dct import IDCT
62 | # # _idct = IDCT()
63 | # _imdct = IMDCT4(window=kbdwin, win_length=opt.win_length, hop_length=opt.hop_length, n_fft=opt.n_fft, center=opt.center, out_length=opt.segment_length, device = 'cuda')
64 |
65 | if opt.fp16:
66 | from torch.cuda.amp import autocast as autocast
67 | from torch.cuda.amp import GradScaler
68 | # According to the offical tutorial, use only one GradScaler and backward losses separately
69 | # https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers
70 | scaler = GradScaler()
71 |
72 |
73 | # Set frequency for displaying information and saving
74 | opt.print_freq = lcm(opt.print_freq, opt.batchSize)
75 | if opt.debug:
76 | opt.display_freq = 1
77 | opt.print_freq = 1
78 | opt.niter = 1
79 | opt.niter_decay = 0
80 | opt.max_dataset_size = 10
81 | total_steps = (start_epoch-1) * train_dataset_size + epoch_iter
82 | display_delta = total_steps % opt.display_freq
83 | print_delta = total_steps % opt.print_freq
84 | save_delta = total_steps % opt.save_latest_freq
85 | eval_delta = total_steps % opt.eval_freq if opt.validation_split > 0 else -1
86 | # loss_update_delta = total_steps % opt.loss_update_freq if opt.use_time_D or opt.use_match_loss else -1
87 |
88 | # Safe ctrl-c
89 | end = False
90 |
91 |
92 | def signal_handler(signal, frame):
93 | print('You pressed Ctrl+C!')
94 | global end
95 | end = True
96 |
97 |
98 | signal.signal(signal.SIGINT, signal_handler)
99 |
100 | # Evaluation process
101 | # Wrap it as a function so that I dont have to free up memory manually
102 |
103 |
104 | def eval_model():
105 | err = []
106 | snr = []
107 | snr_seg = []
108 | pesq = []
109 | lsd = []
110 | for j, eval_data in enumerate(eval_dataloader):
111 | model.eval()
112 | lr_audio = eval_data['LR_audio'].cuda()
113 | hr_audio = eval_data['HR_audio'].cuda()
114 | with torch.no_grad():
115 | _, sr_audio, _, _, _ = model.inference(lr_audio)
116 | _mse, _snr_sr, _snr_lr, _ssnr_sr, _ssnr_lr, _pesq, _lsd = compute_matrics(
117 | hr_audio.squeeze(), lr_audio.squeeze(), sr_audio.squeeze(), opt)
118 | err.append(_mse)
119 | snr.append((_snr_lr, _snr_sr))
120 | snr_seg.append((_ssnr_lr, _ssnr_sr))
121 | pesq.append(_pesq)
122 | lsd.append(_lsd)
123 | if j >= opt.eval_size:
124 | break
125 |
126 | eval_result = {'err': np.mean(err), 'snr': np.mean(snr), 'snr_seg': np.mean(
127 | snr_seg), 'pesq': np.mean(pesq), 'lsd': np.mean(lsd)}
128 | with open(eval_path, 'a') as csv_file:
129 | writer = csv.DictWriter(csv_file, fieldnames=eval_result.keys())
130 | if csv_file.tell() == 0:
131 | writer.writeheader()
132 | writer.writerow(eval_result)
133 | print('Evaluation:', eval_result)
134 | model.train()
135 |
136 |
137 | # Training...
138 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
139 | epoch_start_time = time.time()
140 | if epoch != start_epoch:
141 | epoch_iter = epoch_iter % train_dataset_size
142 | if epoch > opt.niter_limit_aux:
143 | model.limit_aux_loss = True
144 | for i, data in enumerate(train_dataloader, start=epoch_iter):
145 | if end:
146 | print('exiting and saving the model at the epoch %d, iters %d' %
147 | (epoch, total_steps))
148 | model.save('latest')
149 | model.save(epoch)
150 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
151 | exit(0)
152 | if total_steps % opt.print_freq == print_delta:
153 | iter_start_time = time.time()
154 | total_steps += opt.batchSize
155 | epoch_iter += opt.batchSize
156 |
157 | # Whether to collect output images
158 | save_fake = total_steps % opt.display_freq == display_delta
159 |
160 | ############## Forward Pass ######################
161 | if opt.fp16:
162 | with autocast():
163 | losses, _ = model._forward(
164 | data['LR_audio'].cuda(), data['HR_audio'].cuda(), infer=False)
165 | else:
166 | losses, _ = model._forward(
167 | data['LR_audio'].cuda(), data['HR_audio'].cuda(), infer=False)
168 |
169 | # Sum per device losses
170 | losses = [torch.mean(x) if not isinstance(x, int)
171 | else x for x in losses]
172 | loss_dict = dict(zip(model.loss_names, losses))
173 |
174 | # Calculate final loss scalar
175 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + (loss_dict.get('D_fake_t', 0) + loss_dict.get(
176 | 'D_real_t', 0))*0.5 + (loss_dict.get('D_fake_mr', 0) + loss_dict.get('D_real_mr', 0))*0.5
177 | loss_G = loss_dict['G_GAN'] + loss_dict.get('G_mat', 0) + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get(
178 | 'G_VGG', 0) + loss_dict.get('G_GAN_t', 0) + loss_dict.get('G_GAN_mr', 0) + loss_dict.get('G_shift', 0)
179 |
180 | ############### Backward Pass ####################
181 | # update generator weights
182 | optimizer_G.zero_grad()
183 | if opt.fp16:
184 | #with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward()
185 | scaler.scale(loss_G).backward()
186 | scaler.step(optimizer_G)
187 | # update the scaler only once per iteration
188 | # scaler.update()
189 | else:
190 | loss_G.backward()
191 | optimizer_G.step()
192 |
193 | # update discriminator weights
194 | optimizer_D.zero_grad()
195 | if opt.fp16:
196 | #with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward()
197 | scaler.scale(loss_D).backward()
198 | scaler.step(optimizer_D)
199 | scaler.update()
200 | else:
201 | loss_D.backward()
202 | optimizer_D.step()
203 |
204 | ############## Display results and errors ##########
205 | # print out errors
206 | if total_steps % opt.print_freq == print_delta:
207 | errors = {k: v.data.item() if not isinstance(
208 | v, int) else v for k, v in loss_dict.items()}
209 | t = (time.time() - iter_start_time) / opt.print_freq
210 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
211 | visualizer.plot_current_errors(errors, total_steps)
212 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
213 |
214 | # display output images
215 | if save_fake:
216 | visuals = model.get_current_visuals()
217 | visualizer.display_current_results(visuals, epoch, total_steps)
218 | del visuals
219 |
220 | # save latest model
221 | if total_steps % opt.save_latest_freq == save_delta:
222 | print('saving the latest model (epoch %d, total_steps %d)' %
223 | (epoch, total_steps))
224 | model.save('latest')
225 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
226 |
227 | if total_steps % opt.eval_freq == eval_delta:
228 | del losses, loss_D, loss_G, loss_dict
229 | torch.cuda.empty_cache()
230 | gc.collect()
231 | eval_model()
232 | torch.cuda.empty_cache()
233 | gc.collect()
234 | # if total_steps % opt.loss_update_freq == loss_update_delta:
235 | # if opt.use_match_loss:
236 | # model.update_match_loss_scaler()
237 | # if opt.use_time_D:
238 | # model.update_time_D_loss_scaler()
239 |
240 | if epoch_iter >= train_dataset_size:
241 | break
242 |
243 | # end of epoch
244 | iter_end_time = time.time()
245 | print('End of epoch %d / %d \t Time Taken: %d sec' %
246 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
247 |
248 | # save model for this epoch
249 | if epoch % opt.save_epoch_freq == 0:
250 | print('saving the model at the end of epoch %d, iters %d' %
251 | (epoch, total_steps))
252 | model.save('latest')
253 | model.save(epoch)
254 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
255 |
256 | # instead of only training the local enhancer, train the entire network after certain iterations
257 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
258 | model.update_fixed_params()
259 |
260 | # linearly decay learning rate after certain iterations
261 | if epoch > opt.niter:
262 | model.update_learning_rate()
263 |
--------------------------------------------------------------------------------
/models/spectrogram.py:
--------------------------------------------------------------------------------
1 | """
2 | Module to transform signals
3 | from https://github.com/nils-werner/stft/
4 |
5 | """
6 | from __future__ import division, absolute_import
7 | import itertools
8 | import torch
9 | import torch.fft
10 | import torch.nn.functional
11 | import numpy as np
12 |
13 | def _pad(data, frame_length):
14 | return torch.nn.functional.pad(
15 | data,
16 | pad=(
17 | 0,
18 | int(
19 | np.ceil(
20 | len(data) / frame_length
21 | ) * frame_length - len(data)
22 | )
23 | ),
24 | mode='constant',
25 | value=0
26 | )
27 |
28 | def unpad(data, outlength):
29 | slicetuple = [slice(None)] * data.ndim
30 | slicetuple[0] = slice(None, outlength)
31 | return data[tuple(slicetuple)]
32 |
33 |
34 | def center_pad(data, frame_length):
35 | #padtuple = [(0, 0)] * data.ndim
36 | padtuple = (frame_length // 2, frame_length // 2)
37 | #print(padtuple)
38 | return torch.nn.functional.pad(
39 | data,
40 | pad = padtuple,
41 | mode = 'constant',
42 | value = 0
43 | )
44 |
45 | def center_unpad(data, frame_length):
46 | slicetuple = [slice(None)] * data.ndim
47 | slicetuple[1] = slice(frame_length // 2, -frame_length // 2)
48 | return data[tuple(slicetuple)]
49 |
50 | def process(
51 | data,
52 | window_function,
53 | halved,
54 | transform,
55 | padding=0,
56 | n_fft=1024
57 | ):
58 | """Calculate a windowed transform of a signal
59 |
60 | Parameters
61 | ----------
62 | data : array_like
63 | The signal to be calculated. Must be a 1D array.
64 | window : array_like
65 | Tapering window
66 | halved : boolean
67 | Switch for turning on signal truncation. For real signals, the fourier
68 | transform of real signals returns a symmetrically mirrored spectrum.
69 | This additional data is not needed and can be removed.
70 | transform : callable
71 | The transform to be used.
72 | padding : int
73 | Zero-pad signal with x times the number of samples.
74 |
75 | Returns
76 | -------
77 | data : array_like
78 | The spectrum
79 |
80 | """
81 |
82 | data = data * window_function
83 |
84 | if padding > 0:
85 | data = torch.nn.functional.pad(
86 | data,
87 | pad=(
88 | 0,
89 | len(data) * padding
90 | ),
91 | mode='constant',
92 | value=0
93 | )
94 |
95 | result = transform(data,n_fft)
96 |
97 | if(halved):
98 | result = result[0:result.size // 2 + 1]
99 |
100 | return result
101 |
102 |
103 | def iprocess(
104 | data,
105 | window_function,
106 | halved,
107 | transform,
108 | padding=0,
109 | n_fft=1024
110 | ):
111 | """Calculate the inverse short time fourier transform of a spectrum
112 |
113 | Parameters
114 | ----------
115 | data : array_like
116 | The spectrum to be calculated. Must be a 1D array.
117 | window : array_like
118 | Tapering window
119 | halved : boolean
120 | Switch for turning on signal truncation. For real output signals, the
121 | inverse fourier transform consumes a symmetrically mirrored spectrum.
122 | This additional data is not needed and can be removed. Setting this
123 | value to :code:`True` will automatically create a mirrored spectrum.
124 | transform : callable
125 | The transform to be used.
126 | padding : int
127 | Signal before FFT transform was padded with x zeros.
128 |
129 |
130 | Returns
131 | -------
132 | data : array_like
133 | The signal
134 |
135 | """
136 | if halved:
137 | data = torch.nn.functional.pad(data, (0, data.shape[0] - 2), 'reflect')
138 | start = data.shape[0] // 2 + 1
139 | data[start:] = data[start:].conjugate()
140 |
141 | output = transform(data,n_fft)
142 | if torch.is_complex(output):
143 | output = torch.real(output)
144 |
145 | if padding > 0:
146 | output = output[0:-(len(data) * padding // (padding + 1))]
147 |
148 | return output * window_function
149 |
150 |
151 | def spectrogram(
152 | data,
153 | frame_length=1024,
154 | step_length=None,
155 | overlap=None,
156 | centered=True,
157 | n_fft=1024,
158 | window_function=None,
159 | halved=True,
160 | transform=None,
161 | padding=0
162 | ):
163 | """Calculate the spectrogram of a signal
164 |
165 | Parameters
166 | ----------
167 | data : array_like
168 | The signal to be transformed. May be a 1D vector for single channel or
169 | a 2D matrix for multi channel data. In case of a mono signal, the data
170 | is must be a 1D vector of length :code:`samples`. In case of a multi
171 | channel signal, the data must be in the shape of :code:`samples x
172 | channels`.
173 | frame_length : int
174 | The signal frame length. Defaults to :code:`1024`.
175 | step_length : int
176 | The signal frame step_length. Defaults to :code:`None`. Setting this
177 | value will override :code:`overlap`.
178 | overlap : int
179 | The signal frame overlap coefficient. Value :code:`x` means
180 | :code:`1/x` overlap. Defaults to :code:`2`.
181 | centered : boolean
182 | Pad input signal so that the first and last window are centered around
183 | the beginning of the signal. Defaults to true.
184 | window : callable, array_like
185 | Window to be used for deringing. Can be :code:`False` to disable
186 | windowing. Defaults to :code:`scipy.signal.cosine`.
187 | halved : boolean
188 | Switch for turning on signal truncation. For real signals, the fourier
189 | transform of real signals returns a symmetrically mirrored spectrum.
190 | This additional data is not needed and can be removed. Defaults to
191 | :code:`True`.
192 | transform : callable
193 | The transform to be used. Defaults to :code:`scipy.fft.fft`.
194 | padding : int
195 | Zero-pad signal with x times the number of samples.
196 | save_settings : boolean
197 | Save settings used here in attribute :code:`out.stft_settings` so that
198 | :func:`ispectrogram` can infer these settings without the developer
199 | having to pass them again.
200 |
201 | Returns
202 | -------
203 | data : array_like
204 | The spectrogram (or tensor of spectograms) In case of a mono signal,
205 | the data is formatted as :code:`bins x frames`. In case of a multi
206 | channel signal, the data is formatted as :code:`bins x frames x
207 | channels`.
208 |
209 | Notes
210 | -----
211 | The data will be padded to be a multiple of the desired FFT length.
212 |
213 | See Also
214 | --------
215 | stft.stft.process : The function used to transform the data
216 |
217 | """
218 |
219 | if overlap is None:
220 | overlap = 2
221 |
222 | if step_length is None:
223 | step_length = frame_length // overlap
224 |
225 | if halved and torch.any(torch.iscomplex(data)):
226 | raise ValueError("You cannot treat a complex input signal as real "
227 | "valued. Please set keyword argument halved=False.")
228 |
229 | data = torch.squeeze(data)
230 |
231 | if transform is None:
232 | transform = torch.fft.fft
233 |
234 | if not isinstance(transform, (list, tuple)):
235 | transform = [transform]
236 |
237 | transforms = itertools.cycle(transform)
238 |
239 | if centered:
240 | #print("center pad")
241 | data = center_pad(data, frame_length)
242 |
243 | if window_function is None:
244 | window_array = torch.ones(frame_length)
245 |
246 | if callable(window_function):
247 | window_array = window_function(frame_length)
248 | else:
249 | window_array = window_function
250 | frame_length = len(window_array)
251 |
252 | def traf(data):
253 | # Pad input signal so it fits into frame_length spec
254 | #print(data.size())
255 | #data = _pad(data, frame_length)
256 | #print(data.size())
257 |
258 | values = list(enumerate(
259 | range(0, len(data) - frame_length + step_length, step_length)
260 | ))
261 | #print(values)
262 | for j, i in values:
263 | sig = process(
264 | data[i:i + frame_length],
265 | window_function=window_array,
266 | halved=halved,
267 | transform=next(transforms),
268 | padding=padding,
269 | n_fft = n_fft
270 | ) / (frame_length // step_length // 2)
271 | if(i == 0):
272 | output_ = torch.zeros(
273 | (sig.shape[0], len(values)), dtype=sig.dtype
274 | )
275 |
276 | output_[:, j] = sig
277 |
278 | return output_
279 |
280 | if data.ndim > 2:
281 | raise ValueError("spectrogram: Only 1D or 2D input data allowed")
282 | if data.ndim == 1:
283 | out = traf(data)
284 | elif data.ndim == 2:
285 | #print(data.size())
286 | for i in range(data.shape[0]):
287 | tmp = traf(data[i,:])
288 | #print(tmp.size())
289 | if i == 0:
290 | out = torch.empty(
291 | ((data.shape[0],)+tmp.shape), dtype=tmp.dtype
292 | )
293 | out[i, :, :] = tmp
294 | return out
295 |
296 |
297 | def ispectrogram(
298 | data,
299 | frame_length=1024,
300 | step_length=None,
301 | overlap=None,
302 | centered=True,
303 | n_fft=1024,
304 | window_function=None,
305 | halved=True,
306 | transform=None,
307 | padding=0,
308 | out_length=None
309 | ):
310 | """Calculate the inverse spectrogram of a signal
311 |
312 | Parameters
313 | ----------
314 | data : array_like
315 | The spectrogram to be inverted. May be a 2D matrix for single channel
316 | or a 3D tensor for multi channel data. In case of a mono signal, the
317 | data must be in the shape of :code:`bins x frames`. In case of a multi
318 | channel signal, the data must be in the shape of :code:`bins x frames x
319 | channels`.
320 | frame_length : int
321 | The signal frame length. Defaults to infer from data.
322 | step_length : int
323 | The signal frame step_length. Defaults to infer from data. Setting this
324 | value will override :code:`overlap`.
325 | overlap : int
326 | The signal frame overlap coefficient. Value :code:`x` means
327 | :code:`1/x` overlap. Defaults to infer from data.
328 | centered : boolean
329 | Pad input signal so that the first and last window are centered around
330 | the beginning of the signal. Defaults to to infer from data.
331 | window : callable, array_like
332 | Window to be used for deringing. Can be :code:`False` to disable
333 | windowing. Defaults to to infer from data.
334 | halved : boolean
335 | Switch to reconstruct the other halve of the spectrum if the forward
336 | transform has been truncated. Defaults to to infer from data.
337 | transform : callable
338 | The transform to be used. Defaults to infer from data.
339 | padding : int
340 | Zero-pad signal with x times the number of samples. Defaults to infer
341 | from data.
342 | outlength : int
343 | Crop output signal to length. Useful when input length of spectrogram
344 | did not fit into frame_length and input data had to be padded. Not
345 | setting this value will disable cropping, the output data may be
346 | longer than expected.
347 |
348 | Returns
349 | -------
350 | data : array_like
351 | The signal (or matrix of signals). In case of a mono output signal, the
352 | data is formatted as a 1D vector of length :code:`samples`. In case of
353 | a multi channel output signal, the data is formatted as :code:`samples
354 | x channels`.
355 |
356 | Notes
357 | -----
358 | By default :func:`spectrogram` saves its transformation parameters in
359 | the output array. This data is used to infer the transform parameters
360 | here. Any aspect of the settings can be overridden by passing the according
361 | parameter to this function.
362 |
363 | During transform the data will be padded to be a multiple of the desired
364 | FFT length. Hence, the result of the inverse transform might be longer
365 | than the input signal. However it is safe to remove the additional data,
366 | e.g. by using
367 |
368 | .. code:: python
369 |
370 | output.resize(input.shape)
371 |
372 | where :code:`input` is the input of :code:`stft.spectrogram()` and
373 | :code:`output` is the output of :code:`stft.ispectrogram()`
374 |
375 | See Also
376 | --------
377 | stft.stft.iprocess : The function used to transform the data
378 |
379 | """
380 |
381 | if overlap is None:
382 | overlap = 2
383 |
384 | if step_length is None:
385 | step_length = frame_length // overlap
386 |
387 | if window_function is None:
388 | window_array = torch.ones(frame_length)
389 |
390 | if callable(window_function):
391 | window_array = window_function(frame_length)
392 | else:
393 | window_array = window_function
394 |
395 | if transform is None:
396 | transform = torch.fft.ifft
397 |
398 | if not isinstance(transform, (list, tuple)):
399 | transform = [transform]
400 |
401 | transforms = itertools.cycle(transform)
402 |
403 | def traf(data):
404 | i = 0
405 | values = range(0, data.shape[1])
406 | for j in values:
407 | sig = iprocess(
408 | data[:, j],
409 | window_function=window_array,
410 | halved=halved,
411 | transform=next(transforms),
412 | padding=padding,
413 | n_fft=n_fft
414 | )
415 |
416 | if(i == 0):
417 | output = torch.zeros(
418 | frame_length + (len(values) - 1) * step_length,
419 | dtype=sig.dtype
420 | ).cuda()
421 |
422 | output[i:i + frame_length] += sig
423 |
424 | i += step_length
425 |
426 | return output
427 |
428 | if data.ndim == 2:
429 | out = traf(data)
430 | elif data.ndim == 3:
431 | for i in range(data.shape[0]):
432 | tmp = traf(data[i ,:, :])
433 |
434 | if i == 0:
435 | out = torch.empty(
436 | ((data.shape[0],)+tmp.shape), dtype=tmp.dtype
437 | )
438 | out[i,:] = tmp
439 | else:
440 | raise ValueError("ispectrogram: Only 2D or 3D input data allowed")
441 |
442 | if centered:
443 | print(out.size())
444 | out = center_unpad(out, frame_length)
445 |
446 | return unpad(out, out_length)
--------------------------------------------------------------------------------
/data/test.csv:
--------------------------------------------------------------------------------
1 | wav48/p250/p250_328.wav
2 | wav48/p310/p310_345.wav
3 | wav48/p227/p227_020.wav
4 | wav48/p285/p285_050.wav
5 | wav48/p248/p248_011.wav
6 | wav48/p246/p246_030.wav
7 | wav48/p247/p247_191.wav
8 | wav48/p287/p287_127.wav
9 | wav48/p334/p334_220.wav
10 | wav48/p340/p340_414.wav
11 | wav48/p236/p236_231.wav
12 | wav48/p301/p301_334.wav
13 | wav48/p258/p258_021.wav
14 | wav48/p374/p374_403.wav
15 | wav48/p312/p312_182.wav
16 | wav48/p232/p232_237.wav
17 | wav48/p334/p334_080.wav
18 | wav48/p347/p347_206.wav
19 | wav48/p313/p313_347.wav
20 | wav48/p304/p304_274.wav
21 | wav48/p351/p351_117.wav
22 | wav48/p351/p351_051.wav
23 | wav48/p312/p312_109.wav
24 | wav48/p323/p323_148.wav
25 | wav48/p329/p329_187.wav
26 | wav48/p308/p308_255.wav
27 | wav48/p281/p281_039.wav
28 | wav48/p294/p294_004.wav
29 | wav48/p310/p310_056.wav
30 | wav48/p271/p271_016.wav
31 | wav48/p249/p249_276.wav
32 | wav48/p285/p285_305.wav
33 | wav48/p330/p330_266.wav
34 | wav48/p229/p229_006.wav
35 | wav48/p273/p273_313.wav
36 | wav48/p329/p329_392.wav
37 | wav48/p305/p305_384.wav
38 | wav48/p256/p256_177.wav
39 | wav48/p237/p237_327.wav
40 | wav48/p308/p308_028.wav
41 | wav48/p339/p339_401.wav
42 | wav48/p298/p298_049.wav
43 | wav48/p257/p257_331.wav
44 | wav48/p301/p301_302.wav
45 | wav48/p227/p227_330.wav
46 | wav48/p255/p255_292.wav
47 | wav48/p362/p362_040.wav
48 | wav48/p237/p237_218.wav
49 | wav48/p281/p281_233.wav
50 | wav48/p312/p312_350.wav
51 | wav48/p312/p312_130.wav
52 | wav48/p287/p287_056.wav
53 | wav48/p376/p376_051.wav
54 | wav48/p253/p253_148.wav
55 | wav48/p266/p266_336.wav
56 | wav48/p288/p288_307.wav
57 | wav48/p254/p254_165.wav
58 | wav48/p333/p333_395.wav
59 | wav48/p307/p307_171.wav
60 | wav48/p243/p243_255.wav
61 | wav48/p274/p274_439.wav
62 | wav48/p302/p302_235.wav
63 | wav48/p248/p248_047.wav
64 | wav48/p363/p363_273.wav
65 | wav48/p292/p292_143.wav
66 | wav48/p286/p286_001.wav
67 | wav48/p259/p259_223.wav
68 | wav48/p233/p233_334.wav
69 | wav48/p267/p267_109.wav
70 | wav48/p339/p339_006.wav
71 | wav48/p295/p295_006.wav
72 | wav48/p281/p281_018.wav
73 | wav48/p252/p252_175.wav
74 | wav48/p297/p297_108.wav
75 | wav48/p243/p243_398.wav
76 | wav48/p226/p226_112.wav
77 | wav48/p341/p341_232.wav
78 | wav48/p363/p363_368.wav
79 | wav48/p362/p362_240.wav
80 | wav48/p306/p306_100.wav
81 | wav48/p301/p301_049.wav
82 | wav48/p257/p257_238.wav
83 | wav48/p245/p245_039.wav
84 | wav48/p313/p313_133.wav
85 | wav48/p278/p278_147.wav
86 | wav48/p280/p280_130.wav
87 | wav48/p271/p271_052.wav
88 | wav48/p341/p341_131.wav
89 | wav48/p233/p233_074.wav
90 | wav48/p283/p283_181.wav
91 | wav48/p260/p260_034.wav
92 | wav48/p273/p273_294.wav
93 | wav48/p286/p286_245.wav
94 | wav48/p314/p314_009.wav
95 | wav48/p300/p300_277.wav
96 | wav48/p271/p271_022.wav
97 | wav48/p361/p361_214.wav
98 | wav48/p271/p271_385.wav
99 | wav48/p313/p313_026.wav
100 | wav48/p277/p277_096.wav
101 | wav48/p270/p270_352.wav
102 | wav48/p279/p279_257.wav
103 | wav48/p297/p297_015.wav
104 | wav48/p241/p241_014.wav
105 | wav48/p228/p228_072.wav
106 | wav48/p251/p251_092.wav
107 | wav48/p252/p252_169.wav
108 | wav48/p249/p249_003.wav
109 | wav48/p278/p278_023.wav
110 | wav48/p339/p339_134.wav
111 | wav48/p295/p295_193.wav
112 | wav48/p283/p283_397.wav
113 | wav48/p266/p266_173.wav
114 | wav48/p276/p276_394.wav
115 | wav48/p311/p311_354.wav
116 | wav48/p246/p246_194.wav
117 | wav48/p307/p307_256.wav
118 | wav48/p306/p306_010.wav
119 | wav48/p278/p278_159.wav
120 | wav48/p229/p229_359.wav
121 | wav48/p376/p376_086.wav
122 | wav48/p263/p263_265.wav
123 | wav48/p243/p243_058.wav
124 | wav48/p244/p244_215.wav
125 | wav48/p257/p257_174.wav
126 | wav48/p300/p300_295.wav
127 | wav48/p247/p247_113.wav
128 | wav48/p286/p286_032.wav
129 | wav48/p314/p314_344.wav
130 | wav48/p347/p347_234.wav
131 | wav48/p255/p255_302.wav
132 | wav48/p376/p376_140.wav
133 | wav48/p362/p362_182.wav
134 | wav48/p294/p294_123.wav
135 | wav48/p330/p330_380.wav
136 | wav48/p294/p294_242.wav
137 | wav48/p243/p243_280.wav
138 | wav48/p298/p298_240.wav
139 | wav48/p274/p274_302.wav
140 | wav48/p281/p281_285.wav
141 | wav48/p271/p271_093.wav
142 | wav48/p323/p323_314.wav
143 | wav48/p249/p249_140.wav
144 | wav48/p333/p333_060.wav
145 | wav48/p300/p300_108.wav
146 | wav48/p236/p236_450.wav
147 | wav48/p251/p251_196.wav
148 | wav48/p262/p262_308.wav
149 | wav48/p292/p292_259.wav
150 | wav48/p311/p311_361.wav
151 | wav48/p226/p226_001.wav
152 | wav48/p239/p239_394.wav
153 | wav48/p240/p240_028.wav
154 | wav48/p304/p304_143.wav
155 | wav48/p279/p279_209.wav
156 | wav48/p286/p286_072.wav
157 | wav48/p374/p374_330.wav
158 | wav48/p302/p302_208.wav
159 | wav48/p287/p287_344.wav
160 | wav48/p261/p261_220.wav
161 | wav48/p345/p345_253.wav
162 | wav48/p326/p326_002.wav
163 | wav48/p277/p277_417.wav
164 | wav48/p249/p249_211.wav
165 | wav48/p283/p283_441.wav
166 | wav48/p351/p351_292.wav
167 | wav48/p301/p301_309.wav
168 | wav48/p363/p363_252.wav
169 | wav48/p329/p329_005.wav
170 | wav48/p229/p229_098.wav
171 | wav48/p261/p261_312.wav
172 | wav48/p334/p334_422.wav
173 | wav48/p258/p258_059.wav
174 | wav48/p262/p262_183.wav
175 | wav48/p271/p271_441.wav
176 | wav48/p276/p276_301.wav
177 | wav48/p280/p280_221.wav
178 | wav48/p245/p245_293.wav
179 | wav48/p225/p225_301.wav
180 | wav48/p275/p275_272.wav
181 | wav48/p275/p275_310.wav
182 | wav48/p263/p263_002.wav
183 | wav48/p317/p317_302.wav
184 | wav48/p248/p248_252.wav
185 | wav48/p236/p236_499.wav
186 | wav48/p266/p266_356.wav
187 | wav48/p274/p274_233.wav
188 | wav48/p275/p275_246.wav
189 | wav48/p232/p232_025.wav
190 | wav48/p240/p240_009.wav
191 | wav48/p232/p232_236.wav
192 | wav48/p227/p227_096.wav
193 | wav48/p347/p347_172.wav
194 | wav48/p305/p305_363.wav
195 | wav48/p288/p288_270.wav
196 | wav48/p258/p258_286.wav
197 | wav48/p243/p243_220.wav
198 | wav48/p252/p252_409.wav
199 | wav48/p247/p247_283.wav
200 | wav48/p254/p254_355.wav
201 | wav48/p226/p226_246.wav
202 | wav48/p283/p283_468.wav
203 | wav48/p233/p233_132.wav
204 | wav48/p343/p343_066.wav
205 | wav48/p351/p351_339.wav
206 | wav48/p313/p313_099.wav
207 | wav48/p305/p305_383.wav
208 | wav48/p234/p234_318.wav
209 | wav48/p272/p272_228.wav
210 | wav48/p239/p239_308.wav
211 | wav48/p265/p265_287.wav
212 | wav48/p250/p250_424.wav
213 | wav48/p260/p260_357.wav
214 | wav48/p239/p239_228.wav
215 | wav48/p277/p277_130.wav
216 | wav48/p317/p317_269.wav
217 | wav48/p236/p236_287.wav
218 | wav48/p310/p310_151.wav
219 | wav48/p225/p225_166.wav
220 | wav48/p341/p341_290.wav
221 | wav48/p323/p323_167.wav
222 | wav48/p245/p245_208.wav
223 | wav48/p308/p308_287.wav
224 | wav48/p307/p307_154.wav
225 | wav48/p239/p239_137.wav
226 | wav48/p326/p326_306.wav
227 | wav48/p299/p299_394.wav
228 | wav48/p270/p270_203.wav
229 | wav48/p255/p255_143.wav
230 | wav48/p286/p286_143.wav
231 | wav48/p278/p278_248.wav
232 | wav48/p232/p232_035.wav
233 | wav48/p297/p297_404.wav
234 | wav48/p345/p345_207.wav
235 | wav48/p268/p268_107.wav
236 | wav48/p362/p362_148.wav
237 | wav48/p364/p364_223.wav
238 | wav48/p252/p252_404.wav
239 | wav48/p267/p267_238.wav
240 | wav48/p268/p268_182.wav
241 | wav48/p238/p238_387.wav
242 | wav48/p230/p230_207.wav
243 | wav48/p306/p306_126.wav
244 | wav48/p310/p310_201.wav
245 | wav48/p227/p227_065.wav
246 | wav48/p228/p228_214.wav
247 | wav48/p264/p264_082.wav
248 | wav48/p257/p257_226.wav
249 | wav48/p294/p294_302.wav
250 | wav48/p312/p312_184.wav
251 | wav48/p376/p376_105.wav
252 | wav48/p364/p364_056.wav
253 | wav48/p234/p234_067.wav
254 | wav48/p252/p252_061.wav
255 | wav48/p225/p225_027.wav
256 | wav48/p254/p254_136.wav
257 | wav48/p361/p361_078.wav
258 | wav48/p256/p256_088.wav
259 | wav48/p303/p303_331.wav
260 | wav48/p334/p334_406.wav
261 | wav48/p281/p281_243.wav
262 | wav48/p343/p343_293.wav
263 | wav48/p253/p253_126.wav
264 | wav48/p297/p297_152.wav
265 | wav48/p258/p258_226.wav
266 | wav48/p255/p255_363.wav
267 | wav48/p244/p244_107.wav
268 | wav48/p233/p233_295.wav
269 | wav48/p335/p335_420.wav
270 | wav48/p271/p271_082.wav
271 | wav48/p323/p323_136.wav
272 | wav48/p280/p280_259.wav
273 | wav48/p362/p362_221.wav
274 | wav48/p246/p246_121.wav
275 | wav48/p253/p253_049.wav
276 | wav48/p333/p333_196.wav
277 | wav48/p329/p329_048.wav
278 | wav48/p305/p305_062.wav
279 | wav48/p341/p341_103.wav
280 | wav48/p283/p283_264.wav
281 | wav48/p252/p252_038.wav
282 | wav48/p260/p260_314.wav
283 | wav48/p313/p313_352.wav
284 | wav48/p295/p295_360.wav
285 | wav48/p266/p266_169.wav
286 | wav48/p279/p279_108.wav
287 | wav48/p303/p303_123.wav
288 | wav48/p282/p282_087.wav
289 | wav48/p345/p345_265.wav
290 | wav48/p306/p306_292.wav
291 | wav48/p268/p268_321.wav
292 | wav48/p286/p286_311.wav
293 | wav48/p259/p259_268.wav
294 | wav48/p295/p295_328.wav
295 | wav48/p230/p230_185.wav
296 | wav48/p226/p226_257.wav
297 | wav48/p288/p288_042.wav
298 | wav48/p254/p254_334.wav
299 | wav48/p298/p298_227.wav
300 | wav48/p299/p299_075.wav
301 | wav48/p251/p251_053.wav
302 | wav48/p301/p301_284.wav
303 | wav48/p278/p278_313.wav
304 | wav48/p264/p264_491.wav
305 | wav48/p310/p310_032.wav
306 | wav48/p292/p292_120.wav
307 | wav48/p345/p345_386.wav
308 | wav48/p253/p253_305.wav
309 | wav48/p335/p335_072.wav
310 | wav48/p241/p241_228.wav
311 | wav48/p360/p360_362.wav
312 | wav48/p333/p333_088.wav
313 | wav48/p279/p279_122.wav
314 | wav48/p336/p336_127.wav
315 | wav48/p376/p376_045.wav
316 | wav48/p250/p250_240.wav
317 | wav48/p299/p299_273.wav
318 | wav48/p231/p231_226.wav
319 | wav48/p281/p281_367.wav
320 | wav48/p293/p293_185.wav
321 | wav48/p241/p241_032.wav
322 | wav48/p279/p279_029.wav
323 | wav48/p362/p362_116.wav
324 | wav48/p225/p225_243.wav
325 | wav48/p228/p228_323.wav
326 | wav48/p265/p265_342.wav
327 | wav48/p250/p250_282.wav
328 | wav48/p339/p339_355.wav
329 | wav48/p287/p287_357.wav
330 | wav48/p275/p275_390.wav
331 | wav48/p360/p360_249.wav
332 | wav48/p341/p341_012.wav
333 | wav48/p237/p237_150.wav
334 | wav48/p253/p253_121.wav
335 | wav48/p339/p339_417.wav
336 | wav48/p238/p238_355.wav
337 | wav48/p256/p256_129.wav
338 | wav48/p304/p304_047.wav
339 | wav48/p248/p248_200.wav
340 | wav48/p274/p274_268.wav
341 | wav48/p362/p362_150.wav
342 | wav48/p234/p234_069.wav
343 | wav48/p283/p283_288.wav
344 | wav48/p243/p243_080.wav
345 | wav48/p267/p267_338.wav
346 | wav48/p299/p299_188.wav
347 | wav48/p292/p292_229.wav
348 | wav48/p256/p256_076.wav
349 | wav48/p343/p343_174.wav
350 | wav48/p345/p345_321.wav
351 | wav48/p334/p334_194.wav
352 | wav48/p305/p305_204.wav
353 | wav48/p230/p230_128.wav
354 | wav48/p360/p360_286.wav
355 | wav48/p305/p305_324.wav
356 | wav48/p252/p252_315.wav
357 | wav48/p297/p297_137.wav
358 | wav48/p326/p326_059.wav
359 | wav48/p301/p301_394.wav
360 | wav48/p306/p306_180.wav
361 | wav48/p281/p281_132.wav
362 | wav48/p268/p268_394.wav
363 | wav48/p264/p264_116.wav
364 | wav48/p239/p239_259.wav
365 | wav48/p313/p313_087.wav
366 | wav48/p264/p264_108.wav
367 | wav48/p315/p315_250.wav
368 | wav48/p306/p306_290.wav
369 | wav48/p281/p281_036.wav
370 | wav48/p248/p248_113.wav
371 | wav48/p285/p285_038.wav
372 | wav48/p238/p238_027.wav
373 | wav48/p249/p249_278.wav
374 | wav48/p267/p267_091.wav
375 | wav48/p249/p249_311.wav
376 | wav48/p247/p247_012.wav
377 | wav48/p301/p301_293.wav
378 | wav48/p257/p257_361.wav
379 | wav48/p363/p363_017.wav
380 | wav48/p301/p301_314.wav
381 | wav48/p260/p260_180.wav
382 | wav48/p247/p247_456.wav
383 | wav48/p280/p280_066.wav
384 | wav48/p234/p234_056.wav
385 | wav48/p243/p243_352.wav
386 | wav48/p298/p298_400.wav
387 | wav48/p227/p227_003.wav
388 | wav48/p226/p226_350.wav
389 | wav48/p282/p282_036.wav
390 | wav48/p271/p271_273.wav
391 | wav48/p374/p374_042.wav
392 | wav48/p374/p374_186.wav
393 | wav48/p238/p238_046.wav
394 | wav48/p293/p293_361.wav
395 | wav48/p330/p330_148.wav
396 | wav48/p334/p334_190.wav
397 | wav48/p230/p230_220.wav
398 | wav48/p232/p232_393.wav
399 | wav48/p310/p310_179.wav
400 | wav48/p252/p252_391.wav
401 | wav48/p257/p257_370.wav
402 | wav48/p340/p340_184.wav
403 | wav48/p339/p339_171.wav
404 | wav48/p272/p272_160.wav
405 | wav48/p334/p334_185.wav
406 | wav48/p258/p258_052.wav
407 | wav48/p248/p248_314.wav
408 | wav48/p310/p310_421.wav
409 | wav48/p238/p238_232.wav
410 | wav48/p298/p298_097.wav
411 | wav48/p316/p316_099.wav
412 | wav48/p287/p287_182.wav
413 | wav48/p241/p241_062.wav
414 | wav48/p317/p317_350.wav
415 | wav48/p314/p314_105.wav
416 | wav48/p333/p333_096.wav
417 | wav48/p301/p301_345.wav
418 | wav48/p243/p243_090.wav
419 | wav48/p268/p268_277.wav
420 | wav48/p237/p237_253.wav
421 | wav48/p307/p307_021.wav
422 | wav48/p307/p307_413.wav
423 | wav48/p225/p225_303.wav
424 | wav48/p312/p312_414.wav
425 | wav48/p253/p253_340.wav
426 | wav48/p345/p345_117.wav
427 | wav48/p268/p268_246.wav
428 | wav48/p263/p263_264.wav
429 | wav48/p301/p301_130.wav
430 | wav48/p271/p271_295.wav
431 | wav48/p233/p233_072.wav
432 | wav48/p333/p333_406.wav
433 | wav48/p303/p303_202.wav
434 | wav48/p307/p307_386.wav
435 | wav48/p239/p239_162.wav
436 | wav48/p227/p227_040.wav
437 | wav48/p329/p329_256.wav
438 | wav48/p275/p275_020.wav
439 | wav48/p302/p302_116.wav
440 | wav48/p340/p340_308.wav
441 | wav48/p249/p249_189.wav
442 | wav48/p306/p306_096.wav
443 | wav48/p254/p254_335.wav
444 | wav48/p272/p272_042.wav
445 | wav48/p313/p313_258.wav
446 | wav48/p253/p253_168.wav
447 | wav48/p251/p251_225.wav
448 | wav48/p293/p293_293.wav
449 | wav48/p345/p345_370.wav
450 | wav48/p280/p280_161.wav
451 | wav48/p256/p256_124.wav
452 | wav48/p313/p313_329.wav
453 | wav48/p305/p305_072.wav
454 | wav48/p363/p363_364.wav
455 | wav48/p286/p286_251.wav
456 | wav48/p288/p288_006.wav
457 | wav48/p295/p295_034.wav
458 | wav48/p336/p336_200.wav
459 | wav48/p336/p336_104.wav
460 | wav48/p295/p295_090.wav
461 | wav48/p311/p311_145.wav
462 | wav48/p264/p264_266.wav
463 | wav48/p318/p318_381.wav
464 | wav48/p238/p238_206.wav
465 | wav48/p313/p313_409.wav
466 | wav48/p272/p272_285.wav
467 | wav48/p299/p299_193.wav
468 | wav48/p279/p279_040.wav
469 | wav48/p311/p311_042.wav
470 | wav48/p313/p313_035.wav
471 | wav48/p281/p281_196.wav
472 | wav48/p274/p274_083.wav
473 | wav48/p345/p345_099.wav
474 | wav48/p230/p230_069.wav
475 | wav48/p226/p226_163.wav
476 | wav48/p275/p275_411.wav
477 | wav48/p364/p364_266.wav
478 | wav48/p286/p286_404.wav
479 | wav48/p364/p364_054.wav
480 | wav48/p244/p244_147.wav
481 | wav48/p255/p255_073.wav
482 | wav48/p298/p298_024.wav
483 | wav48/p329/p329_028.wav
484 | wav48/p314/p314_128.wav
485 | wav48/p273/p273_066.wav
486 | wav48/p360/p360_325.wav
487 | wav48/p283/p283_434.wav
488 | wav48/p281/p281_197.wav
489 | wav48/p279/p279_359.wav
490 | wav48/p252/p252_162.wav
491 | wav48/p276/p276_242.wav
492 | wav48/p252/p252_191.wav
493 | wav48/p312/p312_279.wav
494 | wav48/p308/p308_337.wav
495 | wav48/p288/p288_188.wav
496 | wav48/p272/p272_044.wav
497 | wav48/p260/p260_284.wav
498 | wav48/p316/p316_224.wav
499 | wav48/p298/p298_355.wav
500 | wav48/p230/p230_079.wav
501 | wav48/p374/p374_059.wav
502 | wav48/p247/p247_399.wav
503 | wav48/p245/p245_067.wav
504 | wav48/p250/p250_423.wav
505 | wav48/p363/p363_204.wav
506 | wav48/p318/p318_175.wav
507 | wav48/p248/p248_163.wav
508 | wav48/p248/p248_293.wav
509 | wav48/p279/p279_133.wav
510 | wav48/p244/p244_359.wav
511 | wav48/p239/p239_176.wav
512 | wav48/p335/p335_423.wav
513 | wav48/p232/p232_030.wav
514 | wav48/p314/p314_219.wav
515 | wav48/p334/p334_103.wav
516 | wav48/p252/p252_099.wav
517 | wav48/p288/p288_409.wav
518 | wav48/p288/p288_316.wav
519 | wav48/p288/p288_030.wav
520 | wav48/p340/p340_187.wav
521 | wav48/p228/p228_272.wav
522 | wav48/p230/p230_044.wav
523 | wav48/p245/p245_022.wav
524 | wav48/p266/p266_222.wav
525 | wav48/p240/p240_138.wav
526 | wav48/p334/p334_062.wav
527 | wav48/p363/p363_220.wav
528 | wav48/p298/p298_011.wav
529 | wav48/p227/p227_258.wav
530 | wav48/p360/p360_114.wav
531 | wav48/p254/p254_283.wav
532 | wav48/p236/p236_065.wav
533 | wav48/p278/p278_139.wav
534 | wav48/p298/p298_270.wav
535 | wav48/p229/p229_134.wav
536 | wav48/p339/p339_218.wav
537 | wav48/p312/p312_358.wav
538 | wav48/p284/p284_079.wav
539 | wav48/p225/p225_005.wav
540 | wav48/p282/p282_356.wav
541 | wav48/p376/p376_257.wav
542 | wav48/p227/p227_200.wav
543 | wav48/p268/p268_276.wav
544 | wav48/p268/p268_245.wav
545 | wav48/p229/p229_053.wav
546 | wav48/p231/p231_332.wav
547 | wav48/p272/p272_077.wav
548 | wav48/p264/p264_053.wav
549 | wav48/p265/p265_038.wav
550 | wav48/p303/p303_335.wav
551 | wav48/p264/p264_100.wav
552 | wav48/p351/p351_042.wav
553 | wav48/p253/p253_336.wav
554 | wav48/p225/p225_324.wav
555 | wav48/p277/p277_307.wav
556 | wav48/p238/p238_459.wav
557 | wav48/p311/p311_320.wav
558 | wav48/p271/p271_322.wav
559 | wav48/p285/p285_039.wav
560 | wav48/p266/p266_145.wav
561 | wav48/p361/p361_383.wav
562 | wav48/p237/p237_285.wav
563 | wav48/p310/p310_386.wav
564 | wav48/p240/p240_281.wav
565 | wav48/p232/p232_369.wav
566 | wav48/p362/p362_231.wav
567 | wav48/p248/p248_080.wav
568 | wav48/p255/p255_257.wav
569 | wav48/p340/p340_279.wav
570 | wav48/p323/p323_174.wav
571 | wav48/p275/p275_060.wav
572 | wav48/p364/p364_083.wav
573 | wav48/p261/p261_462.wav
574 | wav48/p226/p226_221.wav
575 | wav48/p275/p275_112.wav
576 | wav48/p288/p288_211.wav
577 | wav48/p297/p297_049.wav
578 | wav48/p263/p263_160.wav
579 | wav48/p347/p347_259.wav
580 | wav48/p299/p299_367.wav
581 | wav48/p339/p339_297.wav
582 | wav48/p276/p276_186.wav
583 | wav48/p268/p268_330.wav
584 | wav48/p312/p312_232.wav
585 | wav48/p312/p312_262.wav
586 | wav48/p238/p238_389.wav
587 | wav48/p234/p234_052.wav
588 | wav48/p294/p294_406.wav
589 | wav48/p318/p318_335.wav
590 | wav48/p333/p333_176.wav
591 | wav48/p240/p240_304.wav
592 | wav48/p330/p330_230.wav
593 | wav48/p347/p347_245.wav
594 | wav48/p326/p326_368.wav
595 | wav48/p256/p256_025.wav
596 | wav48/p282/p282_123.wav
597 | wav48/p341/p341_316.wav
598 | wav48/p282/p282_292.wav
599 | wav48/p250/p250_385.wav
600 | wav48/p306/p306_352.wav
601 | wav48/p292/p292_371.wav
602 | wav48/p259/p259_291.wav
603 | wav48/p317/p317_397.wav
604 | wav48/p275/p275_181.wav
605 | wav48/p253/p253_334.wav
606 | wav48/p361/p361_055.wav
607 | wav48/p304/p304_406.wav
608 | wav48/p351/p351_295.wav
609 | wav48/p278/p278_164.wav
610 | wav48/p236/p236_091.wav
611 | wav48/p318/p318_169.wav
612 | wav48/p236/p236_386.wav
613 | wav48/p311/p311_195.wav
614 | wav48/p294/p294_376.wav
615 | wav48/p270/p270_219.wav
616 | wav48/p276/p276_265.wav
617 | wav48/p316/p316_374.wav
618 | wav48/p302/p302_243.wav
619 | wav48/p229/p229_177.wav
620 | wav48/p315/p315_309.wav
621 | wav48/p265/p265_331.wav
622 | wav48/p323/p323_353.wav
623 | wav48/p248/p248_286.wav
624 | wav48/p303/p303_205.wav
625 | wav48/p307/p307_012.wav
626 | wav48/p249/p249_295.wav
627 | wav48/p286/p286_235.wav
628 | wav48/p292/p292_386.wav
629 | wav48/p260/p260_105.wav
630 | wav48/p250/p250_137.wav
631 | wav48/p329/p329_372.wav
632 | wav48/p329/p329_342.wav
633 | wav48/p310/p310_169.wav
634 | wav48/p310/p310_228.wav
635 | wav48/p302/p302_017.wav
636 | wav48/p226/p226_136.wav
637 | wav48/p343/p343_145.wav
638 | wav48/p335/p335_069.wav
639 | wav48/p335/p335_143.wav
640 | wav48/p262/p262_312.wav
641 | wav48/p263/p263_334.wav
642 | wav48/p231/p231_421.wav
643 | wav48/p263/p263_385.wav
644 | wav48/p343/p343_104.wav
645 | wav48/p364/p364_050.wav
646 | wav48/p236/p236_200.wav
647 | wav48/p233/p233_240.wav
648 | wav48/p317/p317_299.wav
649 | wav48/p313/p313_089.wav
650 | wav48/p284/p284_050.wav
651 | wav48/p284/p284_056.wav
652 | wav48/p266/p266_272.wav
653 | wav48/p362/p362_008.wav
654 | wav48/p265/p265_323.wav
655 | wav48/p239/p239_390.wav
656 | wav48/p313/p313_027.wav
657 | wav48/p347/p347_334.wav
658 | wav48/p312/p312_077.wav
659 | wav48/p307/p307_305.wav
660 | wav48/p347/p347_073.wav
661 | wav48/p278/p278_349.wav
662 | wav48/p266/p266_320.wav
663 | wav48/p292/p292_344.wav
664 | wav48/p326/p326_023.wav
665 | wav48/p257/p257_301.wav
666 | wav48/p312/p312_273.wav
667 | wav48/p311/p311_024.wav
668 | wav48/p254/p254_276.wav
669 | wav48/p237/p237_235.wav
670 | wav48/p250/p250_141.wav
671 | wav48/p343/p343_177.wav
672 | wav48/p245/p245_354.wav
673 | wav48/p250/p250_302.wav
674 | wav48/p362/p362_360.wav
675 | wav48/p241/p241_064.wav
676 | wav48/p255/p255_248.wav
677 | wav48/p310/p310_042.wav
678 | wav48/p314/p314_222.wav
679 | wav48/p256/p256_115.wav
680 | wav48/p254/p254_249.wav
681 | wav48/p251/p251_281.wav
682 | wav48/p257/p257_074.wav
683 | wav48/p312/p312_272.wav
684 | wav48/p270/p270_008.wav
685 | wav48/p262/p262_171.wav
686 | wav48/p361/p361_176.wav
687 | wav48/p256/p256_258.wav
688 | wav48/p234/p234_193.wav
689 | wav48/p265/p265_343.wav
690 | wav48/p287/p287_338.wav
691 | wav48/p295/p295_294.wav
692 | wav48/p298/p298_362.wav
693 | wav48/p238/p238_395.wav
694 | wav48/p318/p318_067.wav
695 | wav48/p283/p283_249.wav
696 | wav48/p300/p300_148.wav
697 | wav48/p261/p261_422.wav
698 | wav48/p261/p261_272.wav
699 | wav48/p244/p244_323.wav
700 | wav48/p317/p317_055.wav
701 | wav48/p267/p267_246.wav
702 | wav48/p360/p360_043.wav
703 | wav48/p274/p274_196.wav
704 | wav48/p256/p256_087.wav
705 | wav48/p285/p285_379.wav
706 | wav48/p258/p258_069.wav
707 | wav48/p277/p277_414.wav
708 | wav48/p255/p255_118.wav
709 | wav48/p280/p280_392.wav
710 | wav48/p374/p374_325.wav
711 | wav48/p336/p336_322.wav
712 | wav48/p361/p361_399.wav
713 | wav48/p292/p292_407.wav
714 | wav48/p351/p351_350.wav
715 | wav48/p294/p294_285.wav
716 | wav48/p259/p259_374.wav
717 | wav48/p351/p351_297.wav
718 | wav48/p280/p280_028.wav
719 | wav48/p310/p310_144.wav
720 | wav48/p245/p245_203.wav
721 | wav48/p304/p304_326.wav
722 | wav48/p266/p266_382.wav
723 | wav48/p343/p343_288.wav
724 | wav48/p284/p284_398.wav
725 | wav48/p283/p283_120.wav
726 | wav48/p268/p268_297.wav
727 | wav48/p303/p303_284.wav
728 | wav48/p251/p251_116.wav
729 | wav48/p364/p364_071.wav
730 | wav48/p268/p268_283.wav
731 | wav48/p323/p323_045.wav
732 | wav48/p243/p243_019.wav
733 | wav48/p248/p248_306.wav
734 | wav48/p297/p297_256.wav
735 | wav48/p250/p250_459.wav
736 | wav48/p266/p266_062.wav
737 | wav48/p340/p340_394.wav
738 | wav48/p240/p240_338.wav
739 | wav48/p281/p281_273.wav
740 | wav48/p245/p245_112.wav
741 | wav48/p345/p345_015.wav
742 | wav48/p259/p259_014.wav
743 | wav48/p303/p303_065.wav
744 | wav48/p275/p275_161.wav
745 | wav48/p230/p230_374.wav
746 | wav48/p269/p269_043.wav
747 | wav48/p279/p279_251.wav
748 | wav48/p341/p341_306.wav
749 | wav48/p316/p316_401.wav
750 | wav48/p273/p273_231.wav
751 | wav48/p284/p284_420.wav
752 | wav48/p247/p247_048.wav
753 | wav48/p361/p361_375.wav
754 | wav48/p330/p330_143.wav
755 | wav48/p240/p240_102.wav
756 | wav48/p230/p230_296.wav
757 | wav48/p258/p258_400.wav
758 | wav48/p245/p245_015.wav
759 | wav48/p323/p323_349.wav
760 | wav48/p233/p233_155.wav
761 | wav48/p341/p341_039.wav
762 | wav48/p360/p360_200.wav
763 | wav48/p308/p308_421.wav
764 | wav48/p317/p317_401.wav
765 | wav48/p247/p247_180.wav
766 | wav48/p240/p240_088.wav
767 | wav48/p287/p287_303.wav
768 | wav48/p314/p314_319.wav
769 | wav48/p255/p255_235.wav
770 | wav48/p287/p287_268.wav
771 | wav48/p306/p306_320.wav
772 | wav48/p231/p231_247.wav
773 | wav48/p316/p316_156.wav
774 | wav48/p310/p310_416.wav
775 | wav48/p288/p288_037.wav
776 | wav48/p228/p228_204.wav
777 | wav48/p297/p297_168.wav
778 | wav48/p340/p340_177.wav
779 | wav48/p230/p230_190.wav
780 | wav48/p330/p330_278.wav
781 | wav48/p336/p336_041.wav
782 | wav48/p275/p275_115.wav
783 | wav48/p306/p306_234.wav
784 | wav48/p237/p237_008.wav
785 | wav48/p263/p263_194.wav
786 | wav48/p240/p240_219.wav
787 | wav48/p361/p361_043.wav
788 | wav48/p299/p299_109.wav
789 | wav48/p305/p305_111.wav
790 | wav48/p341/p341_372.wav
791 | wav48/p260/p260_321.wav
792 | wav48/p247/p247_373.wav
793 | wav48/p335/p335_369.wav
794 | wav48/p308/p308_311.wav
795 | wav48/p295/p295_068.wav
796 | wav48/p292/p292_374.wav
797 | wav48/p277/p277_411.wav
798 | wav48/p294/p294_263.wav
799 | wav48/p276/p276_233.wav
800 | wav48/p360/p360_281.wav
801 | wav48/p278/p278_345.wav
802 | wav48/p269/p269_294.wav
803 | wav48/p305/p305_420.wav
804 | wav48/p227/p227_137.wav
805 | wav48/p229/p229_329.wav
806 | wav48/p360/p360_366.wav
807 | wav48/p363/p363_275.wav
808 | wav48/p287/p287_022.wav
809 | wav48/p249/p249_249.wav
810 | wav48/p295/p295_318.wav
811 | wav48/p244/p244_067.wav
812 | wav48/p310/p310_303.wav
813 | wav48/p302/p302_285.wav
814 | wav48/p260/p260_062.wav
815 | wav48/p277/p277_064.wav
816 | wav48/p326/p326_157.wav
817 | wav48/p259/p259_003.wav
818 | wav48/p275/p275_168.wav
819 | wav48/p272/p272_243.wav
820 | wav48/p304/p304_164.wav
821 | wav48/p314/p314_119.wav
822 | wav48/p241/p241_067.wav
823 | wav48/p282/p282_243.wav
824 | wav48/p305/p305_178.wav
825 | wav48/p272/p272_136.wav
826 | wav48/p271/p271_138.wav
827 | wav48/p282/p282_144.wav
828 | wav48/p299/p299_359.wav
829 | wav48/p363/p363_388.wav
830 | wav48/p281/p281_417.wav
831 | wav48/p239/p239_010.wav
832 | wav48/p295/p295_041.wav
833 | wav48/p300/p300_157.wav
834 | wav48/p275/p275_157.wav
835 | wav48/p245/p245_062.wav
836 | wav48/p292/p292_203.wav
837 | wav48/p255/p255_301.wav
838 | wav48/p334/p334_029.wav
839 | wav48/p313/p313_305.wav
840 | wav48/p293/p293_351.wav
841 | wav48/p334/p334_324.wav
842 | wav48/p274/p274_114.wav
843 | wav48/p266/p266_034.wav
844 | wav48/p307/p307_230.wav
845 | wav48/p295/p295_403.wav
846 | wav48/p227/p227_257.wav
847 | wav48/p311/p311_334.wav
848 | wav48/p288/p288_227.wav
849 | wav48/p374/p374_399.wav
850 | wav48/p301/p301_386.wav
851 | wav48/p264/p264_260.wav
852 | wav48/p333/p333_068.wav
853 | wav48/p364/p364_095.wav
854 | wav48/p286/p286_082.wav
855 | wav48/p281/p281_206.wav
856 | wav48/p240/p240_060.wav
857 | wav48/p263/p263_118.wav
858 | wav48/p245/p245_066.wav
859 | wav48/p251/p251_345.wav
860 | wav48/p270/p270_400.wav
861 | wav48/p264/p264_019.wav
862 | wav48/p278/p278_301.wav
863 | wav48/p277/p277_086.wav
864 | wav48/p257/p257_115.wav
865 | wav48/p276/p276_039.wav
866 | wav48/p329/p329_320.wav
867 | wav48/p278/p278_021.wav
868 | wav48/p268/p268_298.wav
869 | wav48/p329/p329_170.wav
870 | wav48/p236/p236_169.wav
871 | wav48/p340/p340_198.wav
872 | wav48/p306/p306_340.wav
873 | wav48/p285/p285_351.wav
874 | wav48/p247/p247_017.wav
875 | wav48/p287/p287_005.wav
876 | wav48/p288/p288_019.wav
877 | wav48/p230/p230_144.wav
878 | wav48/p376/p376_262.wav
879 | wav48/p312/p312_286.wav
880 | wav48/p258/p258_002.wav
881 | wav48/p246/p246_155.wav
882 | wav48/p268/p268_275.wav
883 | wav48/p259/p259_370.wav
884 | wav48/p298/p298_235.wav
885 | wav48/p343/p343_016.wav
886 | wav48/p323/p323_282.wav
887 | wav48/p288/p288_058.wav
888 | wav48/p260/p260_119.wav
889 | wav48/p229/p229_274.wav
890 | wav48/p318/p318_399.wav
891 | wav48/p306/p306_260.wav
892 | wav48/p313/p313_109.wav
893 | wav48/p314/p314_318.wav
894 | wav48/p226/p226_314.wav
895 | wav48/p261/p261_338.wav
896 | wav48/p247/p247_090.wav
897 | wav48/p266/p266_117.wav
898 | wav48/p279/p279_038.wav
899 | wav48/p276/p276_352.wav
900 | wav48/p268/p268_399.wav
901 | wav48/p363/p363_314.wav
902 | wav48/p277/p277_202.wav
903 | wav48/p340/p340_337.wav
904 | wav48/p288/p288_049.wav
905 | wav48/p232/p232_267.wav
906 | wav48/p294/p294_167.wav
907 | wav48/p275/p275_179.wav
908 | wav48/p312/p312_124.wav
909 | wav48/p284/p284_279.wav
910 | wav48/p267/p267_048.wav
911 | wav48/p335/p335_043.wav
912 | wav48/p265/p265_187.wav
913 | wav48/p272/p272_229.wav
914 | wav48/p347/p347_355.wav
915 | wav48/p294/p294_048.wav
916 | wav48/p298/p298_219.wav
917 | wav48/p326/p326_237.wav
918 | wav48/p304/p304_194.wav
919 | wav48/p265/p265_352.wav
920 | wav48/p275/p275_147.wav
921 | wav48/p276/p276_276.wav
922 | wav48/p231/p231_152.wav
923 | wav48/p231/p231_012.wav
924 | wav48/p265/p265_308.wav
925 | wav48/p343/p343_186.wav
926 | wav48/p225/p225_319.wav
927 | wav48/p232/p232_242.wav
928 | wav48/p293/p293_330.wav
929 | wav48/p308/p308_053.wav
930 | wav48/p279/p279_162.wav
931 | wav48/p263/p263_301.wav
932 | wav48/p297/p297_193.wav
933 | wav48/p297/p297_206.wav
934 | wav48/p292/p292_298.wav
935 | wav48/p263/p263_388.wav
936 | wav48/p280/p280_255.wav
937 | wav48/p283/p283_453.wav
938 | wav48/p287/p287_092.wav
939 | wav48/p234/p234_241.wav
940 | wav48/p255/p255_369.wav
941 | wav48/p323/p323_086.wav
942 | wav48/p304/p304_203.wav
943 | wav48/p259/p259_233.wav
944 | wav48/p285/p285_244.wav
945 | wav48/p255/p255_352.wav
946 | wav48/p269/p269_393.wav
947 | wav48/p244/p244_091.wav
948 | wav48/p229/p229_236.wav
949 | wav48/p231/p231_300.wav
950 | wav48/p258/p258_100.wav
951 | wav48/p273/p273_267.wav
952 | wav48/p283/p283_158.wav
953 | wav48/p292/p292_085.wav
954 | wav48/p347/p347_243.wav
955 | wav48/p272/p272_335.wav
956 | wav48/p312/p312_363.wav
957 | wav48/p345/p345_125.wav
958 | wav48/p376/p376_163.wav
959 | wav48/p262/p262_313.wav
960 | wav48/p240/p240_345.wav
961 | wav48/p292/p292_293.wav
962 | wav48/p263/p263_275.wav
963 | wav48/p360/p360_313.wav
964 | wav48/p251/p251_078.wav
965 | wav48/p294/p294_041.wav
966 | wav48/p246/p246_164.wav
967 | wav48/p335/p335_343.wav
968 | wav48/p301/p301_157.wav
969 | wav48/p281/p281_230.wav
970 | wav48/p329/p329_098.wav
971 | wav48/p263/p263_123.wav
972 | wav48/p240/p240_051.wav
973 | wav48/p283/p283_004.wav
974 | wav48/p316/p316_037.wav
975 | wav48/p351/p351_283.wav
976 | wav48/p268/p268_093.wav
977 | wav48/p267/p267_290.wav
978 | wav48/p228/p228_092.wav
979 | wav48/p265/p265_102.wav
980 | wav48/p286/p286_067.wav
981 | wav48/p274/p274_064.wav
982 | wav48/p374/p374_292.wav
983 | wav48/p288/p288_107.wav
984 | wav48/p341/p341_249.wav
985 | wav48/p226/p226_228.wav
986 | wav48/p351/p351_357.wav
987 | wav48/p361/p361_346.wav
988 | wav48/p334/p334_184.wav
989 | wav48/p316/p316_378.wav
990 | wav48/p258/p258_276.wav
991 | wav48/p276/p276_030.wav
992 | wav48/p266/p266_115.wav
993 | wav48/p335/p335_002.wav
994 | wav48/p232/p232_412.wav
995 | wav48/p232/p232_406.wav
996 | wav48/p310/p310_384.wav
997 | wav48/p227/p227_153.wav
998 | wav48/p363/p363_244.wav
999 | wav48/p335/p335_027.wav
1000 | wav48/p336/p336_388.wav
1001 |
--------------------------------------------------------------------------------
/models/mdct.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Callable
2 | from torch_scatter import scatter
3 | from einops import rearrange
4 | import numpy as np
5 | import torch
6 | import torch.nn
7 | import torch.fft
8 | from torch.nn.functional import pad, fold
9 | from . import spectrogram
10 |
11 | import functools
12 | #import debugpy
13 | #debugpy.listen(("localhost", 5678))
14 |
15 |
16 | class MDCT(torch.nn.Module):
17 | """
18 | Serial version of the MDCT.
19 | Adapted from https://github.com/nils-werner/mdct/
20 | I modified the original code from scipy+numpy to torch, and I made it into torch layers with GPU support.
21 | """
22 |
23 | def __init__(self, window_function, step_length=None, n_fft=2048, center=True, device='cpu') -> None:
24 | super().__init__()
25 | self.window_function = window_function
26 | self.step_length = step_length
27 | self.n_fft = n_fft
28 | self.center = center
29 | self.device = device
30 |
31 | def mdct(
32 | self,
33 | x,
34 | odd=True,
35 | center=True,
36 | **kwargs
37 | ):
38 | """ Calculate lapped MDCT of input signal
39 | Parameters
40 | ----------
41 | x : array_like
42 | The signal to be transformed. May be a 1D vector for single channel or
43 | a 2D matrix for multi channel data. In case of a mono signal, the data
44 | is must be a 1D vector of length :code:`samples`. In case of a multi
45 | channel signal, the data must be in the shape of :code:`samples x
46 | channels`.
47 | odd : boolean, optional
48 | Switch to oddly stacked transform. Defaults to :code:`True`.
49 | framelength : int
50 | The signal frame length. Defaults to :code:`2048`.
51 | hopsize : int
52 | The signal frame hopsize. Defaults to :code:`None`. Setting this
53 | value will override :code:`overlap`.
54 | overlap : int
55 | The signal frame overlap coefficient. Value :code:`x` means
56 | :code:`1/x` overlap. Defaults to :code:`2`. Note that anything but
57 | :code:`2` will result in a filterbank without perfect reconstruction.
58 | centered : boolean
59 | Pad input signal so that the first and last window are centered around
60 | the beginning of the signal. Defaults to :code:`True`.
61 | Disabling this will result in aliasing
62 | in the first and last half-frame.
63 | window : callable, array_like
64 | Window to be used for deringing. Can be :code:`False` to disable
65 | windowing. Defaults to :code:`scipy.signal.cosine`.
66 | transforms : module, optional
67 | Module reference to core transforms. Mostly used to replace
68 | fast with slow core transforms, for testing. Defaults to
69 | :mod:`mdct.fast`
70 | padding : int
71 | Zero-pad signal with x times the number of samples.
72 | Defaults to :code:`0`.
73 | save_settings : boolean
74 | Save settings used here in attribute :code:`out.stft_settings` so that
75 | :func:`ispectrogram` can infer these settings without the developer
76 | having to pass them again.
77 | Returns
78 | -------
79 | out : array_like
80 | The signal (or matrix of signals). In case of a mono output signal, the
81 | data is formatted as a 1D vector of length :code:`samples`. In case of
82 | a multi channel output signal, the data is formatted as :code:`samples
83 | x channels`.
84 | See Also
85 | --------
86 | mdct.fast.transforms.mdct : MDCT
87 | """
88 | def cmdct(x, odd=True):
89 | """ Calculate complex MDCT/MCLT of input signal
90 | Parameters
91 | ----------
92 | x : array_like
93 | The input signal
94 | odd : boolean, optional
95 | Switch to oddly stacked transform. Defaults to :code:`True`.
96 | Returns
97 | -------
98 | out : array_like
99 | The output signal
100 | """
101 | N = len(x) // 2
102 | n0 = (N + 1) / 2
103 | if odd:
104 | outlen = N
105 | pre_twiddle = torch.exp(-1j * np.pi *
106 | torch.arange(N * 2) / (N * 2)).to(self.device)
107 | offset = 0.5
108 | else:
109 | outlen = N + 1
110 | pre_twiddle = 1.0
111 | offset = 0.0
112 |
113 | post_twiddle = torch.exp(
114 | -1j * np.pi * n0 * (torch.arange(outlen) + offset) / N
115 | ).to(self.device)
116 |
117 | X = torch.fft.fft(x * pre_twiddle)[:outlen]
118 |
119 | if not odd:
120 | X[0] *= np.sqrt(0.5)
121 | X[-1] *= np.sqrt(0.5)
122 |
123 | return X * post_twiddle * np.sqrt(1 / N)
124 |
125 | def _mdct(x, odd=True):
126 | """ Calculate modified discrete cosine transform of input signal
127 | Parameters
128 | ----------
129 | X : array_like
130 | The input signal
131 | odd : boolean, optional
132 | Switch to oddly stacked transform. Defaults to :code:`True`.
133 | Returns
134 | -------
135 | out : array_like
136 | The output signal
137 | """
138 | return torch.real(cmdct(x, odd=odd)) * np.sqrt(2)
139 |
140 | def _mdst(x, odd=True):
141 | """ Calculate modified discrete sine transform of input signal
142 | Parameters
143 | ----------
144 | X : array_like
145 | The input signal
146 | odd : boolean, optional
147 | Switch to oddly stacked transform. Defaults to :code:`True`.
148 | Returns
149 | -------
150 | out : array_like
151 | The output signal
152 | """
153 | return -1 * torch.imag(cmdct(x, odd=odd)) * np.sqrt(2)
154 |
155 | frame_length = len(self.window_function)
156 |
157 | if not odd:
158 | return spectrogram.spectrogram(
159 | x,
160 | transform=[
161 | functools.partial(_mdct, odd=False),
162 | functools.partial(_mdst, odd=False),
163 | ],
164 | halved=False,
165 | frame_length=frame_length,
166 | **kwargs
167 | )
168 | else:
169 | return spectrogram.spectrogram(
170 | x,
171 | transform=_mdct,
172 | halved=False,
173 | frame_length=frame_length,
174 | **kwargs
175 | )
176 |
177 | def forward(self, x):
178 | x = self.mdct(x=x, window_function=self.window_function,
179 | step_length=self.step_length, n_fft=self.n_fft, center=self.center, padding=0)
180 | return x
181 |
182 |
183 | class IMDCT(torch.nn.Module):
184 | def __init__(self, window_function, step_length=None, device='cuda', n_fft=2048, out_length=48000, center=True):
185 | super().__init__()
186 | self.window_function = window_function
187 | self.step_length = step_length
188 | self.n_fft = n_fft
189 | self.out_length = out_length
190 | self.device = device
191 |
192 | def imdct(
193 | self,
194 | X,
195 | odd=True,
196 | **kwargs
197 | ):
198 | """ Calculate lapped inverse MDCT of input signal
199 | Parameters
200 | ----------
201 | x : array_like
202 | The spectrogram to be inverted. May be a 2D matrix for single channel
203 | or a 3D tensor for multi channel data. In case of a mono signal, the
204 | data must be in the shape of :code:`bins x frames`. In case of a multi
205 | channel signal, the data must be in the shape of :code:`bins x frames x
206 | channels`.
207 | odd : boolean, optional
208 | Switch to oddly stacked transform. Defaults to :code:`True`.
209 | framelength : int
210 | The signal frame length. Defaults to infer from data.
211 | hopsize : int
212 | The signal frame hopsize. Defaults to infer from data. Setting this
213 | value will override :code:`overlap`.
214 | overlap : int
215 | The signal frame overlap coefficient. Value :code:`x` means
216 | :code:`1/x` overlap. Defaults to infer from data. Note that anything
217 | but :code:`2` will result in a filterbank without perfect
218 | reconstruction.
219 | centered : boolean
220 | Pad input signal so that the first and last window are centered around
221 | the beginning of the signal. Defaults to to infer from data.
222 | The first and last half-frame will have aliasing, so using
223 | centering during forward MDCT is recommended.
224 | window : callable, array_like
225 | Window to be used for deringing. Can be :code:`False` to disable
226 | windowing. Defaults to to infer from data.
227 | halved : boolean
228 | Switch to reconstruct the other halve of the spectrum if the forward
229 | transform has been truncated. Defaults to to infer from data.
230 | transforms : module, optional
231 | Module reference to core transforms. Mostly used to replace
232 | fast with slow core transforms, for testing. Defaults to
233 | :mod:`mdct.fast`
234 | padding : int
235 | Zero-pad signal with x times the number of samples. Defaults to infer
236 | from data.
237 | outlength : int
238 | Crop output signal to length. Useful when input length of spectrogram
239 | did not fit into framelength and input data had to be padded. Not
240 | setting this value will disable cropping, the output data may be
241 | longer than expected.
242 | Returns
243 | -------
244 | out : array_like
245 | The output signal
246 | See Also
247 | --------
248 | mdct.fast.transforms.imdct : inverse MDCT
249 | """
250 | def icmdct(X, odd=True):
251 | """ Calculate inverse complex MDCT/MCLT of input signal
252 | Parameters
253 | ----------
254 | X : array_like
255 | The input signal
256 | odd : boolean, optional
257 | Switch to oddly stacked transform. Defaults to :code:`True`.
258 | Returns
259 | -------
260 | out : array_like
261 | The output signal
262 | """
263 | if not odd and len(X) % 2 == 0:
264 | raise ValueError(
265 | "Even inverse CMDCT requires an odd number "
266 | "of coefficients"
267 | )
268 |
269 | if odd:
270 | N = len(X)
271 | n0 = (N + 1) / 2
272 |
273 | post_twiddle = torch.exp(
274 | 1j * np.pi * (torch.arange(N * 2) + n0) / (N * 2)
275 | ).to(self.device)
276 |
277 | Y = torch.zeros(N * 2, dtype=X.dtype)
278 | Y[:N] = X
279 | #Y[N:] = -1 * torch.conj(X[::-1])
280 | Y[N:] = -1 * torch.conj(X.flip(dims=(0,)))
281 | else:
282 | N = len(X) - 1
283 | n0 = (N + 1) / 2
284 |
285 | post_twiddle = 1.0
286 |
287 | X[0] *= torch.sqrt(2)
288 | X[-1] *= torch.sqrt(2)
289 |
290 | Y = torch.zeros(N * 2, dtype=X.dtype)
291 | Y[:N+1] = X
292 | #Y[N+1:] = -1 * torch.conj(X[-2:0:-1])
293 | Y[N+1:] = -1 * torch.conj(X[:-2].flip(dims=(0,)))
294 |
295 | pre_twiddle = (torch.exp(1j * np.pi * n0 *
296 | torch.arange(N * 2) / N)).to(self.device)
297 |
298 | y = torch.fft.ifft(Y.to(self.device) * pre_twiddle)
299 |
300 | return torch.real(y * post_twiddle) * np.sqrt(N)
301 |
302 | def _imdct(X, odd=True):
303 | """ Calculate inverse modified discrete cosine transform of input signal
304 | Parameters
305 | ----------
306 | X : array_like
307 | The input signal
308 | odd : boolean, optional
309 | Switch to oddly stacked transform. Defaults to :code:`True`.
310 | Returns
311 | -------
312 | out : array_like
313 | The output signal
314 | """
315 | return icmdct(X, odd=odd) * np.sqrt(2)
316 |
317 | def _imdst(X, odd=True):
318 | """ Calculate inverse modified discrete sine transform of input signal
319 | Parameters
320 | ----------
321 | X : array_like
322 | The input signal
323 | odd : boolean, optional
324 | Switch to oddly stacked transform. Defaults to :code:`True`.
325 | Returns
326 | -------
327 | out : array_like
328 | The output signal
329 | """
330 | return -1 * icmdct(X * 1j, odd=odd) * np.sqrt(2)
331 |
332 | frame_length = len(self.window_function)
333 |
334 | if not odd:
335 | return spectrogram.ispectrogram(
336 | X,
337 | transform=[
338 | functools.partial(_imdct, odd=False),
339 | functools.partial(_imdst, odd=False),
340 | ],
341 | halved=False,
342 | **kwargs
343 | )
344 | else:
345 | return spectrogram.ispectrogram(
346 | X,
347 | transform=_imdct,
348 | halved=False,
349 | frame_length=frame_length,
350 | **kwargs
351 | )
352 |
353 | def forward(self, X):
354 | X = self.imdct(X=X, window_function=self.window_function, step_length=self.step_length,
355 | n_fft=self.n_fft, padding=0, out_length=self.out_length)
356 | return X
357 |
358 |
359 | class MDCT4(torch.nn.Module):
360 | """
361 | The exact version of the MDCT, using modified DCT-IV.
362 | Borrowed from MATLAB implementation.
363 | """
364 |
365 | def __init__(self, n_fft=2048, hop_length=None, win_length=None, window=None, center=True, pad_mode='constant', device='cuda') -> None:
366 | super().__init__()
367 | self.n_fft = n_fft
368 | self.pad_mode = pad_mode
369 | self.device = device
370 | self.hop_length = hop_length
371 | self.center = center
372 |
373 | # making window
374 | if window is None:
375 | window = torch.ones
376 | if callable(window):
377 | self.win_length = int(win_length)
378 | self.window = window(self.win_length).to(self.device)
379 | else:
380 | self.window = window.to(self.device)
381 | self.win_length = len(window)
382 |
383 | assert self.win_length <= self.n_fft, 'Window lenth %d should be no more than fft length %d' % (
384 | self.win_length, self.n_fft)
385 | assert self.hop_length <= self.win_length, 'You hopped more than one frame'
386 |
387 | self.exp1 = torch.exp(-1j*torch.pi/self.n_fft*torch.arange(start=0,
388 | end=self.n_fft, step=1, dtype=torch.float64, device=self.device))
389 | self.exp2 = torch.exp(-1j*(torch.pi/(2*self.n_fft)+torch.pi/4)*torch.arange(
390 | start=1, end=self.n_fft, step=2, dtype=torch.float64, device=self.device))
391 |
392 | def forward(self, signal, return_frames: bool = False):
393 | # Pad the signal to a proper length
394 | signal_len = int(len(signal))
395 | start_pad = 0
396 | # Pad the signal so that the t-th frame is centered at time t * hop_length. Otherwise, the t-th frame begins at time t * hop_length.
397 | if self.center:
398 | start_pad = self.hop_length
399 | additional_len = signal_len % self.hop_length
400 | end_pad = start_pad
401 | if additional_len:
402 | end_pad = start_pad + self.hop_length - additional_len
403 | signal = pad(signal, (start_pad, end_pad), mode=self.pad_mode)
404 |
405 | # Slice the signal with overlapping
406 | signal = signal.unfold(
407 | dimension=-1, size=self.win_length, step=self.hop_length)
408 |
409 | # Apply windows to each pieces
410 | signal = torch.mul(signal, self.window)
411 | if return_frames:
412 | frames = signal.clone()
413 | else:
414 | frames = torch.empty(1)
415 |
416 | # Pad zeros for DCT
417 | if self.n_fft > self.win_length:
418 | signal = pad(signal, (0, self.n_fft-self.win_length),
419 | mode='constant')
420 |
421 | signal = signal*self.exp1
422 | signal = torch.fft.fft(signal)[..., :self.n_fft//2]
423 | signal = torch.real(self.exp2*signal)
424 |
425 | return signal, frames
426 |
427 |
428 | class IMDCT4(torch.nn.Module):
429 | def __init__(self, n_fft=2048, hop_length=None, win_length=None, window=None, center=True, pad_mode='constant', out_length=None, device='cuda') -> None:
430 | super().__init__()
431 | self.n_fft = n_fft
432 | self.pad_mode = pad_mode
433 | self.device = device
434 | self.hop_length = hop_length
435 | self.center = center
436 | self.out_length = out_length
437 |
438 | # making window
439 | if window is None:
440 | window = torch.ones
441 | if callable(window):
442 | self.win_length = int(win_length)
443 | self.window = window(self.win_length).to(self.device)
444 | else:
445 | self.window = window.to(self.device)
446 | self.win_length = len(window)
447 |
448 | assert self.win_length <= self.n_fft, 'Window lenth %d should be no more than fft length %d' % (
449 | self.win_length, self.n_fft)
450 | assert self.hop_length <= self.win_length, 'You hopped more than one frame'
451 |
452 | self.exp1 = torch.exp(-1j*(torch.pi/(2*self.n_fft)+torch.pi/4)*torch.arange(
453 | start=1, end=self.n_fft, step=2, dtype=torch.float64, device=self.device))
454 | self.exp2 = torch.exp(-1j*torch.pi/(2*self.n_fft)*torch.arange(
455 | start=0, end=2*self.n_fft, step=2, dtype=torch.float64, device=self.device))
456 |
457 | def forward(self, signal, return_frames: bool = False):
458 | assert signal.dim() == 3, 'Only tensors shaped in BHW are supported, got tensor of shape %s' % (
459 | str(signal.size()))
460 | assert signal.size(
461 | )[-1] == self.n_fft//2, 'The last dim of input tensor should match the n_fft. Expected %d ,got %d' % (self.n_fft, signal.size()[-1])
462 |
463 | # Inverse transform at the last dim
464 | signal = self.exp1*signal
465 | signal = torch.fft.fft(signal, n=self.n_fft)
466 | signal = torch.real(signal*self.exp2)
467 |
468 | # Remove padded zeros when doing dct
469 | if self.n_fft > self.win_length:
470 | signal = signal[..., :self.win_length]
471 |
472 | # Apply windows to each pieces
473 | signal = torch.mul(signal, self.window)
474 | if return_frames:
475 | frames = signal.clone()
476 | else:
477 | frames = torch.zeros(1)
478 |
479 | # Overlapping adding by fold()
480 | out_len = (signal.size()[-2]-1) * self.hop_length + self.win_length
481 | signal = 4/self.n_fft*fold(signal.transpose_(-1, -2), kernel_size=(
482 | 1, self.win_length), stride=(1, self.hop_length), output_size=(1, out_len))
483 |
484 | if self.center:
485 | # extract the middle part
486 | signal = signal[..., self.win_length//2:-self.win_length//2]
487 | signal = signal if self.out_length is None else signal[...,
488 | :self.out_length]
489 | return signal, frames
490 |
491 |
492 | class FastMDCT4(torch.nn.Module):
493 | '''
494 | Fast implementation of MDCT. Port to PyTorch by Chenhao Shuai
495 | Ref:_https://ccrma.stanford.edu/~bosse/proj/node28.html
496 | Sporer T, Brandenburg K, Edler B, The Use of Multirate Filter Banks for Coding of High Quality Digital Audio, 6th European Signal Processing Conference (EUSIPCO), Amsterdam, June 1992, Vol.1 pp. 211-214.
497 | '''
498 | def __init__(self, n_fft: Optional[int] = 2048, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Union[torch.Tensor, np.ndarray, list, Callable, None] = None, center: bool = True, pad_mode: str = 'constant', device: str = 'cuda') -> None:
499 | super().__init__()
500 | self.n_fft = n_fft
501 | self.pad_mode = pad_mode
502 | self.device = device
503 | self.hop_length = hop_length
504 | self.center = center
505 |
506 | if callable(window):
507 | self.win_length = int(win_length)
508 | self.window = window(self.win_length).to(
509 | device=self.device, dtype=torch.float64)
510 | elif isinstance(window, torch.Tensor):
511 | self.window = window.to(device=self.device, dtype=torch.float64)
512 | self.win_length = len(window)
513 | elif isinstance(window, np.ndarray) or isinstance(window, list):
514 | self.window = torch.tensor(
515 | window, device=self.device, dtype=torch.float64)
516 | self.win_length = len(window)
517 | elif window is None:
518 | if win_length is not None:
519 | self.win_length = win_length
520 | elif n_fft is not None:
521 | self.win_length = n_fft
522 | else:
523 | assert False, 'You should specify window length or n_fft'
524 | self.window = torch.ones(
525 | (self.win_length,), device=self.device, dtype=torch.float64)
526 | else:
527 | raise NotImplementedError
528 |
529 | assert self.win_length <= self.n_fft, f'Window lenth {self.win_length} should be no more than fft length {self.n_fft}'
530 | assert self.hop_length <= self.win_length, 'You hopped more than one frame'
531 |
532 | self.idx = torch.stack((
533 | torch.arange(
534 | start=0, end=n_fft//2, step=2,
535 | dtype=torch.long, device=self.device),
536 | torch.arange(
537 | start=n_fft-1, end=n_fft//2, step=-2,
538 | dtype=torch.long, device=self.device),
539 | torch.arange(
540 | start=n_fft//2, end=n_fft, step=2,
541 | dtype=torch.long, device=self.device),
542 | torch.arange(
543 | start=n_fft//2-1, end=0, step=-2,
544 | dtype=torch.long, device=self.device)
545 | ), dim=0)
546 |
547 | # self.sqrtN = torch.sqrt(torch.tensor(
548 | # [self.n_fft], device=self.device, dtype=torch.float64))
549 | self.post_exp = torch.exp(
550 | -2j*torch.pi/self.n_fft*(
551 | torch.arange(
552 | start=0,
553 | end=self.n_fft//4,
554 | step=1,
555 | dtype=torch.float64,
556 | device=self.device
557 | )+1/8
558 | )
559 | ).to(torch.complex64)
560 |
561 | self.pre_exp = (self.make_pre_exp()*self.window).to(torch.complex64)
562 | self.pre_idx = self.make_pre_idx()
563 | self.post_idx = self.make_post_idx()
564 | # self.idx = self.idx.clone().mT.roll(self.n_fft//8,0).contiguous()
565 |
566 | def make_pre_exp(self):
567 | sgn = torch.ones(1, self.n_fft, dtype=torch.complex128,
568 | device=self.device)
569 | # Shift for Time-Domain Aliasing Cancellation (TDAC)
570 | sgn[..., -self.n_fft//4:] *= -1
571 | sgn = sgn.roll(self.n_fft//4, dims=-1)
572 | sgn[..., self.idx[0]] *= self.post_exp
573 | sgn[..., self.idx[1]] *= -self.post_exp
574 | sgn[..., self.idx[2]] *= -1j*self.post_exp
575 | sgn[..., self.idx[3]] *= 1j*self.post_exp
576 | return sgn.roll(-self.n_fft//4, dims=-1).contiguous()
577 |
578 | def make_pre_idx(self):
579 | i = torch.arange(start=0, end=self.n_fft, step=1,
580 | dtype=torch.long, device=self.device)
581 | i = i.roll(self.n_fft//4, dims=-1)
582 | idx_ = torch.stack([i[self.idx[0]], i[self.idx[1]],
583 | i[self.idx[2]], i[self.idx[3]]], dim=1)
584 | index = torch.zeros(
585 | 1, self.n_fft, device=self.device, dtype=torch.long)
586 | for i in torch.arange(0, self.n_fft//4, dtype=torch.long):
587 | index[..., idx_[i]] = i
588 | return index.squeeze().contiguous()
589 |
590 | def make_post_idx(self):
591 | idx = torch.arange(self.n_fft//2, dtype=torch.long,
592 | device=self.device).reshape(-1, 2)
593 | idx[:, 1] = idx[:, 1].flip(-1)
594 | return idx.flatten().contiguous()
595 |
596 | def forward(self, signal: torch.tensor, return_frames: bool = False):
597 | if signal.dim() == 2: # B T (mono)
598 | signal = signal[:, None, :]
599 | elif signal.dim() == 3: # B C T (stereo)
600 | pass
601 | else:
602 | raise NotImplementedError
603 |
604 | # Pad the signal to a proper length
605 | B, C, T = signal.shape
606 | start_pad = 0
607 | # Pad the signal so that the t-th frame is centered at time t * hop_length. Otherwise, the t-th frame begins at time t * hop_length.
608 | if self.center:
609 | start_pad = self.hop_length
610 | additional_len = T % self.hop_length
611 | end_pad = start_pad
612 | if additional_len:
613 | end_pad = start_pad + self.hop_length - additional_len
614 | signal = pad(signal, (start_pad, end_pad), mode=self.pad_mode)
615 |
616 | # Slice the signal with overlapping
617 | signal = signal.unfold(
618 | dimension=-1, size=self.win_length, step=self.hop_length)
619 | signal = signal*self.pre_exp
620 | signal = scatter(signal, self.pre_idx, dim=-1, reduce='sum')
621 | signal = torch.fft.fft(signal, dim=-1)
622 | # post-twiddle
623 | signal = torch.conj_physical(signal*self.post_exp)
624 | # rearranging
625 | signal = torch.view_as_real(signal)
626 | signal = signal.flatten(-2)[..., self.post_idx]
627 |
628 | return signal, None
629 |
630 |
631 | class FastIMDCT4(torch.nn.Module):
632 | def __init__(self, n_fft: Optional[int] = 2048, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Union[torch.Tensor, np.ndarray, list, Callable, None] = None, center: bool = True, pad_mode: str = 'constant', out_length: Optional[int] = None, device: str = 'cuda') -> None:
633 | super().__init__()
634 | self.n_fft = n_fft
635 | self.pad_mode = pad_mode
636 | self.device = device
637 | self.hop_length = hop_length
638 | self.center = center
639 | self.out_length = out_length
640 |
641 | if callable(window):
642 | self.win_length = int(win_length)
643 | self.window = window(self.win_length).to(
644 | device=self.device, dtype=torch.float64)
645 | elif isinstance(window, torch.Tensor):
646 | self.window = window.to(device=self.device, dtype=torch.float64)
647 | self.win_length = len(window)
648 | elif isinstance(window, np.ndarray) or isinstance(window, list):
649 | self.window = torch.tensor(
650 | window, device=self.device, dtype=torch.float64)
651 | self.win_length = len(window)
652 | elif isinstance(window, None):
653 | if win_length is not None:
654 | self.win_length = win_length
655 | elif n_fft is not None:
656 | self.win_length = n_fft
657 | else:
658 | assert False, 'You should specify window length or n_fft'
659 | self.window = torch.ones(
660 | (self.win_length,), device=self.device, dtype=torch.float64)
661 | else:
662 | raise NotImplementedError
663 |
664 | assert self.win_length <= self.n_fft, f'Window lenth {self.win_length} should be no more than fft length {self.n_fft}'
665 | assert self.hop_length <= self.win_length, 'You hopped more than one frame'
666 |
667 | self.exp = torch.exp(
668 | -2j*torch.pi/self.n_fft*(
669 | torch.arange(
670 | start=0,
671 | end=self.n_fft//4,
672 | step=1,
673 | dtype=torch.float32,
674 | device=self.device
675 | )+1/8
676 | )
677 | ).contiguous()
678 | self.pre_idx = self.make_pre_idx()
679 | self.post_idx = self.make_post_index()
680 | self.window = (4.0*self.make_sign()*self.window /
681 | self.n_fft).to(torch.float32).contiguous()
682 |
683 | def make_pre_idx(self):
684 | a = torch.arange(self.n_fft//2, dtype=torch.long,
685 | device=self.device).unfold(-1, 2, 2)
686 | return torch.stack((a[..., 0], a[..., 1].flip(-1)), dim=-1).contiguous()
687 |
688 | def make_post_index(self):
689 | a = torch.arange(0, self.n_fft//2, 2,
690 | dtype=torch.long, device=self.device)
691 | b = torch.arange(self.n_fft//2-1, 0, -2,
692 | dtype=torch.long, device=self.device)
693 | idx = torch.empty((self.n_fft,), dtype=torch.long, device=self.device)
694 | idx[0:self.n_fft//2:2] = a
695 | idx[1:self.n_fft//2:2] = b
696 | idx[self.n_fft//2:] = idx[:self.n_fft//2].flip(0)
697 | return idx.roll(-self.n_fft//4).contiguous()
698 |
699 | def make_sign(self):
700 | sign = torch.ones((self.n_fft,), device=self.device,
701 | dtype=torch.float64)
702 | sign[1::2] *= -1
703 | sign[..., 0:self.n_fft//4] *= -1
704 | return sign.roll(-self.n_fft//4).contiguous()
705 |
706 | def forward(self, signal: torch.Tensor, return_frames: bool = False):
707 | assert signal.dim(
708 | ) <= 4, f'Only tensors shaped in BHW or BCHW are supported, got tensor of shape {signal.shape}'
709 | assert signal.shape[
710 | -1] == self.n_fft//2, f'The last dim of input tensor should match the n_fft. Expected {self.n_fft}, got {signal.shape[-1]}'
711 |
712 | if signal.dim() == 4:
713 | C = signal.shape[1]
714 | signal = rearrange(signal, 'B C T N -> (B C) T N')
715 | else:
716 | C = 1
717 |
718 | signal = signal.to(self.device)
719 | # # Inverse transform at the last dim
720 | signal = torch.view_as_complex(signal[..., self.pre_idx])
721 |
722 | signal = self.exp*signal
723 | signal = torch.fft.fft(signal)
724 | signal = self.exp*signal
725 |
726 | # [0+4j, 1+5j, 2+6j, 3+7j] -> [2,-5, 3, -4, 4, -3, 5, -2, 6, -1, 7, 0, 0, 7, -1, 6]
727 | signal = torch.view_as_real(signal).flatten(-2)[..., self.post_idx]
728 |
729 | # Apply windows to each pieces
730 | signal = self.window*signal
731 | if return_frames:
732 | frames = signal.clone()
733 | else:
734 | frames = torch.empty(1)
735 |
736 | # Overlapping adding by fold()
737 | out_len = (signal.shape[-2]-1) * self.hop_length + self.win_length
738 | signal = fold(signal.mT, kernel_size=(1, self.win_length),
739 | stride=(1, self.hop_length), output_size=(1, out_len))
740 |
741 | if self.center: # extract the middle part
742 | signal = signal[..., self.win_length//2:-self.win_length//2]
743 | if self.out_length is not None:
744 | signal = signal[..., :self.out_length]
745 | if C != 1:
746 | signal = rearrange(signal, '(B C) T N-> B C T N')
747 | return signal, frames
748 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch
3 | import torch.nn as nn
4 | import functools
5 | import numpy as np
6 | from torch.nn.functional import interpolate, pad
7 |
8 | ###############################################################################
9 | # Functions
10 | ###############################################################################
11 |
12 |
13 | def weights_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv2d') != -1:
16 | m.weight.data.normal_(0.0, 0.02)
17 | elif classname.find('BatchNorm2d') != -1:
18 | m.weight.data.normal_(1.0, 0.02)
19 | m.bias.data.fill_(0)
20 |
21 |
22 | def get_norm_layer(norm_type='instance'):
23 | if norm_type == 'batch':
24 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
25 | elif norm_type == 'instance':
26 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
27 | else:
28 | raise NotImplementedError(
29 | 'normalization layer [%s] is not found' % norm_type)
30 | return norm_layer
31 |
32 |
33 | def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
34 | n_blocks_local=3, norm='instance', gpu_ids=[], upsample_type='transconv', downsample_type='conv', input_size=(128, 256), n_attn_g=0, n_attn_l=0, proj_factor_g=4, heads_g=4, dim_head_g=128, proj_factor_l=4, heads_l=4, dim_head_l=128):
35 | norm_layer = get_norm_layer(norm_type=norm)
36 | if netG == 'global':
37 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global,
38 | n_blocks_global, norm_layer, downsample_type=downsample_type, upsample_type=upsample_type,
39 | input_size=input_size,
40 | n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g)
41 | elif netG == 'local':
42 | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
43 | n_local_enhancers, n_blocks_local, norm_layer, downsample_type=downsample_type, upsample_type=upsample_type,
44 | input_size=input_size,
45 | n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g, n_attn_l=n_attn_l, proj_factor_l=proj_factor_l, heads_l=heads_l, dim_head_l=dim_head_l)
46 | elif netG == 'encoder':
47 | netG = Encoder(input_nc, output_nc, ngf,
48 | n_downsample_global, norm_layer)
49 | else:
50 | raise('generator not implemented!')
51 | print(netG)
52 | if len(gpu_ids) > 0:
53 | assert(torch.cuda.is_available())
54 | netG.cuda(gpu_ids[0])
55 | netG.apply(weights_init)
56 | return netG
57 |
58 |
59 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
60 | norm_layer = get_norm_layer(norm_type=norm)
61 | netD = MultiscaleDiscriminator(
62 | input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
63 | print(netD)
64 | if len(gpu_ids) > 0:
65 | assert(torch.cuda.is_available())
66 | netD.cuda(gpu_ids[0])
67 | netD.apply(weights_init)
68 | return netD
69 |
70 |
71 | def define_MR_D(ndf, n_layers_D, input_nc, norm='instance', use_sigmoid=False, num_D=1, gpu_ids=[], base_nfft=2048, window=None, min_value=1e-7, mdct_type='4', normalizer=None, getIntermFeat=False, abs_spectro=False):
72 | norm_layer = get_norm_layer(norm_type=norm)
73 | netD = MultiResolutionDiscriminator(ndf=ndf, n_layers=n_layers_D, input_nc=input_nc, norm_layer=norm_layer, num_D=num_D, base_nfft=base_nfft, window=window,
74 | min_value=min_value, mdct_type=mdct_type, use_sigmoid=use_sigmoid, normalizer=normalizer, getIntermFeat=getIntermFeat, abs_spectro=abs_spectro)
75 | print(netD)
76 | if len(gpu_ids) > 0:
77 | assert(torch.cuda.is_available())
78 | netD.cuda(gpu_ids[0])
79 | netD.apply(weights_init)
80 | return netD
81 |
82 |
83 | def print_network(net):
84 | if isinstance(net, list):
85 | net = net[0]
86 | num_params = 0
87 | for param in net.parameters():
88 | num_params += param.numel()
89 | print(net)
90 | print('Total number of parameters: %d' % num_params)
91 |
92 | ##############################################################################
93 | # Losses
94 | ##############################################################################
95 |
96 |
97 | class GANLoss(nn.Module):
98 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
99 | device='cuda'):
100 | super(GANLoss, self).__init__()
101 | self.real_label = target_real_label
102 | self.fake_label = target_fake_label
103 | self.real_label_var = None
104 | self.fake_label_var = None
105 | self.device = device
106 | if use_lsgan:
107 | self.loss = nn.MSELoss()
108 | else:
109 | self.loss = nn.BCELoss()
110 |
111 | def get_target_tensor(self, input, target_is_real):
112 | target_tensor = None
113 | if target_is_real:
114 | create_label = ((self.real_label_var is None) or
115 | (self.real_label_var.shape != input.shape))
116 | if create_label:
117 | self.real_label_var = torch.full(size=input.size(),fill_value=self.real_label, device=self.device, requires_grad=False)
118 | target_tensor = self.real_label_var
119 | else:
120 | create_label = ((self.fake_label_var is None) or
121 | (self.fake_label_var.shape != input.shape))
122 | if create_label:
123 | self.fake_label_var = torch.full(size=input.size(),fill_value=self.fake_label, device=self.device, requires_grad=False)
124 | target_tensor = self.fake_label_var
125 | return target_tensor
126 |
127 | def __call__(self, input, target_is_real):
128 | if isinstance(input[0], list):
129 | loss = 0
130 | for input_i in input:
131 | pred = input_i[-1]
132 | target_tensor = self.get_target_tensor(pred, target_is_real)
133 | loss += self.loss(pred, target_tensor)
134 | return loss
135 | else:
136 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
137 | return self.loss(input[-1], target_tensor)
138 |
139 |
140 | class VGGLoss(nn.Module):
141 | def __init__(self, gpu_ids):
142 | super(VGGLoss, self).__init__()
143 | self.vgg = Vgg19().cuda()
144 | self.criterion = nn.L1Loss()
145 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
146 |
147 | def forward(self, x, y):
148 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
149 | loss = 0
150 | for i in range(len(x_vgg)):
151 | loss += self.weights[i] * \
152 | self.criterion(x_vgg[i], y_vgg[i].detach())
153 | return loss
154 |
155 |
156 | class SpecLoss(nn.Module):
157 | def __init__(self) -> None:
158 | super(SpecLoss, self).__init__()
159 |
160 | def forward(self, x, y):
161 | # input shape B,C,H,W
162 | N = x.shape[-1]
163 | spec_loss = torch.norm(x-y, p='fro', dim=(-1, -2)) / \
164 | torch.norm(x, p='fro', dim=(-1, -2))
165 | mag_loss = torch.norm(torch.log10(
166 | torch.abs(x)+1e-7) - torch.log10(torch.abs(y)+1e-7), p=1, dim=(-1, -2)) / N
167 | return torch.mean(spec_loss+mag_loss)
168 | ##############################################################################
169 | # Generator
170 | ##############################################################################
171 |
172 |
173 | class LocalEnhancer(nn.Module):
174 | def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
175 | n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect', downsample_type='conv', upsample_type='transconv', n_attn_g=0, n_attn_l=0, input_size=(128, 256), proj_factor_g=4, heads_g=4, dim_head_g=128, proj_factor_l=4, heads_l=4, dim_head_l=128):
176 | super(LocalEnhancer, self).__init__()
177 | self.n_local_enhancers = n_local_enhancers
178 |
179 | ###### global generator model #####
180 | ngf_global = ngf * (2**n_local_enhancers)
181 | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer,
182 | downsample_type=downsample_type, upsample_type=upsample_type,
183 | input_size=tuple(map(lambda x: x//2, input_size)), n_attn_g=n_attn_g, proj_factor_g=proj_factor_g, heads_g=heads_g, dim_head_g=dim_head_g).model
184 | # get rid of final convolution layers
185 | model_global = [model_global[i] for i in range(len(model_global)-3)]
186 | self.model = nn.Sequential(*model_global)
187 |
188 | # downsample
189 | if downsample_type == 'conv':
190 | downsample_layer = nn.Conv2d
191 | elif downsample_type == 'resconv':
192 | downsample_layer = ConvResBlock
193 | else:
194 | raise NotImplementedError(
195 | 'downsample layer [{:s}] is not found'.format(downsample_type))
196 | # upsample
197 | if upsample_type == 'transconv':
198 | upsample_layer = nn.ConvTranspose2d
199 | elif upsample_type == 'interpolate':
200 | upsample_layer = InterpolateUpsample
201 | else:
202 | raise NotImplementedError(
203 | 'upsample layer [{:s}] is not found'.format(upsample_type))
204 | ###### local enhancer layers #####
205 | # downsample
206 | ngf_global = ngf * (2**(n_local_enhancers-1))
207 | model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
208 | norm_layer(ngf_global), nn.ReLU(True),
209 | downsample_layer(ngf_global, ngf_global * 2,
210 | kernel_size=3, stride=2, padding=1),
211 | norm_layer(ngf_global * 2), nn.ReLU(True)]
212 | # residual blocks
213 | model_upsample = []
214 | for i in range(n_blocks_local):
215 | model_upsample += [ResnetBlock(ngf_global * 2,
216 | padding_type=padding_type, norm_layer=norm_layer)]
217 | # attention bottleneck
218 | if n_attn_l > 0:
219 | middle = n_blocks_local//2
220 | # 8x downsample
221 | down = [downsample_layer(ngf_global * 2, ngf_global,
222 | kernel_size=3, stride=2, padding=1),
223 | norm_layer(ngf_global), nn.ReLU(True)]
224 | down += [downsample_layer(ngf_global, ngf_global,
225 | kernel_size=3, stride=2, padding=1),
226 | norm_layer(ngf_global), nn.ReLU(True)]*2
227 | down = nn.Sequential(*down)
228 | model_upsample.insert(middle, down)
229 |
230 | middle += 1
231 | input_size = tuple(map(lambda x: x//16, input_size))
232 | from bottleneck_transformer_pytorch import BottleStack
233 | attn_block = BottleStack(dim=ngf_global, fmap_size=input_size, dim_out=ngf_global*2, num_layers=n_attn_l, proj_factor=proj_factor_l,
234 | downsample=False, heads=heads_l, dim_head=dim_head_l, activation=nn.ReLU(True), rel_pos_emb=False)
235 | model_upsample.insert(middle, attn_block)
236 | model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global*2, kernel_size=3, stride=2, padding=1, output_padding=1),
237 | norm_layer(ngf_global), nn.ReLU(True)]*3
238 |
239 | model_upsample += [upsample_layer(in_channels=ngf_global*2, out_channels=ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
240 | norm_layer(ngf_global), nn.ReLU(True)]
241 |
242 | # final convolution
243 | model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(
244 | ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
245 |
246 | self.model1_1 = nn.Sequential(*model_downsample)
247 | self.model1_2 = nn.Sequential(*model_upsample)
248 |
249 | self.downsample = nn.AvgPool2d(
250 | 3, stride=2, padding=[1, 1], count_include_pad=False)
251 | self.freeze = False
252 |
253 | def forward(self, input):
254 | # create input pyramid
255 | input_downsampled = [input]
256 | for i in range(self.n_local_enhancers):
257 | input_downsampled.append(self.downsample(input_downsampled[-1]))
258 |
259 | # output at coarest level
260 | output_prev = self.model(input_downsampled[-1])
261 | # build up one layer at a time
262 | model_downsample = self.model1_1
263 | model_upsample = self.model1_2
264 | input_i = input_downsampled[0]
265 | output_prev = model_upsample(
266 | model_downsample(input_i) + output_prev)
267 | return output_prev
268 |
269 | def set_freeze(self, freeze_global_d=True, freeze_global_u=False, freeze_local_d=True, freeze_local_u=False):
270 | print("The following layers will be freezed:")
271 | '''Freeze downsample layers'''
272 | print('Global:')
273 | for name, layer in self.model.named_children():
274 | module_name = layer.__class__.__name__
275 | if 'Conv2d' in module_name or 'ConvResBlock' in module_name:
276 | if freeze_global_d:
277 | print(name, module_name)
278 | for param in layer.parameters():
279 | param.requires_grad = not freeze_global_d
280 | elif 'InterpolateUpsample' in module_name or 'ConvTranspose2d' in module_name or 'ResnetBlock' in module_name or 'BottleStack' in module_name:
281 | if freeze_global_u:
282 | print(name, module_name)
283 | for param in layer.parameters():
284 | param.requires_grad = not freeze_global_u
285 | print('Loacl:')
286 | for name, layer in self.model1_1.named_children():
287 | module_name = layer.__class__.__name__
288 | for param in layer.parameters():
289 | if freeze_local_d:
290 | print(name, module_name)
291 | param.requires_grad = not freeze_local_d
292 |
293 | for name, layer in self.model1_2.named_children():
294 | module_name = layer.__class__.__name__
295 | for param in layer.parameters():
296 | if freeze_local_u:
297 | print(name, module_name)
298 | param.requires_grad = not freeze_local_u
299 |
300 |
301 | class GlobalGenerator(nn.Module):
302 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
303 | padding_type='reflect', upsample_type='transconv', downsample_type='conv', n_attn_g=0, input_size=(128, 256), proj_factor_g=4, heads_g=4, dim_head_g=128):
304 | assert(n_blocks >= 0)
305 | super(GlobalGenerator, self).__init__()
306 | activation = nn.ReLU(True)
307 |
308 | model = [nn.ReflectionPad2d(3), nn.Conv2d(
309 | input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
310 | # downsample
311 | if downsample_type == 'conv':
312 | downsample_layer = nn.Conv2d
313 | elif downsample_type == 'resconv':
314 | downsample_layer = ConvResBlock
315 | else:
316 | raise NotImplementedError(
317 | 'downsample layer [{:s}] is not found'.format(downsample_type))
318 | # upsample
319 | if upsample_type == 'transconv':
320 | upsample_layer = nn.ConvTranspose2d
321 | elif upsample_type == 'interpolate':
322 | upsample_layer = InterpolateUpsample
323 | else:
324 | raise NotImplementedError(
325 | 'upsample layer [{:s}] is not found'.format(upsample_type))
326 |
327 | for i in range(n_downsampling):
328 | mult = 2**i
329 | model += [downsample_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
330 | norm_layer(ngf * mult * 2), activation]
331 |
332 | # resnet blocks
333 | mult = 2**n_downsampling
334 | bottle_neck = []
335 | for i in range(n_blocks):
336 | bottle_neck += [ResnetBlock(ngf * mult, padding_type=padding_type,
337 | activation=activation, norm_layer=norm_layer)]
338 | if n_attn_g > 0:
339 | middle = n_blocks//2
340 | input_size = tuple(map(lambda x: x//mult, input_size))
341 | from bottleneck_transformer_pytorch import BottleStack
342 | attn_block = BottleStack(dim=ngf * mult, fmap_size=input_size, dim_out=ngf * mult, num_layers=n_attn_g, proj_factor=proj_factor_g,
343 | downsample=False, heads=heads_g, dim_head=dim_head_g, activation=activation, rel_pos_emb=False)
344 | bottle_neck.insert(middle, attn_block)
345 | model += bottle_neck
346 |
347 | for i in range(n_downsampling):
348 | mult = 2**(n_downsampling - i)
349 | model += [upsample_layer(in_channels=ngf * mult, out_channels=int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
350 | norm_layer(int(ngf * mult / 2)), activation]
351 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf,
352 | output_nc, kernel_size=7, padding=0), nn.Tanh()]
353 | self.model = nn.Sequential(*model)
354 | self.freeze = False
355 |
356 | def forward(self, input):
357 | return self.model(input)
358 |
359 | def set_freeze(self, freeze=True):
360 | if self.freeze == freeze:
361 | return
362 | else:
363 | self.freeze = freeze
364 | print("The following layers will be freezed:")
365 | '''Freeze downsample layers'''
366 | for name, layer in self.model.named_children():
367 | module_name = layer.__class__.__name__
368 | if 'ResnetBlock' in module_name or 'BottleStack' in module_name:
369 | break
370 | print(name, module_name)
371 | for param in layer.parameters():
372 | param.requires_grad = not freeze
373 |
374 |
375 | class InterpolateUpsample(nn.Module):
376 | """
377 | An upsampling layer with an optional convolution.
378 | :param channels: channels in the inputs and outputs.
379 | :param use_conv: a bool determining if a convolution is applied.
380 |
381 | """
382 |
383 | def __init__(self, *args, **kwargs):
384 | super(InterpolateUpsample, self).__init__()
385 | self.in_channels = kwargs['in_channels']
386 | self.out_channels = kwargs['out_channels']
387 | self.conv1 = nn.Conv2d(
388 | self.in_channels, self.out_channels, 5, padding=1)
389 | self.conv2 = nn.Conv2d(
390 | self.out_channels, self.out_channels, 3, padding=2)
391 | self.conv_res = nn.Conv2d(
392 | self.in_channels, self.out_channels, 3, padding=1)
393 |
394 | def forward(self, x):
395 | assert x.shape[1] == self.in_channels
396 | x = interpolate(x, scale_factor=2.0, mode="nearest")
397 | res_x = self.conv_res(x)
398 | x = self.conv1(x)
399 | x = self.conv2(x)
400 | return x+res_x
401 |
402 |
403 | class ConvResBlock(nn.Module):
404 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
405 | super(ConvResBlock, self).__init__()
406 | self.conv1 = nn.Conv2d(in_channels, in_channels,
407 | kernel_size, stride, padding)
408 | self.conv2 = nn.Conv2d(
409 | in_channels, out_channels, 5, padding=2)
410 | self.conv_res = nn.Conv2d(
411 | in_channels, out_channels, kernel_size=3, stride=1, padding=1)
412 |
413 | def forward(self, x):
414 | x = self.conv1(x)
415 | res_x = self.conv_res(x)
416 | x = self.conv2(x)
417 | return x+res_x
418 | # Define a resnet block
419 |
420 |
421 | class ResnetBlock(nn.Module):
422 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
423 | super(ResnetBlock, self).__init__()
424 | self.conv_block = self.build_conv_block(
425 | dim, padding_type, norm_layer, activation, use_dropout)
426 |
427 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
428 | conv_block = []
429 | p = 0
430 | if padding_type == 'reflect':
431 | conv_block += [nn.ReflectionPad2d(1)]
432 | elif padding_type == 'replicate':
433 | conv_block += [nn.ReplicationPad2d(1)]
434 | elif padding_type == 'zero':
435 | p = 1
436 | else:
437 | raise NotImplementedError(
438 | 'padding [%s] is not implemented' % padding_type)
439 |
440 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
441 | norm_layer(dim),
442 | activation]
443 | if use_dropout:
444 | conv_block += [nn.Dropout(0.5)]
445 |
446 | p = 0
447 | if padding_type == 'reflect':
448 | conv_block += [nn.ReflectionPad2d(1)]
449 | elif padding_type == 'replicate':
450 | conv_block += [nn.ReplicationPad2d(1)]
451 | elif padding_type == 'zero':
452 | p = 1
453 | else:
454 | raise NotImplementedError(
455 | 'padding [%s] is not implemented' % padding_type)
456 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
457 | norm_layer(dim)]
458 |
459 | return nn.Sequential(*conv_block)
460 |
461 | def forward(self, x):
462 | out = x + self.conv_block(x)
463 | return out
464 |
465 |
466 | class Encoder(nn.Module):
467 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
468 | super(Encoder, self).__init__()
469 | self.output_nc = output_nc
470 |
471 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
472 | norm_layer(ngf), nn.ReLU(True)]
473 | # downsample
474 | for i in range(n_downsampling):
475 | mult = 2**i
476 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
477 | norm_layer(ngf * mult * 2), nn.ReLU(True)]
478 |
479 | # upsample
480 | for i in range(n_downsampling):
481 | mult = 2**(n_downsampling - i)
482 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
483 | norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
484 |
485 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf,
486 | output_nc, kernel_size=7, padding=0), nn.Tanh()]
487 | self.model = nn.Sequential(*model)
488 |
489 | def forward(self, input, inst):
490 | outputs = self.model(input)
491 |
492 | # instance-wise average pooling
493 | outputs_mean = outputs.clone()
494 | inst_list = np.unique(inst.cpu().numpy().astype(int))
495 | for i in inst_list:
496 | for b in range(input.size()[0]):
497 | indices = (inst[b:b+1] == int(i)).nonzero() # n x 4
498 | for j in range(self.output_nc):
499 | output_ins = outputs[indices[:, 0] + b,
500 | indices[:, 1] + j, indices[:, 2], indices[:, 3]]
501 | mean_feat = torch.mean(output_ins).expand_as(output_ins)
502 | outputs_mean[indices[:, 0] + b, indices[:, 1] +
503 | j, indices[:, 2], indices[:, 3]] = mean_feat
504 | return outputs_mean
505 |
506 |
507 | class MultiscaleDiscriminator(nn.Module):
508 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
509 | use_sigmoid=False, num_D=3, getIntermFeat=False):
510 | super(MultiscaleDiscriminator, self).__init__()
511 | self.num_D = num_D
512 | self.n_layers = n_layers
513 | self.getIntermFeat = getIntermFeat
514 |
515 | for i in range(num_D):
516 | netD = NLayerDiscriminator(
517 | input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
518 | if getIntermFeat:
519 | for j in range(n_layers+2):
520 | setattr(self, 'scale'+str(i)+'_layer' +
521 | str(j), getattr(netD, 'model'+str(j)))
522 | else:
523 | setattr(self, 'layer'+str(i), netD.model)
524 |
525 | self.downsample = nn.AvgPool2d(
526 | 3, stride=2, padding=[1, 1], count_include_pad=False)
527 |
528 | def singleD_forward(self, model, input):
529 | if self.getIntermFeat:
530 | result = [input]
531 | for i in range(len(model)):
532 | result.append(model[i](result[-1]))
533 | return result[1:]
534 | else:
535 | return [model(input)]
536 |
537 | def forward(self, input):
538 | num_D = self.num_D
539 | result = []
540 | input_downsampled = input
541 | for i in range(num_D):
542 | if self.getIntermFeat:
543 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j))
544 | for j in range(self.n_layers+2)]
545 | else:
546 | model = getattr(self, 'layer'+str(num_D-1-i))
547 | result.append(self.singleD_forward(model, input_downsampled))
548 | if i != (num_D-1):
549 | input_downsampled = self.downsample(input_downsampled)
550 | return result
551 |
552 |
553 | class MultiResolutionDiscriminator(nn.Module):
554 | def __init__(self, input_nc=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
555 | use_sigmoid=False, num_D=3, base_nfft=2048, window=None, min_value=1e-7, mdct_type='4', normalizer=None, getIntermFeat=False, abs_spectro=False):
556 | super(MultiResolutionDiscriminator, self).__init__()
557 | self.num_D = num_D
558 | self.n_layers = n_layers
559 | self.base_nfft = base_nfft
560 | self.window = window
561 | self.min_value = min_value
562 | self.mdct = []
563 | self.normalizer = normalizer
564 | self.getIntermFeat = getIntermFeat
565 | self.abs_spectro = abs_spectro
566 |
567 | if mdct_type == '4':
568 | from .mdct import MDCT4
569 | elif mdct_type == '2':
570 | from .mdct import MDCT2
571 | from dct.dct_native import DCT_2N_native
572 | else:
573 | raise NotImplementedError(
574 | 'MDCT type [%s] is not implemented' % mdct_type)
575 |
576 | for i in range(num_D):
577 | netD = NLayerDiscriminator(
578 | input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
579 | if getIntermFeat:
580 | for j in range(n_layers+2):
581 | setattr(self, 'scale'+str(i)+'_layer' +
582 | str(j), getattr(netD, 'model'+str(j)))
583 | else:
584 | setattr(self, 'layer'+str(i), netD.model)
585 |
586 | if i == 0:
587 | N = int(self.base_nfft*2)
588 | else:
589 | N = int(self.base_nfft//(2**i))
590 | if mdct_type == '4':
591 | self.mdct.append(MDCT4(n_fft=N, hop_length=N//2,
592 | win_length=N, window=self.window, center=True))
593 | elif mdct_type == '2':
594 | _dct = DCT_2N_native()
595 | self.mdct.append(MDCT2(n_fft=N, hop_length=N//2, win_length=N,
596 | window=self.window, dct_op=_dct, center=True))
597 |
598 | def singleD_forward(self, model, input):
599 | if self.getIntermFeat:
600 | result = [input]
601 | for i in range(len(model)):
602 | result.append(model[i](result[-1]))
603 | return result[1:]
604 | else:
605 | return [model(input)]
606 |
607 | def forward(self, waveform):
608 | result = []
609 | # FRAME_LENGTH = (BINS-1)*HOP_LENGTH
610 | bins = waveform.size(-1)//self.base_nfft//2 + 1
611 | for i in range(self.num_D):
612 | if i == 0:
613 | frame_len = int((bins//2-1)*self.base_nfft)
614 | else:
615 | N = int(self.base_nfft//(2**i))
616 | frame_len = int((bins*(2**i)-1)*N)
617 | len_diff = frame_len - waveform.size(-1)
618 | if len_diff < 0:
619 | waveform_ = waveform[..., :len_diff]
620 | else:
621 | waveform_ = pad(waveform, (0, len_diff))
622 | spectro = self.mdct[i](waveform_)
623 | if self.abs_spectro:
624 | # [LR, HR/SR, abs(HR/SR)]
625 | spectro = torch.cat(
626 | (spectro, spectro[:, 1, :, :].abs().unsqueeze(1)), dim=1)
627 | if callable(self.normalizer):
628 | # [0] avoids multiple return values
629 | spectro = self.normalizer(spectro)[0]
630 | if self.getIntermFeat:
631 | model = [getattr(self, 'scale'+str(self.num_D-1-i)+'_layer'+str(j))
632 | for j in range(self.n_layers+2)]
633 | else:
634 | model = getattr(self, 'layer'+str(self.num_D-1-i))
635 | result.append(self.singleD_forward(model, spectro.float()))
636 | return result
637 |
638 | # Defines the PatchGAN discriminator with the specified arguments.
639 |
640 |
641 | class NLayerDiscriminator(nn.Module):
642 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
643 | super(NLayerDiscriminator, self).__init__()
644 | self.getIntermFeat = getIntermFeat
645 | self.n_layers = n_layers
646 |
647 | kw = 4
648 | padw = int(np.ceil((kw-1.0)/2))
649 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
650 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
651 |
652 | nf = ndf
653 | for n in range(1, n_layers):
654 | nf_prev = nf
655 | nf = min(nf * 2, 512)
656 | sequence += [[
657 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
658 | norm_layer(nf), nn.LeakyReLU(0.2, True)
659 | ]]
660 |
661 | nf_prev = nf
662 | nf = min(nf * 2, 512)
663 | sequence += [[
664 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
665 | norm_layer(nf),
666 | nn.LeakyReLU(0.2, True)
667 | ]]
668 |
669 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
670 | stride=1, padding=padw)]]
671 |
672 | if use_sigmoid:
673 | sequence += [[nn.Sigmoid()]]
674 |
675 | if getIntermFeat:
676 | for n in range(len(sequence)):
677 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
678 | else:
679 | sequence_stream = []
680 | for n in range(len(sequence)):
681 | sequence_stream += sequence[n]
682 | self.model = nn.Sequential(*sequence_stream)
683 |
684 | def forward(self, input):
685 | if self.getIntermFeat:
686 | res = [input]
687 | for n in range(self.n_layers+2):
688 | model = getattr(self, 'model'+str(n))
689 | res.append(model(res[-1]))
690 | return res[1:]
691 | else:
692 | return self.model(input)
693 |
694 |
695 | class Vgg19(torch.nn.Module):
696 | def __init__(self, requires_grad=False):
697 | super(Vgg19, self).__init__()
698 | vgg_pretrained_features = models.vgg19(pretrained=True).features
699 | self.slice1 = torch.nn.Sequential()
700 | self.slice2 = torch.nn.Sequential()
701 | self.slice3 = torch.nn.Sequential()
702 | self.slice4 = torch.nn.Sequential()
703 | self.slice5 = torch.nn.Sequential()
704 | for x in range(2):
705 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
706 | for x in range(2, 7):
707 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
708 | for x in range(7, 12):
709 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
710 | for x in range(12, 21):
711 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
712 | for x in range(21, 30):
713 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
714 | if not requires_grad:
715 | for param in self.parameters():
716 | param.requires_grad = False
717 |
718 | def forward(self, X):
719 | h_relu1 = self.slice1(X)
720 | h_relu2 = self.slice2(h_relu1)
721 | h_relu3 = self.slice3(h_relu2)
722 | h_relu4 = self.slice4(h_relu3)
723 | h_relu5 = self.slice5(h_relu4)
724 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
725 | return out
726 |
--------------------------------------------------------------------------------