├── src ├── __init__.py ├── utils.py ├── training.py ├── data.py ├── rnn.py ├── gcn.py └── gcntfilm.py ├── results └── results.txt ├── images └── waveforms.png ├── requirements.txt ├── data └── data.txt ├── LICENSE ├── proc_audio.py ├── README.md └── train.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/results.txt: -------------------------------------------------------------------------------- 1 | here go the results from training -------------------------------------------------------------------------------- /images/waveforms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcomunita/gcn-tfilm/HEAD/images/waveforms.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | auraloss==0.2.2 2 | librosa==0.9.2 3 | numba==0.56.0 4 | numpy==1.22.3 5 | scipy==1.9.3 6 | tensorboard==2.10.1 7 | torch==1.12.1 8 | torchinfo==1.7.1 -------------------------------------------------------------------------------- /data/data.txt: -------------------------------------------------------------------------------- 1 | here goes the data organized in 3 folders: 2 | - test 3 | - train 4 | - val 5 | each folder must contain the input and target files (wav - 16bit) (one file only - it will be split into samples during training) 6 | and they must be named DEVICE-input.wav and DEVICE-target.wav (e.g., la2a-input.wav and la2a-target.wav) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mcomunita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /proc_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from scipy.io.wavfile import write 5 | 6 | import src.utils as utils 7 | import src.data as data 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser( 12 | description='''This script takes an input .wav file, loads it and processes it with a neural network model of a 13 | device, i.e guitar amp/pedal, and saves the output as a .wav file''') 14 | 15 | parser.add_argument('--input_file', type=str, default='data/test/ht1-input.wav', help='input file') 16 | parser.add_argument('--output_file', type=str, default='output.wav', help='output file') 17 | parser.add_argument('--model_file', type=str, default='results/ht1-ht11/model_best.json', help='model file') 18 | parser.add_argument('--chunk_length', type=int, default=16384, help='chunk length') 19 | return parser.parse_args() 20 | 21 | 22 | def proc_audio(args): 23 | model_data = utils.json_load(args.model_file) 24 | model = utils.load_model(model_data) 25 | 26 | dataset = data.DataSet(data_dir='', extensions='') 27 | dataset.create_subset('data') 28 | dataset.load_file(args.input_file, set_names='data') 29 | 30 | # cuda 31 | if torch.cuda.is_available(): 32 | device = torch.device("cuda:0") 33 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 34 | # torch.cuda.set_device(0) 35 | print("\ncuda device available") 36 | else: 37 | device = torch.device("cpu") 38 | torch.set_default_tensor_type("torch.FloatTensor") 39 | print("\ncuda device NOT available") 40 | 41 | model = model.to(device) 42 | model.eval() 43 | 44 | with torch.no_grad(): 45 | input_data = dataset.subsets['data'].data['data'][0] 46 | 47 | input_data = input_data.to(device) 48 | # output = network(input_data) 49 | 50 | _, _, output = model.process_data(input_data, 51 | chunk=args.chunk_length) 52 | 53 | write(args.output_file, dataset.subsets['data'].fs, output.cpu().numpy()[:, 0, 0]) 54 | 55 | 56 | def main(): 57 | args = parse_args() 58 | proc_audio(args) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from src.rnn import RNN 5 | from src.gcn import GCN 6 | from src.gcntfilm import GCNTF 7 | 8 | 9 | # Function that checks if a directory exists, and creates it if it doesn't, if dir_name is a list of strings, it will 10 | # create a search path, i.e dir_name = ['directory', 'subdir'] will search for directory 'directory/subdir' 11 | def dir_check(dir_name): 12 | dir_name = [dir_name] if not type(dir_name) == list else dir_name 13 | dir_path = os.path.join(*dir_name) 14 | if os.path.isdir(dir_path): 15 | pass 16 | else: 17 | os.mkdir(dir_path) 18 | 19 | 20 | # Function that takes a file_name and optionally a path to the directory the file is expected to be, returns true if 21 | # the file is found in the stated directory (or the current directory is dir_name = '') or False is dir/file isn't found 22 | def file_check(file_name, dir_name=''): 23 | assert type(file_name) == str 24 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name 25 | full_path = os.path.join(*dir_name, file_name) 26 | return os.path.isfile(full_path) 27 | 28 | 29 | # Function that saves 'data' to a json file. Constructs a file path is dir_name is provided. 30 | def json_save(data, file_name, dir_name='', indent=0): 31 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name 32 | assert type(file_name) == str 33 | file_name = file_name + '.json' if not file_name.endswith('.json') else file_name 34 | full_path = os.path.join(*dir_name, file_name) 35 | with open(full_path, 'w') as fp: 36 | json.dump(data, fp, indent=indent) 37 | 38 | 39 | def json_load(file_name, dir_name=''): 40 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name 41 | file_name = file_name + '.json' if not file_name.endswith('.json') else file_name 42 | full_path = os.path.join(*dir_name, file_name) 43 | with open(full_path) as fp: 44 | return json.load(fp) 45 | 46 | 47 | def load_model(model_data): 48 | model_meta = model_data.pop('model_data') 49 | # print(model_meta) 50 | 51 | if model_meta["model_type"] == "rnn": 52 | model = RNN(**model_meta) 53 | elif model_meta["model_type"] == "gcn": 54 | model = GCN(**model_meta) 55 | elif model_meta["model_type"] == "gcntf": 56 | model = GCNTF(**model_meta) 57 | 58 | if 'state_dict' in model_data: 59 | state_dict = model.state_dict() 60 | for each in model_data['state_dict']: 61 | state_dict[each] = torch.tensor(model_data['state_dict'][each]) 62 | model.load_state_dict(state_dict) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # GCN-TFiLM 5 | **Modelling black-box audio effects with time-varying feature modulation** 6 | 7 | [Paper](https://arxiv.org/abs/2211.00497) | [Webpage](https://mcomunita.github.io/gcn-tfilm_page/) 8 | 9 |
10 | 11 |
12 | 13 |
14 | 15 | ## Setup 16 | 17 | Install the requirements. 18 | ``` 19 | python3 -m venv .venv 20 | source .venv/bin/activate 21 | python3 -m pip install -r requirements.txt 22 | ``` 23 | 24 | ## Dataset 25 | You can find our dataset of fuzz and compressor effects on [Zenodo](https://zenodo.org/record/7271558#.Y2I_6OzP0-R). Once downloaded you can replace the ```data``` folder in this repo. You can also use your data which goes into 3 subfolders: ```test```, ```train``` and ```val```. Each folder should contain the input and target files (wav - 16bit) (one file only - it will be split into samples during training) and they should be named DEVICE-input.wav and DEVICE-target.wav (e.g., la2a-input.wav and la2a-target.wav). 26 | 27 | ## Training 28 | 29 | If would like to re-train the models in the paper, you can run the training script which will train all the models one by one. The training results will be saved in the ```results``` folder. 30 | 31 | ``` 32 | python train.py 33 | ``` 34 | 35 | ## Process Audio 36 | 37 | To process audio using a trained model you can run the ```proc_audio.py``` script. 38 | 39 | ``` 40 | python proc_audio.py \ 41 | --input_file data/test/facebender-input.wav \ 42 | --output_file facebender-output.wav \ 43 | --model_file results/MODEL_FOLDER/model_best.json 44 | --chunk_length 32768 45 | ``` 46 | 47 | ## Credits 48 | [https://github.com/Alec-Wright/Automated-GuitarAmpModelling](https://github.com/Alec-Wright/Automated-GuitarAmpModelling) 49 | 50 | [https://github.com/csteinmetz1/micro-tcn](https://github.com/csteinmetz1/micro-tcn) 51 | 52 | ## Citation 53 | If you use any of this code in your work, please consider citing us. 54 | ``` 55 | @misc{https://doi.org/10.48550/arxiv.2211.00497, 56 | doi = {10.48550/ARXIV.2211.00497}, 57 | url = {https://arxiv.org/abs/2211.00497}, 58 | author = {Comunità, Marco and Steinmetz, Christian J. and Phan, Huy and Reiss, Joshua D.}, 59 | keywords = {Sound (cs.SD), Artificial Intelligence (cs.AI), Machine Learning (cs.LG), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering}, 60 | title = {Modelling black-box audio effects with time-varying feature modulation}, 61 | publisher = {arXiv}, 62 | year = {2022}, 63 | copyright = {Creative Commons Attribution 4.0 International} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import auraloss 4 | 5 | # ESR loss calculates the Error-to-signal between the output/target 6 | 7 | 8 | class ESRLoss(nn.Module): 9 | def __init__(self): 10 | super(ESRLoss, self).__init__() 11 | self.epsilon = 0.00001 12 | 13 | def forward(self, output, target): 14 | loss = torch.add(target, -output) 15 | loss = torch.pow(loss, 2) 16 | loss = torch.mean(loss) 17 | energy = torch.mean(torch.pow(target, 2)) + self.epsilon 18 | loss = torch.div(loss, energy) 19 | return loss 20 | 21 | 22 | class DCLoss(nn.Module): 23 | def __init__(self): 24 | super(DCLoss, self).__init__() 25 | self.epsilon = 0.00001 26 | 27 | def forward(self, output, target): 28 | loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2) 29 | loss = torch.mean(loss) 30 | energy = torch.mean(torch.pow(target, 2)) + self.epsilon 31 | loss = torch.div(loss, energy) 32 | return loss 33 | 34 | 35 | # PreEmph is a class that applies an FIR pre-emphasis filter to the signal, the filter coefficients are in the 36 | # filter_cfs argument, and lp is a flag that also applies a low pass filter 37 | # Only supported for single-channel! 38 | class PreEmph(nn.Module): 39 | def __init__(self, filter_cfs, low_pass=0): 40 | super(PreEmph, self).__init__() 41 | self.epsilon = 0.00001 42 | self.zPad = len(filter_cfs) - 1 43 | 44 | self.conv_filter = nn.Conv1d(1, 1, 2, bias=False) 45 | self.conv_filter.weight.data = torch.tensor([[filter_cfs]], requires_grad=False) 46 | 47 | self.low_pass = low_pass 48 | if self.low_pass: 49 | self.lp_filter = nn.Conv1d(1, 1, 2, bias=False) 50 | self.lp_filter.weight.data = torch.tensor([[[0.85, 1]]], requires_grad=False) 51 | 52 | def forward(self, output, target): 53 | # zero pad the input/target so the filtered signal is the same length 54 | output = torch.cat((torch.zeros(self.zPad, output.shape[1], 1), output)) 55 | target = torch.cat((torch.zeros(self.zPad, target.shape[1], 1), target)) 56 | # Apply pre-emph filter, permute because the dimension order is different for RNNs and Convs in pytorch... 57 | output = self.conv_filter(output.permute(1, 2, 0)) 58 | target = self.conv_filter(target.permute(1, 2, 0)) 59 | 60 | if self.low_pass: 61 | output = self.lp_filter(output) 62 | target = self.lp_filter(target) 63 | 64 | return output.permute(2, 0, 1), target.permute(2, 0, 1) 65 | 66 | 67 | class LossWrapper(nn.Module): 68 | def __init__(self, losses, pre_filt=None): 69 | super(LossWrapper, self).__init__() 70 | self.losses = losses 71 | self.loss_dict = { 72 | 'ESR': ESRLoss(), 73 | 'DC': DCLoss(), 74 | 'L1': torch.nn.L1Loss(), 75 | 'STFT': auraloss.freq.STFTLoss(), 76 | 'MSTFT': auraloss.freq.MultiResolutionSTFTLoss() 77 | } 78 | if pre_filt: 79 | pre_filt = PreEmph(pre_filt) 80 | self.loss_dict['ESRPre'] = lambda output, target: self.loss_dict['ESR'].forward(*pre_filt(output, target)) 81 | loss_functions = [[self.loss_dict[key], value] for key, value in losses.items()] 82 | 83 | self.loss_functions = tuple([items[0] for items in loss_functions]) 84 | try: 85 | self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions])) 86 | except IndexError: 87 | self.loss_factors = torch.ones(len(self.loss_functions)) 88 | 89 | def forward(self, output, target): 90 | all_losses = {} 91 | for i, loss in enumerate(self.losses): 92 | # original shape: length x batch x 1 93 | # auraloss needs: batch x 1 x length 94 | loss_fcn = self.loss_functions[i] 95 | loss_factor = self.loss_factors[i] 96 | if isinstance(loss_fcn, auraloss.freq.STFTLoss) or isinstance(loss_fcn, auraloss.freq.MultiResolutionSTFTLoss): 97 | output = torch.permute(output, (1, 2, 0)) 98 | target = torch.permute(target, (1, 2, 0)) 99 | all_losses[loss] = torch.mul(loss_fcn(output, target), loss_factor) 100 | return all_losses 101 | 102 | 103 | class TrainTrack(dict): 104 | def __init__(self): 105 | self.update({'current_epoch': 0, 106 | 107 | 'tot_train_losses': [], 108 | 'train_losses': [], 109 | 110 | 'tot_val_losses': [], 111 | 'val_losses': [], 112 | 113 | 'train_av_time': 0.0, 114 | 'val_av_time': 0.0, 115 | 'total_time': 0.0, 116 | 117 | 'val_loss_best': 1e12, 118 | 'val_losses_best': 1e12, 119 | 120 | 'test_loss_final': 0, 121 | 'test_losses_final': {}, 122 | 123 | 'test_loss_best': 0, 124 | 'test_losses_best': {}}) 125 | 126 | def restore_data(self, training_info): 127 | self.update(training_info) 128 | 129 | def train_epoch_update(self, epoch_loss, epoch_losses, ep_st_time, ep_end_time, init_time, current_ep): 130 | self['current_epoch'] = current_ep 131 | self['tot_train_losses'].append(epoch_loss) 132 | self['train_losses'].append(epoch_losses) 133 | 134 | if self['train_av_time']: 135 | self['train_av_time'] = (self['train_av_time'] + ep_end_time - ep_st_time) / 2 136 | else: 137 | self['train_av_time'] = ep_end_time - ep_st_time 138 | 139 | self['total_time'] += ((init_time + ep_end_time - ep_st_time)/3600) 140 | 141 | def val_epoch_update(self, val_loss, val_losses, ep_st_time, ep_end_time): 142 | self['tot_val_losses'].append(val_loss) 143 | self['val_losses'].append(val_losses) 144 | 145 | if self['val_av_time']: 146 | self['val_av_time'] = (self['val_av_time'] + ep_end_time - ep_st_time) / 2 147 | else: 148 | self['val_av_time'] = ep_end_time - ep_st_time 149 | 150 | if val_loss < self['val_loss_best']: 151 | self['val_loss_best'] = val_loss 152 | self['val_losses_best'] = val_losses 153 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import wavfile 3 | import torch 4 | import math 5 | import warnings 6 | import os 7 | 8 | 9 | # Function converting np read audio to range of -1 to +1 10 | def audio_converter(audio): 11 | if audio.dtype == 'int16': 12 | return audio.astype(np.float32, order='C') / 32768.0 13 | else: 14 | print('unimplemented audio data type conversion...') 15 | 16 | 17 | # Splits audio, each split marker determines the fraction of the total audio in that split, i.e [0.75, 0.25] will put 18 | # 75% in the first split and 25% in the second 19 | def audio_splitter(audio, split_markers): 20 | assert sum(split_markers) <= 1.0 21 | if sum(split_markers) < 0.999: 22 | warnings.warn("sum of split markers is less than 1, so not all audio will be included in dataset") 23 | start = 0 24 | slices = [] 25 | # convert split markers to samples 26 | split_bounds = [int(x * audio.shape[0]) for x in split_markers] 27 | for n in split_bounds: 28 | end = start + n 29 | slices.append(audio[start:end]) 30 | start = end 31 | return slices 32 | 33 | 34 | # converts numpy audio into frames, and creates a torch tensor from them, frame_len = 0 just converts to a torch tensor 35 | def framify(audio, frame_len): 36 | # If audio is mono, add a dummy dimension, so the same operations can be applied to mono/multichannel audio 37 | audio = np.expand_dims(audio, 1) if len(audio.shape) == 1 else audio 38 | # Calculate the number of segments the training data will be split into in frame_len is not 0 39 | seg_num = math.floor(audio.shape[0] / frame_len) if frame_len else 1 40 | # If no frame_len is provided, set frame_len to be equal to length of the input audio 41 | frame_len = audio.shape[0] if not frame_len else frame_len 42 | # Find the number of channels 43 | channels = audio.shape[1] 44 | # Initialise tensor matrices 45 | dataset = torch.empty((frame_len, seg_num, channels)) 46 | # Load the audio for the training set 47 | for i in range(seg_num): 48 | dataset[:, i, :] = torch.from_numpy(audio[i * frame_len:(i + 1) * frame_len, :]) 49 | return dataset 50 | 51 | 52 | """This is the main DataSet class, it can hold any number of subsets, which could be, e.g, the training/test/val sets. 53 | The subsets are created by the create_subset method and stored in the DataSet.subsets dictionary. 54 | 55 | datadir: is the default location where the DataSet instance will look when told to load an audio file 56 | extensions: is the default ends of the paired data files, when using paired data, so by default when loading the file 57 | 'wicked_guitar', the load_file method with look for 'wicked_guitar-input.wav' and 'wicked_guitar-target.wav' 58 | to disable this behaviour, enter extensions = '', or None, or anything that evaluates to false in python """ 59 | 60 | 61 | class DataSet: 62 | def __init__(self, data_dir='../Dataset/', extensions=('input', 'target')): 63 | self.extensions = extensions if extensions else [''] 64 | self.subsets = {} 65 | assert type(data_dir) == str, "data_dir should be string,not %r" % {type(data_dir)} 66 | self.data_dir = data_dir 67 | 68 | # add a subset called 'name', desired 'frame_len' is given in seconds, or 0 for just one long frame 69 | def create_subset(self, name, frame_len=0): 70 | assert type(name) == str, "data subset name must be a string, not %r" %{type(name)} 71 | assert not (name in self.subsets), "subset %r already exists" %name 72 | self.subsets[name] = SubSet(frame_len) 73 | 74 | # load a file of 'filename' into existing subset/s 'set_names', split fractionally as specified by 'splits', 75 | # if 'cond_val' is provided the conditioning value will be saved along with the frames of the loaded data 76 | def load_file(self, filename, set_names='train', splits=None, cond_val=None): 77 | # Assertions and checks 78 | if type(set_names) == str: 79 | set_names = [set_names] 80 | assert len(set_names) == 1 or len(set_names) == len(splits), "number of subset names must equal number of " \ 81 | "split markers" 82 | assert [self.subsets.get(each) for each in set_names], "set_names contains subsets that don't exist yet" 83 | 84 | # Load each of the 'extensions' 85 | for i, ext in enumerate(self.extensions): 86 | try: 87 | file_loc = os.path.join(self.data_dir, filename + '-' + ext) 88 | file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc 89 | np_data = wavfile.read(file_loc) 90 | except FileNotFoundError: 91 | file_loc = os.path.join(self.data_dir, filename + ext) 92 | file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc 93 | np_data = wavfile.read(file_loc) 94 | except FileNotFoundError: 95 | print(["File Not Found At: " + self.data_dir + filename]) 96 | return 97 | raw_audio = audio_converter(np_data[1]) 98 | # Split the audio if the set_names were provided 99 | if len(set_names) > 1: 100 | raw_audio = audio_splitter(raw_audio, splits) 101 | for n, sets in enumerate(set_names): 102 | self.subsets[set_names[n]].add_data(np_data[0], raw_audio[n], ext, cond_val) 103 | elif len(set_names) == 1: 104 | self.subsets[set_names[0]].add_data(np_data[0], raw_audio, ext, cond_val) 105 | 106 | 107 | # The SubSet class holds a subset of data, 108 | # frame_len sets the length of audio per frame (in s), if set to 0 a single frame is used instead 109 | class SubSet: 110 | def __init__(self, frame_len): 111 | self.data = {} 112 | self.cond_data = {} 113 | self.frame_len = frame_len 114 | self.conditioning = None 115 | self.fs = None 116 | 117 | # Add 'audio' data, in the data dictionary at the key 'ext', if cond_val is provided save the cond_val of each frame 118 | def add_data(self, fs, audio, ext, cond_val): 119 | if not self.fs: 120 | self.fs = fs 121 | assert self.fs == fs, "data with different sample rate provided to subset" 122 | # if no 'ext' is provided, all the subsets data will be stored at the 'data' key of the 'data' dict 123 | ext = 'data' if not ext else ext 124 | # Frame the data and optionally create a tensor of the conditioning values of each frame 125 | framed_data = framify(audio, self.frame_len) 126 | cond_data = cond_val * torch.ones(framed_data.shape[1]) if isinstance(cond_val, (float, int)) else None 127 | 128 | try: 129 | # Convert data from tuple to list and concatenate new data onto the data tensor 130 | data = list(self.data[ext]) 131 | self.data[ext] = (torch.cat((data[0], framed_data), 1),) 132 | # If cond_val is provided add it to the cond_val tensor, note all frames or no frames must have cond vals 133 | if isinstance(cond_val, (float, int)): 134 | assert torch.is_tensor(self.cond_data[ext][0]), 'cond val provided, but previous data has no cond val' 135 | c_data = list(self.cond_data[ext]) 136 | self.cond_data[ext] = (torch.cat((c_data[0], cond_data), 0),) 137 | # If this is the first data to be loaded into the subset, create the data and cond_data tuples 138 | except KeyError: 139 | self.data[ext] = (framed_data,) 140 | self.cond_data[ext] = (cond_data,) 141 | -------------------------------------------------------------------------------- /src/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import argparse 4 | import src.utils as utils 5 | # from contextlib import nullcontext 6 | 7 | 8 | def wrapperkwargs(func, kwargs): 9 | return func(**kwargs) 10 | 11 | 12 | def wrapperargs(func, args): 13 | return func(*args) 14 | 15 | 16 | """ 17 | Simple RNN class made of a single recurrent unit of type LSTM, GRU or Elman, 18 | followed by a fully connected layer 19 | """ 20 | 21 | 22 | class RNN(torch.nn.Module): 23 | def __init__(self, 24 | input_size=1, 25 | output_size=1, 26 | unit_type="LSTM", 27 | hidden_size=32, 28 | skip=1, 29 | bias_fl=True, 30 | nlayers=1, 31 | **kwargs): 32 | super(RNN, self).__init__() 33 | self.input_size = input_size 34 | self.output_size = output_size 35 | self.unit_type = unit_type 36 | self.hidden_size = hidden_size 37 | self.skip = skip 38 | self.bias_fl = bias_fl 39 | self.nlayers = nlayers 40 | self.save_state = True 41 | self.hidden = None # hidden state 42 | 43 | # create dictionary of possible block types 44 | self.rec = wrapperargs(getattr(torch.nn, unit_type), 45 | [input_size, hidden_size, nlayers]) 46 | 47 | self.lin = torch.nn.Linear(hidden_size, output_size, bias=bias_fl) 48 | 49 | def forward(self, x): 50 | if self.skip: 51 | # save the residual for the skip connection 52 | res = x[:, :, :self.skip] 53 | x, self.hidden = self.rec(x, self.hidden) 54 | return self.lin(x) + res 55 | else: 56 | x, self.hidden = self.rec(x, self.hidden) 57 | return self.lin(x) 58 | 59 | # detach hidden state, this resets gradient tracking on the hidden state 60 | def detach_hidden(self): 61 | if self.hidden.__class__ == tuple: 62 | self.hidden = tuple([h.clone().detach() for h in self.hidden]) 63 | else: 64 | self.hidden = self.hidden.clone().detach() 65 | 66 | # reset hidden state 67 | def reset_hidden(self): 68 | self.hidden = None 69 | 70 | # train_epoch runs one epoch of training 71 | def train_epoch(self, 72 | input_data, 73 | target_data, 74 | loss_fcn, 75 | optim, 76 | batch_size, 77 | init_len=200, 78 | up_fr=1000): 79 | # shuffle the segments at the start of the epoch 80 | shuffle = torch.randperm(input_data.shape[1]) 81 | 82 | # iterate over the batches 83 | ep_losses = {} 84 | 85 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)): 86 | # load batch 87 | input_batch = input_data[:, 88 | shuffle[batch_i * batch_size: 89 | (batch_i + 1) * batch_size], 90 | :] 91 | target_batch = target_data[:, 92 | shuffle[batch_i * batch_size: 93 | (batch_i + 1) * batch_size], 94 | :] 95 | 96 | # initialise network hidden state by processing some samples then zero the gradient buffers 97 | self(input_batch[0:init_len, :, :]) 98 | self.zero_grad() 99 | 100 | # choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr 101 | start_i = init_len 102 | tot_batch_losses = {} 103 | 104 | # iterate over the remaining samples in the mini batch 105 | for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)): 106 | # process batch 107 | output = self(input_batch[start_i:start_i + up_fr, :, :]) 108 | 109 | # loss and backprop 110 | partial_batch_losses = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :]) 111 | 112 | partial_batch_loss = 0 113 | for loss in partial_batch_losses: 114 | partial_batch_loss += partial_batch_losses[loss] 115 | 116 | partial_batch_loss.backward() 117 | optim.step() 118 | 119 | # detach hidden state for truncated BPTT 120 | self.detach_hidden() 121 | self.zero_grad() 122 | 123 | # update the start index for next iteration 124 | start_i += up_fr 125 | 126 | # add partial batch losses to total batch losses 127 | if tot_batch_losses == {}: 128 | tot_batch_losses = partial_batch_losses 129 | else: 130 | for loss in partial_batch_losses: 131 | tot_batch_losses[loss] += partial_batch_losses[loss] 132 | 133 | # add average batch losses to epoch losses 134 | if ep_losses == {}: 135 | for loss in tot_batch_losses: 136 | ep_losses[loss] = tot_batch_losses[loss] / (k + 1) 137 | else: 138 | for loss in tot_batch_losses: 139 | ep_losses[loss] += tot_batch_losses[loss] / (k + 1) 140 | 141 | # # add the average batch loss to the epoch loss and reset the hidden states to zeros 142 | # ep_loss += batch_loss / (k + 1) 143 | 144 | # reset hidden state before next batch 145 | self.reset_hidden() 146 | 147 | # mean epoch losses 148 | for loss in ep_losses: 149 | ep_losses[loss] /= (batch_i + 1) 150 | 151 | return ep_losses 152 | 153 | def process_data(self, 154 | input_data, 155 | target_data=None, 156 | chunk=16384, 157 | loss_fcn=None, 158 | grad=False): 159 | 160 | if not (input_data.shape[0] / chunk).is_integer(): 161 | # round to next chunk size 162 | padding = chunk - (input_data.shape[0] % chunk) 163 | input_data = torch.nn.functional.pad(input_data, 164 | (0, 0, 0, 0, 0, padding), 165 | mode='constant', 166 | value=0) 167 | if target_data != None: 168 | target_data = torch.nn.functional.pad(target_data, 169 | (0, 0, 0, 0, 0, padding), 170 | mode='constant', 171 | value=0) 172 | 173 | with torch.no_grad(): 174 | # reset state before processing 175 | self.reset_hidden() 176 | 177 | # process input 178 | output_data = torch.empty_like(input_data) 179 | 180 | for l in range(int(output_data.size()[0] / chunk)): 181 | output_data[l * chunk:(l + 1) * chunk] = \ 182 | self(input_data[l * chunk:(l + 1) * chunk]) 183 | self.detach_hidden() 184 | 185 | # reset state before other computations 186 | self.reset_hidden() 187 | 188 | if loss_fcn != None and target_data != None: 189 | losses = loss_fcn(output_data, target_data) 190 | return input_data, target_data, output_data, losses 191 | 192 | return input_data, target_data, output_data 193 | 194 | def save_model(self, 195 | file_name, 196 | direc=''): 197 | if direc: 198 | utils.dir_check(direc) 199 | 200 | model_data = {'model_data': {'model_type': 'rnn', 201 | 'input_size': self.input_size, 202 | 'output_size': self.output_size, 203 | 'unit_type': self.unit_type, 204 | 'hidden_size': self.hidden_size, 205 | 'skip': self.skip, 206 | 'bias_fl': self.bias_fl, 207 | 'nlayers': self.nlayers}} 208 | 209 | if self.save_state: 210 | model_state = self.state_dict() 211 | for each in model_state: 212 | model_state[each] = model_state[each].tolist() 213 | model_data['state_dict'] = model_state 214 | 215 | utils.json_save(model_data, file_name, direc) 216 | 217 | # add any model hyperparameters here 218 | @staticmethod 219 | def add_model_specific_args(parent_parser): 220 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 221 | # --- model related --- 222 | parser.add_argument('--input_size', type=int, default=1) 223 | parser.add_argument('--output_size', type=int, default=1) 224 | parser.add_argument('--unit_type', type=str, default="LSTM") 225 | parser.add_argument('--hidden_size', type=int, default=32) 226 | parser.add_argument('--skip', type=int, default=1) 227 | parser.add_argument('--bias_fl', default=True, action='store_true') 228 | parser.add_argument('--nlayers', type=int, default=1) 229 | 230 | return parser 231 | -------------------------------------------------------------------------------- /src/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import argparse 4 | import src.utils as utils 5 | # from contextlib import nullcontext 6 | 7 | 8 | def wrapperkwargs(func, kwargs): 9 | return func(**kwargs) 10 | 11 | 12 | def wrapperargs(func, args): 13 | return func(*args) 14 | 15 | 16 | # def load_model(model_data): 17 | # model_types = {'GCN': GCN} 18 | 19 | # model_meta = model_data.pop('model_data') 20 | 21 | # if model_meta['model'] == 'SimpleRNN' or model_meta['model'] == 'GatedConvNet': 22 | # network = wrapperkwargs( 23 | # model_types[model_meta.pop('model')], model_meta) 24 | # if 'state_dict' in model_data: 25 | # state_dict = network.state_dict() 26 | # for each in model_data['state_dict']: 27 | # state_dict[each] = torch.tensor(model_data['state_dict'][each]) 28 | # network.load_state_dict(state_dict) 29 | 30 | # elif model_meta['model'] == 'RecNet': 31 | # model_meta['blocks'] = [] 32 | # network = wrapperkwargs( 33 | # model_types[model_meta.pop('model')], model_meta) 34 | # for i in range(len(model_data['blocks'])): 35 | # network.add_layer(model_data['blocks'][str(i)]) 36 | 37 | # # Get the state dict from the newly created model and load the saved states, if states were saved 38 | # if 'state_dict' in model_data: 39 | # state_dict = network.state_dict() 40 | # for each in model_data['state_dict']: 41 | # state_dict[each] = torch.tensor(model_data['state_dict'][each]) 42 | # network.load_state_dict(state_dict) 43 | 44 | # if 'training_info' in model_data.keys(): 45 | # network.training_info = model_data['training_info'] 46 | 47 | # return network 48 | 49 | 50 | # This is a function for taking the old json config file format I used to use and converting it to the new format 51 | # def legacy_load(legacy_data): 52 | # if legacy_data['unit_type'] == 'GRU' or legacy_data['unit_type'] == 'LSTM': 53 | # model_data = {'model_data': { 54 | # 'model': 'RecNet', 'skip': 0}, 'blocks': {}} 55 | # model_data['blocks']['0'] = { 56 | # 'block_type': legacy_data['unit_type'], 57 | # 'input_size': legacy_data['in_size'], 58 | # 'hidden_size': legacy_data['hidden_size'], 59 | # 'output_size': 1, 60 | # 'lin_bias': True 61 | # } 62 | # if legacy_data['cur_epoch']: 63 | # training_info = { 64 | # 'current_epoch': legacy_data['cur_epoch'], 65 | # 'training_losses': legacy_data['tloss_list'], 66 | # 'val_losses': legacy_data['vloss_list'], 67 | # 'load_config': legacy_data['load_config'], 68 | # 'low_pass': legacy_data['low_pass'], 69 | # 'val_freq': legacy_data['val_freq'], 70 | # 'device': legacy_data['pedal'], 71 | # 'seg_length': legacy_data['seg_len'], 72 | # 'learning_rate': legacy_data['learn_rate'], 73 | # 'batch_size': legacy_data['batch_size'], 74 | # 'loss_func': legacy_data['loss_fcn'], 75 | # 'update_freq': legacy_data['up_fr'], 76 | # 'init_length': legacy_data['init_len'], 77 | # 'pre_filter': legacy_data['pre_filt'] 78 | # } 79 | # model_data['training_info'] = training_info 80 | 81 | # if 'state_dict' in legacy_data: 82 | # state_dict = legacy_data['state_dict'] 83 | # state_dict = dict(state_dict) 84 | # new_state_dict = {} 85 | # for each in state_dict: 86 | # new_name = each[0:7] + 'block_1.' + each[9:] 87 | # new_state_dict[new_name] = state_dict[each] 88 | # model_data['state_dict'] = new_state_dict 89 | # return model_data 90 | # else: 91 | # print('format not recognised') 92 | 93 | 94 | """ 95 | Gated convolutional layer, zero pads and then applies a causal convolution to the input 96 | """ 97 | 98 | 99 | class GatedConv1d(torch.nn.Module): 100 | 101 | def __init__(self, 102 | in_ch, 103 | out_ch, 104 | dilation, 105 | kernel_size): 106 | super(GatedConv1d, self).__init__() 107 | self.in_ch = in_ch 108 | self.out_ch = out_ch 109 | self.dilation = dilation 110 | self.kernal_size = kernel_size 111 | 112 | # Layers: Conv1D -> Activations -> Mix + Residual 113 | 114 | self.conv = torch.nn.Conv1d(in_channels=in_ch, 115 | out_channels=out_ch * 2, 116 | kernel_size=kernel_size, 117 | stride=1, 118 | padding=0, 119 | dilation=dilation) 120 | 121 | self.mix = torch.nn.Conv1d(in_channels=out_ch, 122 | out_channels=out_ch, 123 | kernel_size=1, 124 | stride=1, 125 | padding=0) 126 | 127 | def forward(self, x): 128 | # print("GatedConv1d: ", x.shape) 129 | residual = x 130 | 131 | # dilated conv 132 | y = self.conv(x) 133 | 134 | # gated activation 135 | z = torch.tanh(y[:, :self.out_ch, :]) * \ 136 | torch.sigmoid(y[:, self.out_ch:, :]) 137 | 138 | # zero pad on the left side, so that z is the same length as x 139 | z = torch.cat((torch.zeros(residual.shape[0], 140 | self.out_ch, 141 | residual.shape[2] - z.shape[2]), 142 | z), 143 | dim=2) 144 | 145 | x = self.mix(z) + residual 146 | 147 | return x, z 148 | 149 | 150 | """ 151 | Gated convolutional neural net block, applies successive gated convolutional layers to the input, a total of 'layers' 152 | layers are applied, with the filter size 'kernel_size' and the dilation increasing by a factor of 'dilation_growth' for 153 | each successive layer. 154 | """ 155 | 156 | 157 | class GCNBlock(torch.nn.Module): 158 | def __init__(self, 159 | in_ch, 160 | out_ch, 161 | nlayers, 162 | kernel_size, 163 | dilation_growth): 164 | super(GCNBlock, self).__init__() 165 | self.in_ch = in_ch 166 | self.out_ch = out_ch 167 | self.nlayers = nlayers 168 | self.kernel_size = kernel_size 169 | self.dilation_growth = dilation_growth 170 | 171 | dilations = [dilation_growth ** l for l in range(nlayers)] 172 | 173 | self.layers = torch.nn.ModuleList() 174 | 175 | for d in dilations: 176 | self.layers.append(GatedConv1d(in_ch=in_ch, 177 | out_ch=out_ch, 178 | dilation=d, 179 | kernel_size=kernel_size)) 180 | in_ch = out_ch 181 | 182 | def forward(self, x): 183 | # print("GCNBlock: ", x.shape) 184 | # [batch, channels, length] 185 | z = torch.empty([x.shape[0], 186 | self.nlayers * self.out_ch, 187 | x.shape[2]]) 188 | 189 | for n, layer in enumerate(self.layers): 190 | x, zn = layer(x) 191 | z[:, n * self.out_ch: (n + 1) * self.out_ch, :] = zn 192 | 193 | return x, z 194 | 195 | 196 | """ 197 | Gated Convolutional Neural Net class, based on the 'WaveNet' architecture, takes a single channel of audio as input and 198 | produces a single channel of audio of equal length as output. one-sided zero-padding is used to ensure the network is 199 | causal and doesn't reduce the length of the audio. 200 | 201 | Made up of 'blocks', each one applying a series of dilated convolutions, with the dilation of each successive layer 202 | increasing by a factor of 'dilation_growth'. 'layers' determines how many convolutional layers are in each block, 203 | 'kernel_size' is the size of the filters. Channels is the number of convolutional channels. 204 | 205 | The output of the model is creating by the linear mixer, which sums weighted outputs from each of the layers in the 206 | model 207 | """ 208 | 209 | 210 | class GCN(torch.nn.Module): 211 | def __init__(self, 212 | nblocks=2, 213 | nlayers=9, 214 | nchannels=8, 215 | kernel_size=3, 216 | dilation_growth=2, 217 | **kwargs): 218 | super(GCN, self).__init__() 219 | self.nblocks = nblocks 220 | self.nlayers = nlayers 221 | self.nchannels = nchannels 222 | self.kernel_size = kernel_size 223 | self.dilation_growth = dilation_growth 224 | 225 | self.blocks = torch.nn.ModuleList() 226 | for b in range(nblocks): 227 | self.blocks.append(GCNBlock(in_ch=1 if b == 0 else nchannels, 228 | out_ch=nchannels, 229 | nlayers=nlayers, 230 | kernel_size=kernel_size, 231 | dilation_growth=dilation_growth)) 232 | 233 | # output mixing layer 234 | self.blocks.append( 235 | torch.nn.Conv1d(in_channels=nchannels * nlayers * nblocks, 236 | out_channels=1, 237 | kernel_size=1, 238 | stride=1, 239 | padding=0)) 240 | 241 | def forward(self, x): 242 | # print("GCN: ", x.shape) 243 | # x.shape = [length, batch, channels] 244 | x = x.permute(1, 2, 0) # change to [batch, channels, length] 245 | z = torch.empty([x.shape[0], self.blocks[-1].in_channels, x.shape[2]]) 246 | 247 | for n, block in enumerate(self.blocks[:-1]): 248 | x, zn = block(x) 249 | z[:, 250 | n * self.nchannels * self.nlayers: 251 | (n + 1) * self.nchannels * self.nlayers, 252 | :] = zn 253 | 254 | # back to [length, batch, channels] 255 | return self.blocks[-1](z).permute(2, 0, 1) 256 | 257 | # train_epoch runs one epoch of training 258 | def train_epoch(self, 259 | input_data, 260 | target_data, 261 | loss_fcn, 262 | optim, 263 | batch_size): 264 | # shuffle the segments at the start of the epoch 265 | shuffle = torch.randperm(input_data.shape[1]) 266 | 267 | # iterate over the batches 268 | ep_losses = None 269 | 270 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)): 271 | # zero all gradients 272 | self.zero_grad() 273 | 274 | # load batch 275 | input_batch = input_data[:, 276 | shuffle[batch_i * batch_size: 277 | (batch_i + 1) * batch_size], 278 | :] 279 | target_batch = target_data[:, 280 | shuffle[batch_i * batch_size: 281 | (batch_i + 1) * batch_size], 282 | :] 283 | 284 | # process batch 285 | output = self(input_batch) 286 | 287 | # loss and backprop 288 | batch_losses = loss_fcn(output, target_batch) 289 | 290 | tot_batch_loss = 0 291 | for loss in batch_losses: 292 | tot_batch_loss += batch_losses[loss] 293 | 294 | tot_batch_loss.backward() 295 | optim.step() 296 | 297 | # add batch losses to epoch losses 298 | for loss in batch_losses: 299 | if ep_losses == None: 300 | ep_losses = batch_losses 301 | else: 302 | ep_losses[loss] += batch_losses[loss] 303 | 304 | # mean epoch losses 305 | for loss in ep_losses: 306 | ep_losses[loss] /= (batch_i + 1) 307 | 308 | return ep_losses 309 | 310 | def process_data(self, 311 | input_data, 312 | target_data=None, 313 | chunk=16384, 314 | loss_fcn=None, 315 | grad=False): 316 | 317 | rf = self.compute_receptive_field() 318 | 319 | if not (input_data.shape[0] / chunk).is_integer(): 320 | # round to next chunk size 321 | padding = chunk - (input_data.shape[0] % chunk) 322 | input_data = torch.nn.functional.pad(input_data, 323 | (0, 0, 0, 0, 0, padding), 324 | mode='constant', 325 | value=0) 326 | if target_data != None: 327 | target_data = torch.nn.functional.pad(target_data, 328 | (0, 0, 0, 0, 0, padding), 329 | mode='constant', 330 | value=0) 331 | 332 | with torch.no_grad(): 333 | # process input 334 | output_data = torch.empty_like(input_data) 335 | 336 | for l in range(int(output_data.size()[0] / chunk)): 337 | input_chunk = input_data[l * chunk: (l + 1) * chunk] 338 | if l == 0: # first chunk 339 | padding = torch.zeros([rf, input_chunk.shape[1], input_chunk.shape[2]]) 340 | else: 341 | padding = input_data[(l * chunk) - rf: l * chunk] 342 | input_chunk = torch.cat([padding, input_chunk]) 343 | output_chunk = self(input_chunk) 344 | output_data[l * chunk: (l + 1) * chunk] = \ 345 | output_chunk[rf:, :, :] 346 | 347 | if loss_fcn != None and target_data != None: 348 | losses = loss_fcn(output_data, target_data) 349 | return input_data, target_data, output_data, losses 350 | 351 | return input_data, target_data, output_data 352 | 353 | def save_model(self, 354 | file_name, 355 | direc=""): 356 | if direc: 357 | utils.dir_check(direc) 358 | 359 | model_data = {"model_data": {"model_type": "gcn", 360 | "nblocks": self.nblocks, 361 | "nlayers": self.nlayers, 362 | "nchannels": self.nchannels, 363 | "kernel_size": self.kernel_size, 364 | "dilation_growth": self.dilation_growth}} 365 | 366 | if self.save_state: 367 | model_state = self.state_dict() 368 | for each in model_state: 369 | model_state[each] = model_state[each].tolist() 370 | model_data["state_dict"] = model_state 371 | 372 | utils.json_save(model_data, file_name, direc) 373 | 374 | def compute_receptive_field(self): 375 | """ Compute the receptive field in samples.""" 376 | rf = self.kernel_size 377 | for n in range(1, self.nblocks * self.nlayers): 378 | dilation = self.dilation_growth ** (n % self.nlayers) 379 | rf = rf + ((self.kernel_size-1) * dilation) 380 | return rf 381 | 382 | # add any model hyperparameters here 383 | @staticmethod 384 | def add_model_specific_args(parent_parser): 385 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 386 | # --- model related --- 387 | parser.add_argument('--nblocks', type=int, default=2) 388 | parser.add_argument('--nlayers', type=int, default=9) 389 | parser.add_argument('--nchannels', type=int, default=16) 390 | parser.add_argument('--kernel_size', type=int, default=3) 391 | parser.add_argument('--dilation_growth', type=int, default=2) 392 | 393 | return parser 394 | -------------------------------------------------------------------------------- /src/gcntfilm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import argparse 4 | import src.utils as utils 5 | from contextlib import nullcontext 6 | 7 | 8 | def wrapperkwargs(func, kwargs): 9 | return func(**kwargs) 10 | 11 | 12 | def wrapperargs(func, args): 13 | return func(*args) 14 | 15 | 16 | # def load_model(model_data): 17 | # model_types = {"GCN": GCN} 18 | 19 | # model_meta = model_data.pop("model_data") 20 | 21 | # if model_meta["model"] == "SimpleRNN" or model_meta["model"] == "GCN": 22 | # network = wrapperkwargs( 23 | # model_types[model_meta.pop("model")], model_meta) 24 | # if "state_dict" in model_data: 25 | # state_dict = network.state_dict() 26 | # for each in model_data["state_dict"]: 27 | # state_dict[each] = torch.tensor(model_data["state_dict"][each]) 28 | # network.load_state_dict(state_dict) 29 | 30 | # elif model_meta["model"] == "RecNet": 31 | # model_meta["blocks"] = [] 32 | # network = wrapperkwargs( 33 | # model_types[model_meta.pop("model")], model_meta) 34 | # for i in range(len(model_data["blocks"])): 35 | # network.add_layer(model_data["blocks"][str(i)]) 36 | 37 | # # Get the state dict from the newly created model and load the saved states, if states were saved 38 | # if "state_dict" in model_data: 39 | # state_dict = network.state_dict() 40 | # for each in model_data["state_dict"]: 41 | # state_dict[each] = torch.tensor(model_data["state_dict"][each]) 42 | # network.load_state_dict(state_dict) 43 | 44 | # if "training_info" in model_data.keys(): 45 | # network.training_info = model_data["training_info"] 46 | 47 | # return network 48 | 49 | 50 | # This is a function for taking the old json config file format I used to use and converting it to the new format 51 | # def legacy_load(legacy_data): 52 | # if legacy_data["unit_type"] == "GRU" or legacy_data["unit_type"] == "LSTM": 53 | # model_data = {"model_data": { 54 | # "model": "RecNet", "skip": 0}, "blocks": {}} 55 | # model_data["blocks"]["0"] = { 56 | # "block_type": legacy_data["unit_type"], 57 | # "input_size": legacy_data["in_size"], 58 | # "hidden_size": legacy_data["hidden_size"], 59 | # "output_size": 1, 60 | # "lin_bias": True, 61 | # } 62 | # if legacy_data["cur_epoch"]: 63 | # training_info = { 64 | # "current_epoch": legacy_data["cur_epoch"], 65 | # "training_losses": legacy_data["tloss_list"], 66 | # "val_losses": legacy_data["vloss_list"], 67 | # "load_config": legacy_data["load_config"], 68 | # "low_pass": legacy_data["low_pass"], 69 | # "val_freq": legacy_data["val_freq"], 70 | # "device": legacy_data["pedal"], 71 | # "seg_length": legacy_data["seg_len"], 72 | # "learning_rate": legacy_data["learn_rate"], 73 | # "batch_size": legacy_data["batch_size"], 74 | # "loss_func": legacy_data["loss_fcn"], 75 | # "update_freq": legacy_data["up_fr"], 76 | # "init_length": legacy_data["init_len"], 77 | # "pre_filter": legacy_data["pre_filt"], 78 | # } 79 | # model_data["training_info"] = training_info 80 | 81 | # if "state_dict" in legacy_data: 82 | # state_dict = legacy_data["state_dict"] 83 | # state_dict = dict(state_dict) 84 | # new_state_dict = {} 85 | # for each in state_dict: 86 | # new_name = each[0:7] + "block_1." + each[9:] 87 | # new_state_dict[new_name] = state_dict[each] 88 | # model_data["state_dict"] = new_state_dict 89 | # return model_data 90 | # else: 91 | # print("format not recognised") 92 | 93 | 94 | """ 95 | Temporal FiLM layer 96 | """ 97 | 98 | 99 | class TFiLM(torch.nn.Module): 100 | def __init__(self, 101 | nchannels, 102 | block_size=128): 103 | super(TFiLM, self).__init__() 104 | self.nchannels = nchannels 105 | self.block_size = block_size 106 | self.num_layers = 1 107 | self.hidden_state = None # (hidden_state, cell_state) 108 | 109 | # used to downsample input 110 | self.maxpool = torch.nn.MaxPool1d(kernel_size=block_size, 111 | stride=None, 112 | padding=0, 113 | dilation=1, 114 | return_indices=False, 115 | ceil_mode=False) 116 | 117 | self.lstm = torch.nn.LSTM(input_size=nchannels, 118 | hidden_size=nchannels, 119 | num_layers=self.num_layers, 120 | batch_first=False, 121 | bidirectional=False) 122 | 123 | def forward(self, x): 124 | # print("TFiLM: ", x.shape) 125 | # x = [batch, channels, length] 126 | x_shape = x.shape 127 | nsteps = int(x_shape[-1] / self.block_size) 128 | 129 | # downsample 130 | x_down = self.maxpool(x) 131 | 132 | # shape for LSTM (length, batch, channels) 133 | x_down = x_down.permute(2, 0, 1) 134 | 135 | # modulation sequence 136 | if self.hidden_state == None: # state was reset 137 | # init hidden and cell states with zeros 138 | h0 = torch.zeros(self.num_layers, x.size(0), self.nchannels).requires_grad_() 139 | c0 = torch.zeros(self.num_layers, x.size(0), self.nchannels).requires_grad_() 140 | x_norm, self.hidden_state = self.lstm(x_down, (h0.detach(), c0.detach())) # detach for truncated BPTT 141 | else: 142 | x_norm, self.hidden_state = self.lstm(x_down, self.hidden_state) 143 | 144 | # put shape back (batch, channels, length) 145 | x_norm = x_norm.permute(1, 2, 0) 146 | 147 | # reshape input and modulation sequence into blocks 148 | x_in = torch.reshape( 149 | x, shape=(-1, self.nchannels, nsteps, self.block_size)) 150 | x_norm = torch.reshape( 151 | x_norm, shape=(-1, self.nchannels, nsteps, 1)) 152 | 153 | # multiply 154 | x_out = x_norm * x_in 155 | 156 | # return to original shape 157 | x_out = torch.reshape(x_out, shape=(x_shape)) 158 | 159 | return x_out 160 | 161 | def detach_state(self): 162 | if self.hidden_state.__class__ == tuple: 163 | self.hidden_state = tuple([h.clone().detach() for h in self.hidden_state]) 164 | else: 165 | self.hidden_state = self.hidden_state.clone().detach() 166 | 167 | def reset_state(self): 168 | self.hidden_state = None 169 | 170 | 171 | """ 172 | Gated convolutional layer, zero pads and then applies a causal convolution to the input 173 | """ 174 | 175 | 176 | class GatedConv1d(torch.nn.Module): 177 | def __init__(self, 178 | in_ch, 179 | out_ch, 180 | dilation, 181 | kernel_size, 182 | tfilm_block_size): 183 | super(GatedConv1d, self).__init__() 184 | self.in_ch = in_ch 185 | self.out_ch = out_ch 186 | self.dilation = dilation 187 | self.kernal_size = kernel_size 188 | self.tfilm_block_size = tfilm_block_size 189 | 190 | # Layers: Conv1D -> Activations -> TFiLM -> Mix + Residual 191 | 192 | self.conv = torch.nn.Conv1d(in_channels=in_ch, 193 | out_channels=out_ch * 2, 194 | kernel_size=kernel_size, 195 | stride=1, 196 | padding=0, 197 | dilation=dilation) 198 | 199 | self.tfilm = TFiLM(nchannels=out_ch, 200 | block_size=tfilm_block_size) 201 | 202 | self.mix = torch.nn.Conv1d(in_channels=out_ch, 203 | out_channels=out_ch, 204 | kernel_size=1, 205 | stride=1, 206 | padding=0) 207 | 208 | def forward(self, x): 209 | # print("GatedConv1d: ", x.shape) 210 | residual = x 211 | 212 | # dilated conv 213 | y = self.conv(x) 214 | 215 | # gated activation 216 | z = torch.tanh(y[:, :self.out_ch, :]) * \ 217 | torch.sigmoid(y[:, self.out_ch:, :]) 218 | 219 | # zero pad on the left side, so that z is the same length as x 220 | z = torch.cat((torch.zeros(residual.shape[0], 221 | self.out_ch, 222 | residual.shape[2] - z.shape[2]), 223 | z), 224 | dim=2) 225 | 226 | # modulation 227 | z = self.tfilm(z) 228 | 229 | x = self.mix(z) + residual 230 | 231 | return x, z 232 | 233 | 234 | """ 235 | Gated convolutional neural net block, applies successive gated convolutional layers to the input, a total of 'layers' 236 | layers are applied, with the filter size 'kernel_size' and the dilation increasing by a factor of 'dilation_growth' for 237 | each successive layer. 238 | """ 239 | 240 | 241 | class GCNBlock(torch.nn.Module): 242 | def __init__(self, 243 | in_ch, 244 | out_ch, 245 | nlayers, 246 | kernel_size, 247 | dilation_growth, 248 | tfilm_block_size): 249 | super(GCNBlock, self).__init__() 250 | self.in_ch = in_ch 251 | self.out_ch = out_ch 252 | self.nlayers = nlayers 253 | self.kernel_size = kernel_size 254 | self.dilation_growth = dilation_growth 255 | self.tfilm_block_size = tfilm_block_size 256 | 257 | dilations = [dilation_growth ** l for l in range(nlayers)] 258 | 259 | self.layers = torch.nn.ModuleList() 260 | 261 | for d in dilations: 262 | self.layers.append(GatedConv1d(in_ch=in_ch, 263 | out_ch=out_ch, 264 | dilation=d, 265 | kernel_size=kernel_size, 266 | tfilm_block_size=tfilm_block_size)) 267 | in_ch = out_ch 268 | 269 | def forward(self, x): 270 | # print("GCNBlock: ", x.shape) 271 | # [batch, channels, length] 272 | z = torch.empty([x.shape[0], 273 | self.nlayers * self.out_ch, 274 | x.shape[2]]) 275 | 276 | for n, layer in enumerate(self.layers): 277 | x, zn = layer(x) 278 | z[:, n * self.out_ch: (n + 1) * self.out_ch, :] = zn 279 | 280 | return x, z 281 | 282 | 283 | """ 284 | Gated Convolutional Neural Net class, based on the 'WaveNet' architecture, takes a single channel of audio as input and 285 | produces a single channel of audio of equal length as output. one-sided zero-padding is used to ensure the network is 286 | causal and doesn't reduce the length of the audio. 287 | 288 | Made up of 'blocks', each one applying a series of dilated convolutions, with the dilation of each successive layer 289 | increasing by a factor of 'dilation_growth'. 'layers' determines how many convolutional layers are in each block, 290 | 'kernel_size' is the size of the filters. Channels is the number of convolutional channels. 291 | 292 | The output of the model is creating by the linear mixer, which sums weighted outputs from each of the layers in the 293 | model 294 | """ 295 | 296 | 297 | class GCNTF(torch.nn.Module): 298 | def __init__(self, 299 | nblocks=2, 300 | nlayers=9, 301 | nchannels=8, 302 | kernel_size=3, 303 | dilation_growth=2, 304 | tfilm_block_size=128, 305 | **kwargs): 306 | super(GCNTF, self).__init__() 307 | self.nblocks = nblocks 308 | self.nlayers = nlayers 309 | self.nchannels = nchannels 310 | self.kernel_size = kernel_size 311 | self.dilation_growth = dilation_growth 312 | self.tfilm_block_size = tfilm_block_size 313 | 314 | self.blocks = torch.nn.ModuleList() 315 | for b in range(nblocks): 316 | self.blocks.append(GCNBlock(in_ch=1 if b == 0 else nchannels, 317 | out_ch=nchannels, 318 | nlayers=nlayers, 319 | kernel_size=kernel_size, 320 | dilation_growth=dilation_growth, 321 | tfilm_block_size=tfilm_block_size)) 322 | 323 | # output mixing layer 324 | self.blocks.append( 325 | torch.nn.Conv1d(in_channels=nchannels * nlayers * nblocks, 326 | out_channels=1, 327 | kernel_size=1, 328 | stride=1, 329 | padding=0)) 330 | 331 | def forward(self, x): 332 | # print("GCN: ", x.shape) 333 | # x.shape = [length, batch, channels] 334 | x = x.permute(1, 2, 0) # change to [batch, channels, length] 335 | z = torch.empty([x.shape[0], self.blocks[-1].in_channels, x.shape[2]]) 336 | 337 | for n, block in enumerate(self.blocks[:-1]): 338 | x, zn = block(x) 339 | z[:, 340 | n * self.nchannels * self.nlayers: 341 | (n + 1) * self.nchannels * self.nlayers, 342 | :] = zn 343 | 344 | # back to [length, batch, channels] 345 | return self.blocks[-1](z).permute(2, 0, 1) 346 | 347 | def detach_states(self): 348 | # print("DETACH STATES") 349 | for layer in self.modules(): 350 | if isinstance(layer, TFiLM): 351 | layer.detach_state() 352 | 353 | # reset state for all TFiLM layers 354 | def reset_states(self): 355 | # print("RESET STATES") 356 | for layer in self.modules(): 357 | if isinstance(layer, TFiLM): 358 | layer.reset_state() 359 | 360 | # train_epoch runs one epoch of training 361 | def train_epoch(self, 362 | input_data, 363 | target_data, 364 | loss_fcn, 365 | optim, 366 | batch_size): 367 | # shuffle the segments at the start of the epoch 368 | shuffle = torch.randperm(input_data.shape[1]) 369 | 370 | # iterate over the batches 371 | ep_losses = None 372 | 373 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)): 374 | # reset states before starting new batch 375 | self.reset_states() 376 | 377 | # zero all gradients 378 | self.zero_grad() 379 | 380 | # load batch 381 | input_batch = input_data[:, 382 | shuffle[batch_i * batch_size: 383 | (batch_i + 1) * batch_size], 384 | :] 385 | target_batch = target_data[:, 386 | shuffle[batch_i * batch_size: 387 | (batch_i + 1) * batch_size], 388 | :] 389 | 390 | # process batch 391 | output = self(input_batch) 392 | 393 | # loss and backprop 394 | batch_losses = loss_fcn(output, target_batch) 395 | 396 | tot_batch_loss = 0 397 | for loss in batch_losses: 398 | tot_batch_loss += batch_losses[loss] 399 | 400 | tot_batch_loss.backward() 401 | optim.step() 402 | 403 | # add batch losses to epoch losses 404 | # ep_loss += loss 405 | for loss in batch_losses: 406 | if ep_losses == None: 407 | ep_losses = batch_losses 408 | else: 409 | ep_losses[loss] += batch_losses[loss] 410 | 411 | # mean epoch losses 412 | for loss in ep_losses: 413 | ep_losses[loss] /= (batch_i + 1) 414 | 415 | return ep_losses 416 | 417 | def process_data(self, 418 | input_data, 419 | target_data=None, 420 | chunk=16384, 421 | loss_fcn=None, 422 | grad=False): 423 | 424 | rf = self.compute_receptive_field() 425 | # round to next block size 426 | rf = rf + (self.tfilm_block_size - rf % self.tfilm_block_size) 427 | 428 | if not (input_data.shape[0] / chunk).is_integer(): 429 | # round to next chunk size 430 | padding = chunk - (input_data.shape[0] % chunk) 431 | input_data = torch.nn.functional.pad(input_data, 432 | (0, 0, 0, 0, 0, padding), 433 | mode='constant', 434 | value=0) 435 | if target_data != None: 436 | target_data = torch.nn.functional.pad(target_data, 437 | (0, 0, 0, 0, 0, padding), 438 | mode='constant', 439 | value=0) 440 | 441 | with torch.no_grad(): 442 | # reset states before processing 443 | self.reset_states() 444 | 445 | # process input 446 | output_data = torch.empty_like(input_data) 447 | 448 | for l in range(int(output_data.size()[0] / chunk)): 449 | input_chunk = input_data[l * chunk: (l + 1) * chunk] 450 | if l == 0: # first chunk 451 | padding = torch.zeros([rf, input_chunk.shape[1], input_chunk.shape[2]]) 452 | else: 453 | padding = input_data[(l * chunk) - rf: l * chunk] 454 | input_chunk = torch.cat([padding, input_chunk]) 455 | output_chunk = self(input_chunk) 456 | output_data[l * chunk: (l + 1) * chunk] = \ 457 | output_chunk[rf:, :, :] 458 | 459 | if loss_fcn != None and target_data != None: 460 | losses = loss_fcn(output_data, target_data) 461 | return input_data, target_data, output_data, losses 462 | 463 | # reset states before other computations 464 | self.reset_states() 465 | 466 | return input_data, target_data, output_data 467 | 468 | # this functions saves the model and all its paraemters to a json file, so it can be loaded by a JUCE plugin 469 | def save_model(self, 470 | file_name, 471 | direc=""): 472 | if direc: 473 | utils.dir_check(direc) 474 | 475 | model_data = {"model_data": {"model_type": "gcntf", 476 | "nblocks": self.nblocks, 477 | "nlayers": self.nlayers, 478 | "nchannels": self.nchannels, 479 | "kernel_size": self.kernel_size, 480 | "dilation_growth": self.dilation_growth, 481 | "tfilm_block_size": self.tfilm_block_size}} 482 | 483 | if self.save_state: 484 | model_state = self.state_dict() 485 | for each in model_state: 486 | model_state[each] = model_state[each].tolist() 487 | model_data["state_dict"] = model_state 488 | 489 | utils.json_save(model_data, file_name, direc) 490 | 491 | def compute_receptive_field(self): 492 | """ Compute the receptive field in samples.""" 493 | rf = self.kernel_size 494 | for n in range(1, self.nblocks * self.nlayers): 495 | dilation = self.dilation_growth ** (n % self.nlayers) 496 | rf = rf + ((self.kernel_size-1) * dilation) 497 | return rf 498 | 499 | # add any model hyperparameters here 500 | @staticmethod 501 | def add_model_specific_args(parent_parser): 502 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 503 | # --- model related --- 504 | parser.add_argument('--nblocks', type=int, default=2) 505 | parser.add_argument('--nlayers', type=int, default=9) 506 | parser.add_argument('--nchannels', type=int, default=16) 507 | parser.add_argument('--kernel_size', type=int, default=3) 508 | parser.add_argument('--dilation_growth', type=int, default=2) 509 | parser.add_argument('--tfilm_block_size', type=int, default=128) 510 | 511 | return parser 512 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import scipy 5 | import torch 6 | import torchinfo 7 | import torch.utils.tensorboard as tensorboard 8 | 9 | import src.utils as utils 10 | import src.training as training 11 | import src.data as data 12 | 13 | from src.rnn import RNN 14 | from src.gcn import GCN 15 | from src.gcntfilm import GCNTF 16 | 17 | torch.backends.cudnn.benchmark = True 18 | 19 | train_configs = [ 20 | # { # from Wright et al. - 2020 21 | # "name": "LSTM32", 22 | # "model_type": "rnn", 23 | # "hidden_size": 32, 24 | # "unit_type": "LSTM", 25 | # "loss_fcns": {"ESRPre": 0.75, "DC": 0.25}, 26 | # "prefilt": "a-weighting", 27 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 28 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 29 | # "train_length": 132300, 30 | # "val_chunk": 16384, 31 | # "test_chunk": 16384, 32 | # "validation_p": 20, 33 | # "batch_size": 40, 34 | # "up_fr": 1000, 35 | # }, 36 | # { # from Wright et al. - 2020 37 | # "name": "LSTM96", 38 | # "model_type": "rnn", 39 | # "hidden_size": 96, 40 | # "unit_type": "LSTM", 41 | # "loss_fcns": {"ESRPre": 0.75, "DC": 0.25}, 42 | # "prefilt": "a-weighting", 43 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 44 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 45 | # "train_length": 132300, 46 | # "val_chunk": 16384, 47 | # "test_chunk": 16384, 48 | # "validation_p": 20, 49 | # "batch_size": 40, 50 | # "up_fr": 1000, 51 | # }, 52 | # { # from Wright et al. - 2020 53 | # "name": "GCN1", 54 | # "model_type": "gcn", 55 | # "nblocks": 1, 56 | # "nlayers": 10, 57 | # "nchannels": 16, 58 | # "kernel_size": 3, 59 | # "dilation_growth": 2, 60 | # "loss_fcns": {"ESRPre": 1.0, "DC": 0.0}, 61 | # "prefilt": "high_pass", 62 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 63 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 64 | # "train_length": 4410, 65 | # "val_chunk": 441000, 66 | # "test_chunk": 441000, 67 | # "validation_p": 20, 68 | # "batch_size": 40 69 | # }, 70 | # { # from Wright et al. - 2020 71 | # "name": "GCN3", 72 | # "model_type": "gcn", 73 | # "nblocks": 2, 74 | # "nlayers": 9, 75 | # "nchannels": 16, 76 | # "kernel_size": 3, 77 | # "dilation_growth": 2, 78 | # "loss_fcns": {"ESRPre": 1.0, "DC": 0.0}, 79 | # "prefilt": "high_pass", 80 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 81 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 82 | # "train_length": 4410, 83 | # "val_chunk": 441000, 84 | # "test_chunk": 441000, 85 | # "validation_p": 20, 86 | # "batch_size": 40 87 | # }, 88 | { 89 | "name": "LSTM32", 90 | "model_type": "rnn", 91 | "hidden_size": 32, 92 | "unit_type": "LSTM", 93 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 94 | "prefilt": None, 95 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 96 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 97 | "train_length": 112640, 98 | "val_chunk": 112640, 99 | "test_chunk": 112640, 100 | "validation_p": 20, 101 | "batch_size": 6, 102 | "up_fr": 2048, 103 | }, 104 | { 105 | "name": "LSTM96", 106 | "model_type": "rnn", 107 | "hidden_size": 96, 108 | "unit_type": "LSTM", 109 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 110 | "prefilt": None, 111 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 112 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 113 | "train_length": 112640, 114 | "val_chunk": 112640, 115 | "test_chunk": 112640, 116 | "validation_p": 20, 117 | "batch_size": 6, 118 | "up_fr": 2048, 119 | }, 120 | { 121 | "name": "GCN1", 122 | "model_type": "gcn", 123 | "nblocks": 1, 124 | "nlayers": 10, 125 | "nchannels": 16, 126 | "kernel_size": 3, 127 | "dilation_growth": 2, 128 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 129 | "prefilt": None, 130 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 131 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 132 | "train_length": 112640, 133 | "val_chunk": 112640, 134 | "test_chunk": 112640, 135 | "validation_p": 20, 136 | "batch_size": 6, 137 | }, 138 | { 139 | "name": "GCN3", 140 | "model_type": "gcn", 141 | "nblocks": 2, 142 | "nlayers": 9, 143 | "nchannels": 16, 144 | "kernel_size": 3, 145 | "dilation_growth": 2, 146 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 147 | "prefilt": None, 148 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 149 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 150 | "train_length": 112640, 151 | "val_chunk": 112640, 152 | "test_chunk": 112640, 153 | "validation_p": 20, 154 | "batch_size": 6, 155 | }, 156 | { 157 | "name": "GCNTF1", 158 | "model_type": "gcntf", 159 | "nblocks": 1, 160 | "nlayers": 10, 161 | "nchannels": 16, 162 | "kernel_size": 3, 163 | "dilation_growth": 2, 164 | "tfilm_block_size": 128, 165 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 166 | "prefilt": None, 167 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 168 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 169 | "train_length": 112640, 170 | "val_chunk": 112640, 171 | "test_chunk": 112640, 172 | "validation_p": 20, 173 | "batch_size": 6 174 | }, 175 | { 176 | "name": "GCNTF3", 177 | "model_type": "gcntf", 178 | "nblocks": 2, 179 | "nlayers": 9, 180 | "nchannels": 16, 181 | "kernel_size": 3, 182 | "dilation_growth": 2, 183 | "tfilm_block_size": 128, 184 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 185 | "prefilt": None, 186 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 187 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 188 | "train_length": 112640, 189 | "val_chunk": 112640, 190 | "test_chunk": 112640, 191 | "validation_p": 40, 192 | "batch_size": 6 193 | }, 194 | { 195 | "name": "GCNTF2500", 196 | "model_type": "gcntf", 197 | "nblocks": 1, 198 | "nlayers": 10, 199 | "nchannels": 16, 200 | "kernel_size": 5, 201 | "dilation_growth": 3, 202 | "tfilm_block_size": 128, 203 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 204 | "prefilt": None, 205 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 206 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 207 | "train_length": 118400, 208 | "val_chunk": 118400, 209 | "test_chunk": 118400, 210 | "validation_p": 20, 211 | "batch_size": 6 212 | }, 213 | { 214 | "name": "GCNTF250", 215 | "model_type": "gcntf", 216 | "nblocks": 1, 217 | "nlayers": 4, 218 | "nchannels": 16, 219 | "kernel_size": 41, 220 | "dilation_growth": 6, 221 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 222 | "prefilt": None, 223 | "device": "fuzz-rndamp-G5S10A1msR2500ms", 224 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 225 | "train_length": 112640, 226 | "val_chunk": 112640, 227 | "test_chunk": 112640, 228 | "validation_p": 20, 229 | "batch_size": 6 230 | }, 231 | # { 232 | # "name": "GCNTF1000", 233 | # "model_type": "gcntf", 234 | # "nblocks": 1, 235 | # "nlayers": 4, 236 | # "nchannels": 16, 237 | # "kernel_size": 41, 238 | # "dilation_growth": 10, 239 | # "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 240 | # "prefilt": None, 241 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 242 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 243 | # "train_length": 112640, 244 | # "val_chunk": 112640, 245 | # "test_chunk": 112640, 246 | # "validation_p": 20, 247 | # "batch_size": 6 248 | # }, 249 | # { 250 | # "name": "GCNTF500", 251 | # "model_type": "gcntf", 252 | # "nblocks": 1, 253 | # "nlayers": 4, 254 | # "nchannels": 16, 255 | # "kernel_size": 41, 256 | # "dilation_growth": 8, 257 | # "loss_fcns": {"L1": 0.5, "MSTFT": 0.5}, 258 | # "prefilt": None, 259 | # "device": "fuzz-rndamp-G5S10A1msR2500ms", 260 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms", 261 | # "train_length": 112640, 262 | # "val_chunk": 112640, 263 | # "test_chunk": 112640, 264 | # "validation_p": 20, 265 | # "batch_size": 6 266 | # } 267 | ] 268 | 269 | n_configs = len(train_configs) 270 | 271 | for idx, tconf in enumerate(train_configs): 272 | 273 | parser = argparse.ArgumentParser() 274 | 275 | # add PROGRAM level args 276 | parser.add_argument("--model_type", type=str, default="gcntf", help="rnn, gcn, gcntf") 277 | # data locations, file names and config 278 | parser.add_argument("--device", "-d", default="ht1", help="device label") 279 | parser.add_argument("--data_rootdir", "-dr", default="./data", help="data directory") 280 | parser.add_argument("--file_name", "-fn", default="ht1", help="filename-input.wav and -target.wav") 281 | # parser.add_argument('--load_config', '-l', help="config file path") 282 | # parser.add_argument('--config_location', '-cl', default='configs', help='configs directory') 283 | parser.add_argument("--save_location", "-sloc", default="results", help="trained models directory") 284 | parser.add_argument("--load_model", "-lm", action="store_true", help="load pretrained model") 285 | 286 | # pre-processing of the training/val/test data 287 | parser.add_argument("--train_length", "-trlen", type=int, default=16384, help="training frame length in samples") 288 | parser.add_argument("--val_length", "-vllen", type=int, default=0, help="training frame length in samples") 289 | parser.add_argument("--test_length", "-tslen", type=int, default=0, help="training frame length in samples") 290 | 291 | # number of epochs and validation 292 | parser.add_argument("--epochs", "-eps", type=int, default=2000, help="max epochs") 293 | parser.add_argument("--validation_f", "-vfr", type=int, default=2, help="validation frequency (in epochs)") 294 | parser.add_argument("--validation_p", "-vp", type=int, default=25, help="validation patience or None") 295 | 296 | # settings for the training epoch 297 | parser.add_argument("--batch_size", "-bs", type=int, default=40, help="mini-batch size") 298 | parser.add_argument("--iter_num", "-it", type=int, default=None, help="batch_size set to have --iter_num batches") 299 | parser.add_argument("--learn_rate", "-lr", type=float, default=0.005, help="initial learning rate") 300 | parser.add_argument("--cuda", "-cu", default=1, help="use GPU if available") 301 | 302 | # loss function/s 303 | parser.add_argument("--loss_fcns", "-lf", default={"ESRPre": 0.75, "DC": 0.25}, help="loss functions and weights") 304 | parser.add_argument("--prefilt", "-pf", default="high_pass", help="pre-emphasis filter coefficients, can also read in a csv file") 305 | 306 | # validation and test sets chunk size 307 | parser.add_argument("--val_chunk", "-vs", type=int, default=100000, help="validation chunk length") 308 | parser.add_argument("--test_chunk", "-tc", type=int, default=100000, help="test chunk length") 309 | 310 | # parse general args 311 | args = parser.parse_args() 312 | 313 | # add model specific args 314 | if tconf["model_type"] == "rnn": 315 | parser = RNN.add_model_specific_args(parser) 316 | elif tconf["model_type"] == "gcn": 317 | parser = GCN.add_model_specific_args(parser) 318 | elif tconf["model_type"] == "gcntf": 319 | parser = GCNTF.add_model_specific_args(parser) 320 | 321 | # parse general + model args 322 | args = parser.parse_args() 323 | 324 | # create dictionary with args 325 | dict_args = vars(args) 326 | 327 | # overwrite with temporary configuration 328 | dict_args.update(tconf) 329 | 330 | # set filter args 331 | if dict_args["prefilt"] == "a-weighting": 332 | # as reported in in https://ieeexplore.ieee.org/abstract/document/9052944 333 | dict_args["prefilt"] = [0.85, 1] 334 | elif dict_args["prefilt"] == "high_pass": 335 | # args.prefilt = [-0.85, 1] # as reported in https://ieeexplore.ieee.org/abstract/document/9052944 336 | # as reported in (https://www.mdpi.com/2076-3417/10/3/766/htm) 337 | dict_args["prefilt"] = [-0.95, 1] 338 | else: 339 | dict_args["prefilt"] = None 340 | 341 | # directory where results will be saved 342 | if dict_args["model_type"] == "rnn": 343 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}" 344 | specifier += f"__{dict_args['nlayers']}-{dict_args['hidden_size']}" 345 | specifier += "-skip" if dict_args["skip"] else "-noskip" 346 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}" 347 | elif dict_args["model_type"] == "gcn": 348 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}" 349 | specifier += f"__{dict_args['nblocks']}-{dict_args['nlayers']}-{dict_args['nchannels']}" 350 | specifier += f"-{dict_args['kernel_size']}-{dict_args['dilation_growth']}" 351 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}" 352 | elif dict_args["model_type"] == "gcntf": 353 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}" 354 | specifier += f"__{dict_args['nblocks']}-{dict_args['nlayers']}-{dict_args['nchannels']}" 355 | specifier += f"-{dict_args['kernel_size']}-{dict_args['dilation_growth']}-{dict_args['tfilm_block_size']}" 356 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}" 357 | 358 | # results directory 359 | save_path = os.path.join(dict_args["save_location"], specifier) 360 | utils.dir_check(save_path) 361 | 362 | # set the seed 363 | # TODO 364 | 365 | # init model 366 | if dict_args["model_type"] == "rnn": 367 | model = RNN(**dict_args) 368 | elif dict_args["model_type"] == "gcn": 369 | model = GCN(**dict_args) 370 | elif dict_args["model_type"] == "gcntf": 371 | model = GCNTF(**dict_args) 372 | 373 | # compute rf 374 | if dict_args["model_type"] in ["gcn", "gcntf"]: 375 | dict_args["rf"] = model.compute_receptive_field() 376 | 377 | model.save_state = False 378 | model.save_model("model", save_path) 379 | 380 | # save settings 381 | utils.json_save(dict_args, "config", save_path, indent=4) 382 | print(f"\n* Training config {idx+1}/{n_configs}") 383 | print(dict_args) 384 | 385 | # cuda 386 | if not torch.cuda.is_available() or dict_args["cuda"] == 0: 387 | print("\ncuda device not available/not selected") 388 | cuda = 0 389 | else: 390 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 391 | torch.cuda.set_device(0) 392 | print("\ncuda device available") 393 | model = model.cuda() 394 | cuda = 1 395 | 396 | # optimiser + scheduler + loss fcns 397 | optimiser = torch.optim.Adam(model.parameters(), 398 | lr=dict_args["learn_rate"], 399 | weight_decay=1e-4) 400 | 401 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 402 | "min", 403 | factor=0.5, 404 | patience=5, 405 | verbose=True) 406 | 407 | loss_functions = training.LossWrapper(dict_args["loss_fcns"], dict_args["prefilt"]) 408 | 409 | # training tracker 410 | train_track = training.TrainTrack() 411 | writer = tensorboard.SummaryWriter(os.path.join("results", specifier)) 412 | 413 | # dataset 414 | dataset = data.DataSet(data_dir=dict_args["data_rootdir"]) 415 | 416 | dataset.create_subset("train", frame_len=dict_args["train_length"]) 417 | dataset.load_file(os.path.join("train", dict_args["file_name"]), "train") 418 | print("\ntrain dataset: ", dataset.subsets["train"].data["input"][0].shape) 419 | 420 | dataset.create_subset("val", frame_len=dict_args["val_length"]) 421 | dataset.load_file(os.path.join("val", dict_args["file_name"]), "val") 422 | print("val dataset: ", dataset.subsets["val"].data["input"][0].shape) 423 | 424 | dataset.create_subset("test", frame_len=dict_args["test_length"]) 425 | dataset.load_file(os.path.join("test", dict_args["file_name"]), "test") 426 | print("test dataset: ", dataset.subsets["test"].data["input"][0].shape) 427 | 428 | # summary 429 | print() 430 | torchinfo.summary(model, 431 | input_size=(dict_args["train_length"], dict_args["batch_size"], 1), 432 | device=None) 433 | 434 | # ===== TRAIN ===== # 435 | start_time = time.time() 436 | 437 | model.save_state = True 438 | patience_counter = 0 439 | init_time = 0 440 | 441 | for epoch in range(train_track['current_epoch'] + 1, dict_args["epochs"] + 1): 442 | ep_st_time = time.time() 443 | 444 | # run 1 epoch of training 445 | if dict_args["model_type"] == "rnn": 446 | epoch_losses = model.train_epoch(dataset.subsets['train'].data['input'][0], 447 | dataset.subsets['train'].data['target'][0], 448 | loss_functions, 449 | optimiser, 450 | dict_args["batch_size"], 451 | up_fr=dict_args["up_fr"]) 452 | else: 453 | epoch_losses = model.train_epoch(dataset.subsets['train'].data['input'][0], 454 | dataset.subsets['train'].data['target'][0], 455 | loss_functions, 456 | optimiser, 457 | dict_args["batch_size"]) 458 | epoch_loss = 0 459 | for loss in epoch_losses: 460 | epoch_loss += epoch_losses[loss] 461 | print(f"epoch {epoch} | \ttrain loss: \t{epoch_loss:0.4f}", end="") 462 | for loss in epoch_losses: 463 | print(f" | \t{loss}: \t{epoch_losses[loss]:0.4f}", end="") 464 | print() 465 | 466 | # ===== VALIDATION ===== # 467 | if epoch % dict_args["validation_f"] == 0: 468 | val_ep_st_time = time.time() 469 | val_input, val_target, val_output, val_losses = \ 470 | model.process_data(dataset.subsets['val'].data['input'][0], 471 | dataset.subsets['val'].data['target'][0], 472 | loss_fcn=loss_functions, 473 | chunk=dict_args["val_chunk"]) 474 | 475 | # val losses 476 | val_loss = 0 477 | for loss in val_losses: 478 | val_loss += val_losses[loss] 479 | print(f"\t\tval loss: \t{val_loss:0.4f}", end="") 480 | for loss in val_losses: 481 | print(f" | \t{loss}: \t{val_losses[loss]:0.4f}", end="") 482 | print() 483 | 484 | # update lr 485 | scheduler.step(val_loss) 486 | 487 | # save best model 488 | if val_loss < train_track['val_loss_best']: 489 | patience_counter = 0 490 | model.save_model('model_best', save_path) 491 | scipy.io.wavfile.write(os.path.join(save_path, "best_val_out.wav"), 492 | dataset.subsets['test'].fs, 493 | val_output.cpu().numpy()[:, 0, 0]) 494 | else: 495 | patience_counter += 1 496 | 497 | # log validation losses 498 | for loss in val_losses: 499 | val_losses[loss] = val_losses[loss].item() 500 | 501 | train_track.val_epoch_update(val_loss=val_loss.item(), 502 | val_losses=val_losses, 503 | ep_st_time=val_ep_st_time, 504 | ep_end_time=time.time()) 505 | 506 | writer.add_scalar('Loss/Val (Tot)', val_loss, epoch) 507 | for loss in val_losses: 508 | writer.add_scalar(f"Loss/Val ({loss})", val_losses[loss], epoch) 509 | 510 | # log training losses 511 | for loss in epoch_losses: 512 | epoch_losses[loss] = epoch_losses[loss].item() 513 | 514 | train_track.train_epoch_update(epoch_loss=epoch_loss.item(), 515 | epoch_losses=epoch_losses, 516 | ep_st_time=ep_st_time, 517 | ep_end_time=time.time(), 518 | init_time=init_time, 519 | current_ep=epoch) 520 | 521 | writer.add_scalar('Loss/Train (Tot)', epoch_loss, epoch) 522 | for loss in epoch_losses: 523 | writer.add_scalar(f"Loss/Train ({loss})", epoch_losses[loss], epoch) 524 | 525 | # log learning rate 526 | writer.add_scalar('LR/current', optimiser.param_groups[0]['lr']) 527 | 528 | # save model 529 | model.save_model('model', save_path) 530 | 531 | # log training stats to json 532 | utils.json_save(train_track, 'training_stats', save_path, indent=4) 533 | 534 | # check early stopping 535 | if dict_args["validation_p"] and patience_counter > dict_args["validation_p"]: 536 | print('\nvalidation patience limit reached at epoch ' + str(epoch)) 537 | break 538 | 539 | # ===== TEST (last model) ===== # 540 | test_input, test_target, test_output, test_losses = \ 541 | model.process_data(dataset.subsets['test'].data['input'][0], 542 | dataset.subsets['test'].data['target'][0], 543 | loss_fcn=loss_functions, 544 | chunk=dict_args["test_chunk"]) 545 | 546 | # test losses (last) 547 | test_loss = 0 548 | for loss in test_losses: 549 | test_loss += test_losses[loss] 550 | print(f"\ttest loss (last): \t{test_loss:0.4f}", end="") 551 | for loss in test_losses: 552 | print(f" | \t{loss}: \t{test_losses[loss]:0.4f}", end="") 553 | print() 554 | 555 | lossESR = training.ESRLoss() # include ESR loss 556 | test_loss_ESR = lossESR(test_output, test_target) 557 | 558 | # save output audio 559 | scipy.io.wavfile.write(os.path.join(save_path, "test_out_final.wav"), 560 | dataset.subsets['test'].fs, 561 | test_output.cpu().numpy()[:, 0, 0]) 562 | 563 | # log test losses 564 | for loss in test_losses: 565 | test_losses[loss] = test_losses[loss].item() 566 | 567 | train_track['test_loss_final'] = test_loss.item() 568 | train_track['test_losses_final'] = test_losses 569 | train_track['test_lossESR_final'] = test_loss_ESR.item() 570 | 571 | writer.add_scalar('Loss/Test/Last (Tot)', test_loss, 0) 572 | for loss in test_losses: 573 | writer.add_scalar(f"Loss/Test/Last ({loss})", test_losses[loss], 0) 574 | 575 | # ===== TEST (best validation model) ===== # 576 | best_val_net = utils.json_load('model_best', save_path) 577 | model = utils.load_model(best_val_net) 578 | 579 | test_input, test_target, test_output, test_losses = \ 580 | model.process_data(dataset.subsets['test'].data['input'][0], 581 | dataset.subsets['test'].data['target'][0], 582 | loss_fcn=loss_functions, 583 | chunk=dict_args["test_chunk"]) 584 | 585 | # test losses (best) 586 | test_loss = 0 587 | for loss in test_losses: 588 | test_loss += test_losses[loss] 589 | print(f"\ttest loss (best): \t{test_loss:0.4f}", end="") 590 | for loss in test_losses: 591 | print(f" | \t{loss}: \t{test_losses[loss]:0.4f}", end="") 592 | print() 593 | 594 | test_loss_ESR = lossESR(test_output, test_target) 595 | 596 | # save output audio 597 | scipy.io.wavfile.write(os.path.join(save_path, "test_out_bestv.wav"), 598 | dataset.subsets['test'].fs, 599 | test_output.cpu().numpy()[:, 0, 0]) 600 | 601 | # log test losses 602 | for loss in test_losses: 603 | test_losses[loss] = test_losses[loss].item() 604 | 605 | train_track['test_loss_best'] = test_loss.item() 606 | train_track['test_losses_best'] = test_losses 607 | train_track['test_lossESR_best'] = test_loss_ESR.item() 608 | 609 | writer.add_scalar('Loss/Test/Best (Tot)', test_loss, 0) 610 | for loss in test_losses: 611 | writer.add_scalar(f"Loss/Test/Best ({loss})", test_losses[loss], 0) 612 | 613 | # log training stats to json 614 | utils.json_save(train_track, 'training_stats', save_path, indent=4) 615 | 616 | if cuda: 617 | with open(os.path.join(save_path, 'maxmemusage.txt'), 'w') as f: 618 | f.write(str(torch.cuda.max_memory_allocated())) 619 | 620 | stop_time = time.time() 621 | print(f"\ntraining time: {(stop_time-start_time)/60:0.2f} min") 622 | --------------------------------------------------------------------------------