├── src
├── __init__.py
├── utils.py
├── training.py
├── data.py
├── rnn.py
├── gcn.py
└── gcntfilm.py
├── results
└── results.txt
├── images
└── waveforms.png
├── requirements.txt
├── data
└── data.txt
├── LICENSE
├── proc_audio.py
├── README.md
└── train.py
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/results/results.txt:
--------------------------------------------------------------------------------
1 | here go the results from training
--------------------------------------------------------------------------------
/images/waveforms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mcomunita/gcn-tfilm/HEAD/images/waveforms.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | auraloss==0.2.2
2 | librosa==0.9.2
3 | numba==0.56.0
4 | numpy==1.22.3
5 | scipy==1.9.3
6 | tensorboard==2.10.1
7 | torch==1.12.1
8 | torchinfo==1.7.1
--------------------------------------------------------------------------------
/data/data.txt:
--------------------------------------------------------------------------------
1 | here goes the data organized in 3 folders:
2 | - test
3 | - train
4 | - val
5 | each folder must contain the input and target files (wav - 16bit) (one file only - it will be split into samples during training)
6 | and they must be named DEVICE-input.wav and DEVICE-target.wav (e.g., la2a-input.wav and la2a-target.wav)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mcomunita
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/proc_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from scipy.io.wavfile import write
5 |
6 | import src.utils as utils
7 | import src.data as data
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(
12 | description='''This script takes an input .wav file, loads it and processes it with a neural network model of a
13 | device, i.e guitar amp/pedal, and saves the output as a .wav file''')
14 |
15 | parser.add_argument('--input_file', type=str, default='data/test/ht1-input.wav', help='input file')
16 | parser.add_argument('--output_file', type=str, default='output.wav', help='output file')
17 | parser.add_argument('--model_file', type=str, default='results/ht1-ht11/model_best.json', help='model file')
18 | parser.add_argument('--chunk_length', type=int, default=16384, help='chunk length')
19 | return parser.parse_args()
20 |
21 |
22 | def proc_audio(args):
23 | model_data = utils.json_load(args.model_file)
24 | model = utils.load_model(model_data)
25 |
26 | dataset = data.DataSet(data_dir='', extensions='')
27 | dataset.create_subset('data')
28 | dataset.load_file(args.input_file, set_names='data')
29 |
30 | # cuda
31 | if torch.cuda.is_available():
32 | device = torch.device("cuda:0")
33 | torch.set_default_tensor_type("torch.cuda.FloatTensor")
34 | # torch.cuda.set_device(0)
35 | print("\ncuda device available")
36 | else:
37 | device = torch.device("cpu")
38 | torch.set_default_tensor_type("torch.FloatTensor")
39 | print("\ncuda device NOT available")
40 |
41 | model = model.to(device)
42 | model.eval()
43 |
44 | with torch.no_grad():
45 | input_data = dataset.subsets['data'].data['data'][0]
46 |
47 | input_data = input_data.to(device)
48 | # output = network(input_data)
49 |
50 | _, _, output = model.process_data(input_data,
51 | chunk=args.chunk_length)
52 |
53 | write(args.output_file, dataset.subsets['data'].fs, output.cpu().numpy()[:, 0, 0])
54 |
55 |
56 | def main():
57 | args = parse_args()
58 | proc_audio(args)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from src.rnn import RNN
5 | from src.gcn import GCN
6 | from src.gcntfilm import GCNTF
7 |
8 |
9 | # Function that checks if a directory exists, and creates it if it doesn't, if dir_name is a list of strings, it will
10 | # create a search path, i.e dir_name = ['directory', 'subdir'] will search for directory 'directory/subdir'
11 | def dir_check(dir_name):
12 | dir_name = [dir_name] if not type(dir_name) == list else dir_name
13 | dir_path = os.path.join(*dir_name)
14 | if os.path.isdir(dir_path):
15 | pass
16 | else:
17 | os.mkdir(dir_path)
18 |
19 |
20 | # Function that takes a file_name and optionally a path to the directory the file is expected to be, returns true if
21 | # the file is found in the stated directory (or the current directory is dir_name = '') or False is dir/file isn't found
22 | def file_check(file_name, dir_name=''):
23 | assert type(file_name) == str
24 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
25 | full_path = os.path.join(*dir_name, file_name)
26 | return os.path.isfile(full_path)
27 |
28 |
29 | # Function that saves 'data' to a json file. Constructs a file path is dir_name is provided.
30 | def json_save(data, file_name, dir_name='', indent=0):
31 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
32 | assert type(file_name) == str
33 | file_name = file_name + '.json' if not file_name.endswith('.json') else file_name
34 | full_path = os.path.join(*dir_name, file_name)
35 | with open(full_path, 'w') as fp:
36 | json.dump(data, fp, indent=indent)
37 |
38 |
39 | def json_load(file_name, dir_name=''):
40 | dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
41 | file_name = file_name + '.json' if not file_name.endswith('.json') else file_name
42 | full_path = os.path.join(*dir_name, file_name)
43 | with open(full_path) as fp:
44 | return json.load(fp)
45 |
46 |
47 | def load_model(model_data):
48 | model_meta = model_data.pop('model_data')
49 | # print(model_meta)
50 |
51 | if model_meta["model_type"] == "rnn":
52 | model = RNN(**model_meta)
53 | elif model_meta["model_type"] == "gcn":
54 | model = GCN(**model_meta)
55 | elif model_meta["model_type"] == "gcntf":
56 | model = GCNTF(**model_meta)
57 |
58 | if 'state_dict' in model_data:
59 | state_dict = model.state_dict()
60 | for each in model_data['state_dict']:
61 | state_dict[each] = torch.tensor(model_data['state_dict'][each])
62 | model.load_state_dict(state_dict)
63 |
64 | return model
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # GCN-TFiLM
5 | **Modelling black-box audio effects with time-varying feature modulation**
6 |
7 | [Paper](https://arxiv.org/abs/2211.00497) | [Webpage](https://mcomunita.github.io/gcn-tfilm_page/)
8 |
9 |
10 |
11 |
12 |

13 |
14 |
15 | ## Setup
16 |
17 | Install the requirements.
18 | ```
19 | python3 -m venv .venv
20 | source .venv/bin/activate
21 | python3 -m pip install -r requirements.txt
22 | ```
23 |
24 | ## Dataset
25 | You can find our dataset of fuzz and compressor effects on [Zenodo](https://zenodo.org/record/7271558#.Y2I_6OzP0-R). Once downloaded you can replace the ```data``` folder in this repo. You can also use your data which goes into 3 subfolders: ```test```, ```train``` and ```val```. Each folder should contain the input and target files (wav - 16bit) (one file only - it will be split into samples during training) and they should be named DEVICE-input.wav and DEVICE-target.wav (e.g., la2a-input.wav and la2a-target.wav).
26 |
27 | ## Training
28 |
29 | If would like to re-train the models in the paper, you can run the training script which will train all the models one by one. The training results will be saved in the ```results``` folder.
30 |
31 | ```
32 | python train.py
33 | ```
34 |
35 | ## Process Audio
36 |
37 | To process audio using a trained model you can run the ```proc_audio.py``` script.
38 |
39 | ```
40 | python proc_audio.py \
41 | --input_file data/test/facebender-input.wav \
42 | --output_file facebender-output.wav \
43 | --model_file results/MODEL_FOLDER/model_best.json
44 | --chunk_length 32768
45 | ```
46 |
47 | ## Credits
48 | [https://github.com/Alec-Wright/Automated-GuitarAmpModelling](https://github.com/Alec-Wright/Automated-GuitarAmpModelling)
49 |
50 | [https://github.com/csteinmetz1/micro-tcn](https://github.com/csteinmetz1/micro-tcn)
51 |
52 | ## Citation
53 | If you use any of this code in your work, please consider citing us.
54 | ```
55 | @misc{https://doi.org/10.48550/arxiv.2211.00497,
56 | doi = {10.48550/ARXIV.2211.00497},
57 | url = {https://arxiv.org/abs/2211.00497},
58 | author = {Comunità, Marco and Steinmetz, Christian J. and Phan, Huy and Reiss, Joshua D.},
59 | keywords = {Sound (cs.SD), Artificial Intelligence (cs.AI), Machine Learning (cs.LG), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering},
60 | title = {Modelling black-box audio effects with time-varying feature modulation},
61 | publisher = {arXiv},
62 | year = {2022},
63 | copyright = {Creative Commons Attribution 4.0 International}
64 | }
65 | ```
66 |
--------------------------------------------------------------------------------
/src/training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import auraloss
4 |
5 | # ESR loss calculates the Error-to-signal between the output/target
6 |
7 |
8 | class ESRLoss(nn.Module):
9 | def __init__(self):
10 | super(ESRLoss, self).__init__()
11 | self.epsilon = 0.00001
12 |
13 | def forward(self, output, target):
14 | loss = torch.add(target, -output)
15 | loss = torch.pow(loss, 2)
16 | loss = torch.mean(loss)
17 | energy = torch.mean(torch.pow(target, 2)) + self.epsilon
18 | loss = torch.div(loss, energy)
19 | return loss
20 |
21 |
22 | class DCLoss(nn.Module):
23 | def __init__(self):
24 | super(DCLoss, self).__init__()
25 | self.epsilon = 0.00001
26 |
27 | def forward(self, output, target):
28 | loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
29 | loss = torch.mean(loss)
30 | energy = torch.mean(torch.pow(target, 2)) + self.epsilon
31 | loss = torch.div(loss, energy)
32 | return loss
33 |
34 |
35 | # PreEmph is a class that applies an FIR pre-emphasis filter to the signal, the filter coefficients are in the
36 | # filter_cfs argument, and lp is a flag that also applies a low pass filter
37 | # Only supported for single-channel!
38 | class PreEmph(nn.Module):
39 | def __init__(self, filter_cfs, low_pass=0):
40 | super(PreEmph, self).__init__()
41 | self.epsilon = 0.00001
42 | self.zPad = len(filter_cfs) - 1
43 |
44 | self.conv_filter = nn.Conv1d(1, 1, 2, bias=False)
45 | self.conv_filter.weight.data = torch.tensor([[filter_cfs]], requires_grad=False)
46 |
47 | self.low_pass = low_pass
48 | if self.low_pass:
49 | self.lp_filter = nn.Conv1d(1, 1, 2, bias=False)
50 | self.lp_filter.weight.data = torch.tensor([[[0.85, 1]]], requires_grad=False)
51 |
52 | def forward(self, output, target):
53 | # zero pad the input/target so the filtered signal is the same length
54 | output = torch.cat((torch.zeros(self.zPad, output.shape[1], 1), output))
55 | target = torch.cat((torch.zeros(self.zPad, target.shape[1], 1), target))
56 | # Apply pre-emph filter, permute because the dimension order is different for RNNs and Convs in pytorch...
57 | output = self.conv_filter(output.permute(1, 2, 0))
58 | target = self.conv_filter(target.permute(1, 2, 0))
59 |
60 | if self.low_pass:
61 | output = self.lp_filter(output)
62 | target = self.lp_filter(target)
63 |
64 | return output.permute(2, 0, 1), target.permute(2, 0, 1)
65 |
66 |
67 | class LossWrapper(nn.Module):
68 | def __init__(self, losses, pre_filt=None):
69 | super(LossWrapper, self).__init__()
70 | self.losses = losses
71 | self.loss_dict = {
72 | 'ESR': ESRLoss(),
73 | 'DC': DCLoss(),
74 | 'L1': torch.nn.L1Loss(),
75 | 'STFT': auraloss.freq.STFTLoss(),
76 | 'MSTFT': auraloss.freq.MultiResolutionSTFTLoss()
77 | }
78 | if pre_filt:
79 | pre_filt = PreEmph(pre_filt)
80 | self.loss_dict['ESRPre'] = lambda output, target: self.loss_dict['ESR'].forward(*pre_filt(output, target))
81 | loss_functions = [[self.loss_dict[key], value] for key, value in losses.items()]
82 |
83 | self.loss_functions = tuple([items[0] for items in loss_functions])
84 | try:
85 | self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
86 | except IndexError:
87 | self.loss_factors = torch.ones(len(self.loss_functions))
88 |
89 | def forward(self, output, target):
90 | all_losses = {}
91 | for i, loss in enumerate(self.losses):
92 | # original shape: length x batch x 1
93 | # auraloss needs: batch x 1 x length
94 | loss_fcn = self.loss_functions[i]
95 | loss_factor = self.loss_factors[i]
96 | if isinstance(loss_fcn, auraloss.freq.STFTLoss) or isinstance(loss_fcn, auraloss.freq.MultiResolutionSTFTLoss):
97 | output = torch.permute(output, (1, 2, 0))
98 | target = torch.permute(target, (1, 2, 0))
99 | all_losses[loss] = torch.mul(loss_fcn(output, target), loss_factor)
100 | return all_losses
101 |
102 |
103 | class TrainTrack(dict):
104 | def __init__(self):
105 | self.update({'current_epoch': 0,
106 |
107 | 'tot_train_losses': [],
108 | 'train_losses': [],
109 |
110 | 'tot_val_losses': [],
111 | 'val_losses': [],
112 |
113 | 'train_av_time': 0.0,
114 | 'val_av_time': 0.0,
115 | 'total_time': 0.0,
116 |
117 | 'val_loss_best': 1e12,
118 | 'val_losses_best': 1e12,
119 |
120 | 'test_loss_final': 0,
121 | 'test_losses_final': {},
122 |
123 | 'test_loss_best': 0,
124 | 'test_losses_best': {}})
125 |
126 | def restore_data(self, training_info):
127 | self.update(training_info)
128 |
129 | def train_epoch_update(self, epoch_loss, epoch_losses, ep_st_time, ep_end_time, init_time, current_ep):
130 | self['current_epoch'] = current_ep
131 | self['tot_train_losses'].append(epoch_loss)
132 | self['train_losses'].append(epoch_losses)
133 |
134 | if self['train_av_time']:
135 | self['train_av_time'] = (self['train_av_time'] + ep_end_time - ep_st_time) / 2
136 | else:
137 | self['train_av_time'] = ep_end_time - ep_st_time
138 |
139 | self['total_time'] += ((init_time + ep_end_time - ep_st_time)/3600)
140 |
141 | def val_epoch_update(self, val_loss, val_losses, ep_st_time, ep_end_time):
142 | self['tot_val_losses'].append(val_loss)
143 | self['val_losses'].append(val_losses)
144 |
145 | if self['val_av_time']:
146 | self['val_av_time'] = (self['val_av_time'] + ep_end_time - ep_st_time) / 2
147 | else:
148 | self['val_av_time'] = ep_end_time - ep_st_time
149 |
150 | if val_loss < self['val_loss_best']:
151 | self['val_loss_best'] = val_loss
152 | self['val_losses_best'] = val_losses
153 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.io import wavfile
3 | import torch
4 | import math
5 | import warnings
6 | import os
7 |
8 |
9 | # Function converting np read audio to range of -1 to +1
10 | def audio_converter(audio):
11 | if audio.dtype == 'int16':
12 | return audio.astype(np.float32, order='C') / 32768.0
13 | else:
14 | print('unimplemented audio data type conversion...')
15 |
16 |
17 | # Splits audio, each split marker determines the fraction of the total audio in that split, i.e [0.75, 0.25] will put
18 | # 75% in the first split and 25% in the second
19 | def audio_splitter(audio, split_markers):
20 | assert sum(split_markers) <= 1.0
21 | if sum(split_markers) < 0.999:
22 | warnings.warn("sum of split markers is less than 1, so not all audio will be included in dataset")
23 | start = 0
24 | slices = []
25 | # convert split markers to samples
26 | split_bounds = [int(x * audio.shape[0]) for x in split_markers]
27 | for n in split_bounds:
28 | end = start + n
29 | slices.append(audio[start:end])
30 | start = end
31 | return slices
32 |
33 |
34 | # converts numpy audio into frames, and creates a torch tensor from them, frame_len = 0 just converts to a torch tensor
35 | def framify(audio, frame_len):
36 | # If audio is mono, add a dummy dimension, so the same operations can be applied to mono/multichannel audio
37 | audio = np.expand_dims(audio, 1) if len(audio.shape) == 1 else audio
38 | # Calculate the number of segments the training data will be split into in frame_len is not 0
39 | seg_num = math.floor(audio.shape[0] / frame_len) if frame_len else 1
40 | # If no frame_len is provided, set frame_len to be equal to length of the input audio
41 | frame_len = audio.shape[0] if not frame_len else frame_len
42 | # Find the number of channels
43 | channels = audio.shape[1]
44 | # Initialise tensor matrices
45 | dataset = torch.empty((frame_len, seg_num, channels))
46 | # Load the audio for the training set
47 | for i in range(seg_num):
48 | dataset[:, i, :] = torch.from_numpy(audio[i * frame_len:(i + 1) * frame_len, :])
49 | return dataset
50 |
51 |
52 | """This is the main DataSet class, it can hold any number of subsets, which could be, e.g, the training/test/val sets.
53 | The subsets are created by the create_subset method and stored in the DataSet.subsets dictionary.
54 |
55 | datadir: is the default location where the DataSet instance will look when told to load an audio file
56 | extensions: is the default ends of the paired data files, when using paired data, so by default when loading the file
57 | 'wicked_guitar', the load_file method with look for 'wicked_guitar-input.wav' and 'wicked_guitar-target.wav'
58 | to disable this behaviour, enter extensions = '', or None, or anything that evaluates to false in python """
59 |
60 |
61 | class DataSet:
62 | def __init__(self, data_dir='../Dataset/', extensions=('input', 'target')):
63 | self.extensions = extensions if extensions else ['']
64 | self.subsets = {}
65 | assert type(data_dir) == str, "data_dir should be string,not %r" % {type(data_dir)}
66 | self.data_dir = data_dir
67 |
68 | # add a subset called 'name', desired 'frame_len' is given in seconds, or 0 for just one long frame
69 | def create_subset(self, name, frame_len=0):
70 | assert type(name) == str, "data subset name must be a string, not %r" %{type(name)}
71 | assert not (name in self.subsets), "subset %r already exists" %name
72 | self.subsets[name] = SubSet(frame_len)
73 |
74 | # load a file of 'filename' into existing subset/s 'set_names', split fractionally as specified by 'splits',
75 | # if 'cond_val' is provided the conditioning value will be saved along with the frames of the loaded data
76 | def load_file(self, filename, set_names='train', splits=None, cond_val=None):
77 | # Assertions and checks
78 | if type(set_names) == str:
79 | set_names = [set_names]
80 | assert len(set_names) == 1 or len(set_names) == len(splits), "number of subset names must equal number of " \
81 | "split markers"
82 | assert [self.subsets.get(each) for each in set_names], "set_names contains subsets that don't exist yet"
83 |
84 | # Load each of the 'extensions'
85 | for i, ext in enumerate(self.extensions):
86 | try:
87 | file_loc = os.path.join(self.data_dir, filename + '-' + ext)
88 | file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc
89 | np_data = wavfile.read(file_loc)
90 | except FileNotFoundError:
91 | file_loc = os.path.join(self.data_dir, filename + ext)
92 | file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc
93 | np_data = wavfile.read(file_loc)
94 | except FileNotFoundError:
95 | print(["File Not Found At: " + self.data_dir + filename])
96 | return
97 | raw_audio = audio_converter(np_data[1])
98 | # Split the audio if the set_names were provided
99 | if len(set_names) > 1:
100 | raw_audio = audio_splitter(raw_audio, splits)
101 | for n, sets in enumerate(set_names):
102 | self.subsets[set_names[n]].add_data(np_data[0], raw_audio[n], ext, cond_val)
103 | elif len(set_names) == 1:
104 | self.subsets[set_names[0]].add_data(np_data[0], raw_audio, ext, cond_val)
105 |
106 |
107 | # The SubSet class holds a subset of data,
108 | # frame_len sets the length of audio per frame (in s), if set to 0 a single frame is used instead
109 | class SubSet:
110 | def __init__(self, frame_len):
111 | self.data = {}
112 | self.cond_data = {}
113 | self.frame_len = frame_len
114 | self.conditioning = None
115 | self.fs = None
116 |
117 | # Add 'audio' data, in the data dictionary at the key 'ext', if cond_val is provided save the cond_val of each frame
118 | def add_data(self, fs, audio, ext, cond_val):
119 | if not self.fs:
120 | self.fs = fs
121 | assert self.fs == fs, "data with different sample rate provided to subset"
122 | # if no 'ext' is provided, all the subsets data will be stored at the 'data' key of the 'data' dict
123 | ext = 'data' if not ext else ext
124 | # Frame the data and optionally create a tensor of the conditioning values of each frame
125 | framed_data = framify(audio, self.frame_len)
126 | cond_data = cond_val * torch.ones(framed_data.shape[1]) if isinstance(cond_val, (float, int)) else None
127 |
128 | try:
129 | # Convert data from tuple to list and concatenate new data onto the data tensor
130 | data = list(self.data[ext])
131 | self.data[ext] = (torch.cat((data[0], framed_data), 1),)
132 | # If cond_val is provided add it to the cond_val tensor, note all frames or no frames must have cond vals
133 | if isinstance(cond_val, (float, int)):
134 | assert torch.is_tensor(self.cond_data[ext][0]), 'cond val provided, but previous data has no cond val'
135 | c_data = list(self.cond_data[ext])
136 | self.cond_data[ext] = (torch.cat((c_data[0], cond_data), 0),)
137 | # If this is the first data to be loaded into the subset, create the data and cond_data tuples
138 | except KeyError:
139 | self.data[ext] = (framed_data,)
140 | self.cond_data[ext] = (cond_data,)
141 |
--------------------------------------------------------------------------------
/src/rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import argparse
4 | import src.utils as utils
5 | # from contextlib import nullcontext
6 |
7 |
8 | def wrapperkwargs(func, kwargs):
9 | return func(**kwargs)
10 |
11 |
12 | def wrapperargs(func, args):
13 | return func(*args)
14 |
15 |
16 | """
17 | Simple RNN class made of a single recurrent unit of type LSTM, GRU or Elman,
18 | followed by a fully connected layer
19 | """
20 |
21 |
22 | class RNN(torch.nn.Module):
23 | def __init__(self,
24 | input_size=1,
25 | output_size=1,
26 | unit_type="LSTM",
27 | hidden_size=32,
28 | skip=1,
29 | bias_fl=True,
30 | nlayers=1,
31 | **kwargs):
32 | super(RNN, self).__init__()
33 | self.input_size = input_size
34 | self.output_size = output_size
35 | self.unit_type = unit_type
36 | self.hidden_size = hidden_size
37 | self.skip = skip
38 | self.bias_fl = bias_fl
39 | self.nlayers = nlayers
40 | self.save_state = True
41 | self.hidden = None # hidden state
42 |
43 | # create dictionary of possible block types
44 | self.rec = wrapperargs(getattr(torch.nn, unit_type),
45 | [input_size, hidden_size, nlayers])
46 |
47 | self.lin = torch.nn.Linear(hidden_size, output_size, bias=bias_fl)
48 |
49 | def forward(self, x):
50 | if self.skip:
51 | # save the residual for the skip connection
52 | res = x[:, :, :self.skip]
53 | x, self.hidden = self.rec(x, self.hidden)
54 | return self.lin(x) + res
55 | else:
56 | x, self.hidden = self.rec(x, self.hidden)
57 | return self.lin(x)
58 |
59 | # detach hidden state, this resets gradient tracking on the hidden state
60 | def detach_hidden(self):
61 | if self.hidden.__class__ == tuple:
62 | self.hidden = tuple([h.clone().detach() for h in self.hidden])
63 | else:
64 | self.hidden = self.hidden.clone().detach()
65 |
66 | # reset hidden state
67 | def reset_hidden(self):
68 | self.hidden = None
69 |
70 | # train_epoch runs one epoch of training
71 | def train_epoch(self,
72 | input_data,
73 | target_data,
74 | loss_fcn,
75 | optim,
76 | batch_size,
77 | init_len=200,
78 | up_fr=1000):
79 | # shuffle the segments at the start of the epoch
80 | shuffle = torch.randperm(input_data.shape[1])
81 |
82 | # iterate over the batches
83 | ep_losses = {}
84 |
85 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)):
86 | # load batch
87 | input_batch = input_data[:,
88 | shuffle[batch_i * batch_size:
89 | (batch_i + 1) * batch_size],
90 | :]
91 | target_batch = target_data[:,
92 | shuffle[batch_i * batch_size:
93 | (batch_i + 1) * batch_size],
94 | :]
95 |
96 | # initialise network hidden state by processing some samples then zero the gradient buffers
97 | self(input_batch[0:init_len, :, :])
98 | self.zero_grad()
99 |
100 | # choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
101 | start_i = init_len
102 | tot_batch_losses = {}
103 |
104 | # iterate over the remaining samples in the mini batch
105 | for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
106 | # process batch
107 | output = self(input_batch[start_i:start_i + up_fr, :, :])
108 |
109 | # loss and backprop
110 | partial_batch_losses = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :])
111 |
112 | partial_batch_loss = 0
113 | for loss in partial_batch_losses:
114 | partial_batch_loss += partial_batch_losses[loss]
115 |
116 | partial_batch_loss.backward()
117 | optim.step()
118 |
119 | # detach hidden state for truncated BPTT
120 | self.detach_hidden()
121 | self.zero_grad()
122 |
123 | # update the start index for next iteration
124 | start_i += up_fr
125 |
126 | # add partial batch losses to total batch losses
127 | if tot_batch_losses == {}:
128 | tot_batch_losses = partial_batch_losses
129 | else:
130 | for loss in partial_batch_losses:
131 | tot_batch_losses[loss] += partial_batch_losses[loss]
132 |
133 | # add average batch losses to epoch losses
134 | if ep_losses == {}:
135 | for loss in tot_batch_losses:
136 | ep_losses[loss] = tot_batch_losses[loss] / (k + 1)
137 | else:
138 | for loss in tot_batch_losses:
139 | ep_losses[loss] += tot_batch_losses[loss] / (k + 1)
140 |
141 | # # add the average batch loss to the epoch loss and reset the hidden states to zeros
142 | # ep_loss += batch_loss / (k + 1)
143 |
144 | # reset hidden state before next batch
145 | self.reset_hidden()
146 |
147 | # mean epoch losses
148 | for loss in ep_losses:
149 | ep_losses[loss] /= (batch_i + 1)
150 |
151 | return ep_losses
152 |
153 | def process_data(self,
154 | input_data,
155 | target_data=None,
156 | chunk=16384,
157 | loss_fcn=None,
158 | grad=False):
159 |
160 | if not (input_data.shape[0] / chunk).is_integer():
161 | # round to next chunk size
162 | padding = chunk - (input_data.shape[0] % chunk)
163 | input_data = torch.nn.functional.pad(input_data,
164 | (0, 0, 0, 0, 0, padding),
165 | mode='constant',
166 | value=0)
167 | if target_data != None:
168 | target_data = torch.nn.functional.pad(target_data,
169 | (0, 0, 0, 0, 0, padding),
170 | mode='constant',
171 | value=0)
172 |
173 | with torch.no_grad():
174 | # reset state before processing
175 | self.reset_hidden()
176 |
177 | # process input
178 | output_data = torch.empty_like(input_data)
179 |
180 | for l in range(int(output_data.size()[0] / chunk)):
181 | output_data[l * chunk:(l + 1) * chunk] = \
182 | self(input_data[l * chunk:(l + 1) * chunk])
183 | self.detach_hidden()
184 |
185 | # reset state before other computations
186 | self.reset_hidden()
187 |
188 | if loss_fcn != None and target_data != None:
189 | losses = loss_fcn(output_data, target_data)
190 | return input_data, target_data, output_data, losses
191 |
192 | return input_data, target_data, output_data
193 |
194 | def save_model(self,
195 | file_name,
196 | direc=''):
197 | if direc:
198 | utils.dir_check(direc)
199 |
200 | model_data = {'model_data': {'model_type': 'rnn',
201 | 'input_size': self.input_size,
202 | 'output_size': self.output_size,
203 | 'unit_type': self.unit_type,
204 | 'hidden_size': self.hidden_size,
205 | 'skip': self.skip,
206 | 'bias_fl': self.bias_fl,
207 | 'nlayers': self.nlayers}}
208 |
209 | if self.save_state:
210 | model_state = self.state_dict()
211 | for each in model_state:
212 | model_state[each] = model_state[each].tolist()
213 | model_data['state_dict'] = model_state
214 |
215 | utils.json_save(model_data, file_name, direc)
216 |
217 | # add any model hyperparameters here
218 | @staticmethod
219 | def add_model_specific_args(parent_parser):
220 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
221 | # --- model related ---
222 | parser.add_argument('--input_size', type=int, default=1)
223 | parser.add_argument('--output_size', type=int, default=1)
224 | parser.add_argument('--unit_type', type=str, default="LSTM")
225 | parser.add_argument('--hidden_size', type=int, default=32)
226 | parser.add_argument('--skip', type=int, default=1)
227 | parser.add_argument('--bias_fl', default=True, action='store_true')
228 | parser.add_argument('--nlayers', type=int, default=1)
229 |
230 | return parser
231 |
--------------------------------------------------------------------------------
/src/gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import argparse
4 | import src.utils as utils
5 | # from contextlib import nullcontext
6 |
7 |
8 | def wrapperkwargs(func, kwargs):
9 | return func(**kwargs)
10 |
11 |
12 | def wrapperargs(func, args):
13 | return func(*args)
14 |
15 |
16 | # def load_model(model_data):
17 | # model_types = {'GCN': GCN}
18 |
19 | # model_meta = model_data.pop('model_data')
20 |
21 | # if model_meta['model'] == 'SimpleRNN' or model_meta['model'] == 'GatedConvNet':
22 | # network = wrapperkwargs(
23 | # model_types[model_meta.pop('model')], model_meta)
24 | # if 'state_dict' in model_data:
25 | # state_dict = network.state_dict()
26 | # for each in model_data['state_dict']:
27 | # state_dict[each] = torch.tensor(model_data['state_dict'][each])
28 | # network.load_state_dict(state_dict)
29 |
30 | # elif model_meta['model'] == 'RecNet':
31 | # model_meta['blocks'] = []
32 | # network = wrapperkwargs(
33 | # model_types[model_meta.pop('model')], model_meta)
34 | # for i in range(len(model_data['blocks'])):
35 | # network.add_layer(model_data['blocks'][str(i)])
36 |
37 | # # Get the state dict from the newly created model and load the saved states, if states were saved
38 | # if 'state_dict' in model_data:
39 | # state_dict = network.state_dict()
40 | # for each in model_data['state_dict']:
41 | # state_dict[each] = torch.tensor(model_data['state_dict'][each])
42 | # network.load_state_dict(state_dict)
43 |
44 | # if 'training_info' in model_data.keys():
45 | # network.training_info = model_data['training_info']
46 |
47 | # return network
48 |
49 |
50 | # This is a function for taking the old json config file format I used to use and converting it to the new format
51 | # def legacy_load(legacy_data):
52 | # if legacy_data['unit_type'] == 'GRU' or legacy_data['unit_type'] == 'LSTM':
53 | # model_data = {'model_data': {
54 | # 'model': 'RecNet', 'skip': 0}, 'blocks': {}}
55 | # model_data['blocks']['0'] = {
56 | # 'block_type': legacy_data['unit_type'],
57 | # 'input_size': legacy_data['in_size'],
58 | # 'hidden_size': legacy_data['hidden_size'],
59 | # 'output_size': 1,
60 | # 'lin_bias': True
61 | # }
62 | # if legacy_data['cur_epoch']:
63 | # training_info = {
64 | # 'current_epoch': legacy_data['cur_epoch'],
65 | # 'training_losses': legacy_data['tloss_list'],
66 | # 'val_losses': legacy_data['vloss_list'],
67 | # 'load_config': legacy_data['load_config'],
68 | # 'low_pass': legacy_data['low_pass'],
69 | # 'val_freq': legacy_data['val_freq'],
70 | # 'device': legacy_data['pedal'],
71 | # 'seg_length': legacy_data['seg_len'],
72 | # 'learning_rate': legacy_data['learn_rate'],
73 | # 'batch_size': legacy_data['batch_size'],
74 | # 'loss_func': legacy_data['loss_fcn'],
75 | # 'update_freq': legacy_data['up_fr'],
76 | # 'init_length': legacy_data['init_len'],
77 | # 'pre_filter': legacy_data['pre_filt']
78 | # }
79 | # model_data['training_info'] = training_info
80 |
81 | # if 'state_dict' in legacy_data:
82 | # state_dict = legacy_data['state_dict']
83 | # state_dict = dict(state_dict)
84 | # new_state_dict = {}
85 | # for each in state_dict:
86 | # new_name = each[0:7] + 'block_1.' + each[9:]
87 | # new_state_dict[new_name] = state_dict[each]
88 | # model_data['state_dict'] = new_state_dict
89 | # return model_data
90 | # else:
91 | # print('format not recognised')
92 |
93 |
94 | """
95 | Gated convolutional layer, zero pads and then applies a causal convolution to the input
96 | """
97 |
98 |
99 | class GatedConv1d(torch.nn.Module):
100 |
101 | def __init__(self,
102 | in_ch,
103 | out_ch,
104 | dilation,
105 | kernel_size):
106 | super(GatedConv1d, self).__init__()
107 | self.in_ch = in_ch
108 | self.out_ch = out_ch
109 | self.dilation = dilation
110 | self.kernal_size = kernel_size
111 |
112 | # Layers: Conv1D -> Activations -> Mix + Residual
113 |
114 | self.conv = torch.nn.Conv1d(in_channels=in_ch,
115 | out_channels=out_ch * 2,
116 | kernel_size=kernel_size,
117 | stride=1,
118 | padding=0,
119 | dilation=dilation)
120 |
121 | self.mix = torch.nn.Conv1d(in_channels=out_ch,
122 | out_channels=out_ch,
123 | kernel_size=1,
124 | stride=1,
125 | padding=0)
126 |
127 | def forward(self, x):
128 | # print("GatedConv1d: ", x.shape)
129 | residual = x
130 |
131 | # dilated conv
132 | y = self.conv(x)
133 |
134 | # gated activation
135 | z = torch.tanh(y[:, :self.out_ch, :]) * \
136 | torch.sigmoid(y[:, self.out_ch:, :])
137 |
138 | # zero pad on the left side, so that z is the same length as x
139 | z = torch.cat((torch.zeros(residual.shape[0],
140 | self.out_ch,
141 | residual.shape[2] - z.shape[2]),
142 | z),
143 | dim=2)
144 |
145 | x = self.mix(z) + residual
146 |
147 | return x, z
148 |
149 |
150 | """
151 | Gated convolutional neural net block, applies successive gated convolutional layers to the input, a total of 'layers'
152 | layers are applied, with the filter size 'kernel_size' and the dilation increasing by a factor of 'dilation_growth' for
153 | each successive layer.
154 | """
155 |
156 |
157 | class GCNBlock(torch.nn.Module):
158 | def __init__(self,
159 | in_ch,
160 | out_ch,
161 | nlayers,
162 | kernel_size,
163 | dilation_growth):
164 | super(GCNBlock, self).__init__()
165 | self.in_ch = in_ch
166 | self.out_ch = out_ch
167 | self.nlayers = nlayers
168 | self.kernel_size = kernel_size
169 | self.dilation_growth = dilation_growth
170 |
171 | dilations = [dilation_growth ** l for l in range(nlayers)]
172 |
173 | self.layers = torch.nn.ModuleList()
174 |
175 | for d in dilations:
176 | self.layers.append(GatedConv1d(in_ch=in_ch,
177 | out_ch=out_ch,
178 | dilation=d,
179 | kernel_size=kernel_size))
180 | in_ch = out_ch
181 |
182 | def forward(self, x):
183 | # print("GCNBlock: ", x.shape)
184 | # [batch, channels, length]
185 | z = torch.empty([x.shape[0],
186 | self.nlayers * self.out_ch,
187 | x.shape[2]])
188 |
189 | for n, layer in enumerate(self.layers):
190 | x, zn = layer(x)
191 | z[:, n * self.out_ch: (n + 1) * self.out_ch, :] = zn
192 |
193 | return x, z
194 |
195 |
196 | """
197 | Gated Convolutional Neural Net class, based on the 'WaveNet' architecture, takes a single channel of audio as input and
198 | produces a single channel of audio of equal length as output. one-sided zero-padding is used to ensure the network is
199 | causal and doesn't reduce the length of the audio.
200 |
201 | Made up of 'blocks', each one applying a series of dilated convolutions, with the dilation of each successive layer
202 | increasing by a factor of 'dilation_growth'. 'layers' determines how many convolutional layers are in each block,
203 | 'kernel_size' is the size of the filters. Channels is the number of convolutional channels.
204 |
205 | The output of the model is creating by the linear mixer, which sums weighted outputs from each of the layers in the
206 | model
207 | """
208 |
209 |
210 | class GCN(torch.nn.Module):
211 | def __init__(self,
212 | nblocks=2,
213 | nlayers=9,
214 | nchannels=8,
215 | kernel_size=3,
216 | dilation_growth=2,
217 | **kwargs):
218 | super(GCN, self).__init__()
219 | self.nblocks = nblocks
220 | self.nlayers = nlayers
221 | self.nchannels = nchannels
222 | self.kernel_size = kernel_size
223 | self.dilation_growth = dilation_growth
224 |
225 | self.blocks = torch.nn.ModuleList()
226 | for b in range(nblocks):
227 | self.blocks.append(GCNBlock(in_ch=1 if b == 0 else nchannels,
228 | out_ch=nchannels,
229 | nlayers=nlayers,
230 | kernel_size=kernel_size,
231 | dilation_growth=dilation_growth))
232 |
233 | # output mixing layer
234 | self.blocks.append(
235 | torch.nn.Conv1d(in_channels=nchannels * nlayers * nblocks,
236 | out_channels=1,
237 | kernel_size=1,
238 | stride=1,
239 | padding=0))
240 |
241 | def forward(self, x):
242 | # print("GCN: ", x.shape)
243 | # x.shape = [length, batch, channels]
244 | x = x.permute(1, 2, 0) # change to [batch, channels, length]
245 | z = torch.empty([x.shape[0], self.blocks[-1].in_channels, x.shape[2]])
246 |
247 | for n, block in enumerate(self.blocks[:-1]):
248 | x, zn = block(x)
249 | z[:,
250 | n * self.nchannels * self.nlayers:
251 | (n + 1) * self.nchannels * self.nlayers,
252 | :] = zn
253 |
254 | # back to [length, batch, channels]
255 | return self.blocks[-1](z).permute(2, 0, 1)
256 |
257 | # train_epoch runs one epoch of training
258 | def train_epoch(self,
259 | input_data,
260 | target_data,
261 | loss_fcn,
262 | optim,
263 | batch_size):
264 | # shuffle the segments at the start of the epoch
265 | shuffle = torch.randperm(input_data.shape[1])
266 |
267 | # iterate over the batches
268 | ep_losses = None
269 |
270 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)):
271 | # zero all gradients
272 | self.zero_grad()
273 |
274 | # load batch
275 | input_batch = input_data[:,
276 | shuffle[batch_i * batch_size:
277 | (batch_i + 1) * batch_size],
278 | :]
279 | target_batch = target_data[:,
280 | shuffle[batch_i * batch_size:
281 | (batch_i + 1) * batch_size],
282 | :]
283 |
284 | # process batch
285 | output = self(input_batch)
286 |
287 | # loss and backprop
288 | batch_losses = loss_fcn(output, target_batch)
289 |
290 | tot_batch_loss = 0
291 | for loss in batch_losses:
292 | tot_batch_loss += batch_losses[loss]
293 |
294 | tot_batch_loss.backward()
295 | optim.step()
296 |
297 | # add batch losses to epoch losses
298 | for loss in batch_losses:
299 | if ep_losses == None:
300 | ep_losses = batch_losses
301 | else:
302 | ep_losses[loss] += batch_losses[loss]
303 |
304 | # mean epoch losses
305 | for loss in ep_losses:
306 | ep_losses[loss] /= (batch_i + 1)
307 |
308 | return ep_losses
309 |
310 | def process_data(self,
311 | input_data,
312 | target_data=None,
313 | chunk=16384,
314 | loss_fcn=None,
315 | grad=False):
316 |
317 | rf = self.compute_receptive_field()
318 |
319 | if not (input_data.shape[0] / chunk).is_integer():
320 | # round to next chunk size
321 | padding = chunk - (input_data.shape[0] % chunk)
322 | input_data = torch.nn.functional.pad(input_data,
323 | (0, 0, 0, 0, 0, padding),
324 | mode='constant',
325 | value=0)
326 | if target_data != None:
327 | target_data = torch.nn.functional.pad(target_data,
328 | (0, 0, 0, 0, 0, padding),
329 | mode='constant',
330 | value=0)
331 |
332 | with torch.no_grad():
333 | # process input
334 | output_data = torch.empty_like(input_data)
335 |
336 | for l in range(int(output_data.size()[0] / chunk)):
337 | input_chunk = input_data[l * chunk: (l + 1) * chunk]
338 | if l == 0: # first chunk
339 | padding = torch.zeros([rf, input_chunk.shape[1], input_chunk.shape[2]])
340 | else:
341 | padding = input_data[(l * chunk) - rf: l * chunk]
342 | input_chunk = torch.cat([padding, input_chunk])
343 | output_chunk = self(input_chunk)
344 | output_data[l * chunk: (l + 1) * chunk] = \
345 | output_chunk[rf:, :, :]
346 |
347 | if loss_fcn != None and target_data != None:
348 | losses = loss_fcn(output_data, target_data)
349 | return input_data, target_data, output_data, losses
350 |
351 | return input_data, target_data, output_data
352 |
353 | def save_model(self,
354 | file_name,
355 | direc=""):
356 | if direc:
357 | utils.dir_check(direc)
358 |
359 | model_data = {"model_data": {"model_type": "gcn",
360 | "nblocks": self.nblocks,
361 | "nlayers": self.nlayers,
362 | "nchannels": self.nchannels,
363 | "kernel_size": self.kernel_size,
364 | "dilation_growth": self.dilation_growth}}
365 |
366 | if self.save_state:
367 | model_state = self.state_dict()
368 | for each in model_state:
369 | model_state[each] = model_state[each].tolist()
370 | model_data["state_dict"] = model_state
371 |
372 | utils.json_save(model_data, file_name, direc)
373 |
374 | def compute_receptive_field(self):
375 | """ Compute the receptive field in samples."""
376 | rf = self.kernel_size
377 | for n in range(1, self.nblocks * self.nlayers):
378 | dilation = self.dilation_growth ** (n % self.nlayers)
379 | rf = rf + ((self.kernel_size-1) * dilation)
380 | return rf
381 |
382 | # add any model hyperparameters here
383 | @staticmethod
384 | def add_model_specific_args(parent_parser):
385 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
386 | # --- model related ---
387 | parser.add_argument('--nblocks', type=int, default=2)
388 | parser.add_argument('--nlayers', type=int, default=9)
389 | parser.add_argument('--nchannels', type=int, default=16)
390 | parser.add_argument('--kernel_size', type=int, default=3)
391 | parser.add_argument('--dilation_growth', type=int, default=2)
392 |
393 | return parser
394 |
--------------------------------------------------------------------------------
/src/gcntfilm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import argparse
4 | import src.utils as utils
5 | from contextlib import nullcontext
6 |
7 |
8 | def wrapperkwargs(func, kwargs):
9 | return func(**kwargs)
10 |
11 |
12 | def wrapperargs(func, args):
13 | return func(*args)
14 |
15 |
16 | # def load_model(model_data):
17 | # model_types = {"GCN": GCN}
18 |
19 | # model_meta = model_data.pop("model_data")
20 |
21 | # if model_meta["model"] == "SimpleRNN" or model_meta["model"] == "GCN":
22 | # network = wrapperkwargs(
23 | # model_types[model_meta.pop("model")], model_meta)
24 | # if "state_dict" in model_data:
25 | # state_dict = network.state_dict()
26 | # for each in model_data["state_dict"]:
27 | # state_dict[each] = torch.tensor(model_data["state_dict"][each])
28 | # network.load_state_dict(state_dict)
29 |
30 | # elif model_meta["model"] == "RecNet":
31 | # model_meta["blocks"] = []
32 | # network = wrapperkwargs(
33 | # model_types[model_meta.pop("model")], model_meta)
34 | # for i in range(len(model_data["blocks"])):
35 | # network.add_layer(model_data["blocks"][str(i)])
36 |
37 | # # Get the state dict from the newly created model and load the saved states, if states were saved
38 | # if "state_dict" in model_data:
39 | # state_dict = network.state_dict()
40 | # for each in model_data["state_dict"]:
41 | # state_dict[each] = torch.tensor(model_data["state_dict"][each])
42 | # network.load_state_dict(state_dict)
43 |
44 | # if "training_info" in model_data.keys():
45 | # network.training_info = model_data["training_info"]
46 |
47 | # return network
48 |
49 |
50 | # This is a function for taking the old json config file format I used to use and converting it to the new format
51 | # def legacy_load(legacy_data):
52 | # if legacy_data["unit_type"] == "GRU" or legacy_data["unit_type"] == "LSTM":
53 | # model_data = {"model_data": {
54 | # "model": "RecNet", "skip": 0}, "blocks": {}}
55 | # model_data["blocks"]["0"] = {
56 | # "block_type": legacy_data["unit_type"],
57 | # "input_size": legacy_data["in_size"],
58 | # "hidden_size": legacy_data["hidden_size"],
59 | # "output_size": 1,
60 | # "lin_bias": True,
61 | # }
62 | # if legacy_data["cur_epoch"]:
63 | # training_info = {
64 | # "current_epoch": legacy_data["cur_epoch"],
65 | # "training_losses": legacy_data["tloss_list"],
66 | # "val_losses": legacy_data["vloss_list"],
67 | # "load_config": legacy_data["load_config"],
68 | # "low_pass": legacy_data["low_pass"],
69 | # "val_freq": legacy_data["val_freq"],
70 | # "device": legacy_data["pedal"],
71 | # "seg_length": legacy_data["seg_len"],
72 | # "learning_rate": legacy_data["learn_rate"],
73 | # "batch_size": legacy_data["batch_size"],
74 | # "loss_func": legacy_data["loss_fcn"],
75 | # "update_freq": legacy_data["up_fr"],
76 | # "init_length": legacy_data["init_len"],
77 | # "pre_filter": legacy_data["pre_filt"],
78 | # }
79 | # model_data["training_info"] = training_info
80 |
81 | # if "state_dict" in legacy_data:
82 | # state_dict = legacy_data["state_dict"]
83 | # state_dict = dict(state_dict)
84 | # new_state_dict = {}
85 | # for each in state_dict:
86 | # new_name = each[0:7] + "block_1." + each[9:]
87 | # new_state_dict[new_name] = state_dict[each]
88 | # model_data["state_dict"] = new_state_dict
89 | # return model_data
90 | # else:
91 | # print("format not recognised")
92 |
93 |
94 | """
95 | Temporal FiLM layer
96 | """
97 |
98 |
99 | class TFiLM(torch.nn.Module):
100 | def __init__(self,
101 | nchannels,
102 | block_size=128):
103 | super(TFiLM, self).__init__()
104 | self.nchannels = nchannels
105 | self.block_size = block_size
106 | self.num_layers = 1
107 | self.hidden_state = None # (hidden_state, cell_state)
108 |
109 | # used to downsample input
110 | self.maxpool = torch.nn.MaxPool1d(kernel_size=block_size,
111 | stride=None,
112 | padding=0,
113 | dilation=1,
114 | return_indices=False,
115 | ceil_mode=False)
116 |
117 | self.lstm = torch.nn.LSTM(input_size=nchannels,
118 | hidden_size=nchannels,
119 | num_layers=self.num_layers,
120 | batch_first=False,
121 | bidirectional=False)
122 |
123 | def forward(self, x):
124 | # print("TFiLM: ", x.shape)
125 | # x = [batch, channels, length]
126 | x_shape = x.shape
127 | nsteps = int(x_shape[-1] / self.block_size)
128 |
129 | # downsample
130 | x_down = self.maxpool(x)
131 |
132 | # shape for LSTM (length, batch, channels)
133 | x_down = x_down.permute(2, 0, 1)
134 |
135 | # modulation sequence
136 | if self.hidden_state == None: # state was reset
137 | # init hidden and cell states with zeros
138 | h0 = torch.zeros(self.num_layers, x.size(0), self.nchannels).requires_grad_()
139 | c0 = torch.zeros(self.num_layers, x.size(0), self.nchannels).requires_grad_()
140 | x_norm, self.hidden_state = self.lstm(x_down, (h0.detach(), c0.detach())) # detach for truncated BPTT
141 | else:
142 | x_norm, self.hidden_state = self.lstm(x_down, self.hidden_state)
143 |
144 | # put shape back (batch, channels, length)
145 | x_norm = x_norm.permute(1, 2, 0)
146 |
147 | # reshape input and modulation sequence into blocks
148 | x_in = torch.reshape(
149 | x, shape=(-1, self.nchannels, nsteps, self.block_size))
150 | x_norm = torch.reshape(
151 | x_norm, shape=(-1, self.nchannels, nsteps, 1))
152 |
153 | # multiply
154 | x_out = x_norm * x_in
155 |
156 | # return to original shape
157 | x_out = torch.reshape(x_out, shape=(x_shape))
158 |
159 | return x_out
160 |
161 | def detach_state(self):
162 | if self.hidden_state.__class__ == tuple:
163 | self.hidden_state = tuple([h.clone().detach() for h in self.hidden_state])
164 | else:
165 | self.hidden_state = self.hidden_state.clone().detach()
166 |
167 | def reset_state(self):
168 | self.hidden_state = None
169 |
170 |
171 | """
172 | Gated convolutional layer, zero pads and then applies a causal convolution to the input
173 | """
174 |
175 |
176 | class GatedConv1d(torch.nn.Module):
177 | def __init__(self,
178 | in_ch,
179 | out_ch,
180 | dilation,
181 | kernel_size,
182 | tfilm_block_size):
183 | super(GatedConv1d, self).__init__()
184 | self.in_ch = in_ch
185 | self.out_ch = out_ch
186 | self.dilation = dilation
187 | self.kernal_size = kernel_size
188 | self.tfilm_block_size = tfilm_block_size
189 |
190 | # Layers: Conv1D -> Activations -> TFiLM -> Mix + Residual
191 |
192 | self.conv = torch.nn.Conv1d(in_channels=in_ch,
193 | out_channels=out_ch * 2,
194 | kernel_size=kernel_size,
195 | stride=1,
196 | padding=0,
197 | dilation=dilation)
198 |
199 | self.tfilm = TFiLM(nchannels=out_ch,
200 | block_size=tfilm_block_size)
201 |
202 | self.mix = torch.nn.Conv1d(in_channels=out_ch,
203 | out_channels=out_ch,
204 | kernel_size=1,
205 | stride=1,
206 | padding=0)
207 |
208 | def forward(self, x):
209 | # print("GatedConv1d: ", x.shape)
210 | residual = x
211 |
212 | # dilated conv
213 | y = self.conv(x)
214 |
215 | # gated activation
216 | z = torch.tanh(y[:, :self.out_ch, :]) * \
217 | torch.sigmoid(y[:, self.out_ch:, :])
218 |
219 | # zero pad on the left side, so that z is the same length as x
220 | z = torch.cat((torch.zeros(residual.shape[0],
221 | self.out_ch,
222 | residual.shape[2] - z.shape[2]),
223 | z),
224 | dim=2)
225 |
226 | # modulation
227 | z = self.tfilm(z)
228 |
229 | x = self.mix(z) + residual
230 |
231 | return x, z
232 |
233 |
234 | """
235 | Gated convolutional neural net block, applies successive gated convolutional layers to the input, a total of 'layers'
236 | layers are applied, with the filter size 'kernel_size' and the dilation increasing by a factor of 'dilation_growth' for
237 | each successive layer.
238 | """
239 |
240 |
241 | class GCNBlock(torch.nn.Module):
242 | def __init__(self,
243 | in_ch,
244 | out_ch,
245 | nlayers,
246 | kernel_size,
247 | dilation_growth,
248 | tfilm_block_size):
249 | super(GCNBlock, self).__init__()
250 | self.in_ch = in_ch
251 | self.out_ch = out_ch
252 | self.nlayers = nlayers
253 | self.kernel_size = kernel_size
254 | self.dilation_growth = dilation_growth
255 | self.tfilm_block_size = tfilm_block_size
256 |
257 | dilations = [dilation_growth ** l for l in range(nlayers)]
258 |
259 | self.layers = torch.nn.ModuleList()
260 |
261 | for d in dilations:
262 | self.layers.append(GatedConv1d(in_ch=in_ch,
263 | out_ch=out_ch,
264 | dilation=d,
265 | kernel_size=kernel_size,
266 | tfilm_block_size=tfilm_block_size))
267 | in_ch = out_ch
268 |
269 | def forward(self, x):
270 | # print("GCNBlock: ", x.shape)
271 | # [batch, channels, length]
272 | z = torch.empty([x.shape[0],
273 | self.nlayers * self.out_ch,
274 | x.shape[2]])
275 |
276 | for n, layer in enumerate(self.layers):
277 | x, zn = layer(x)
278 | z[:, n * self.out_ch: (n + 1) * self.out_ch, :] = zn
279 |
280 | return x, z
281 |
282 |
283 | """
284 | Gated Convolutional Neural Net class, based on the 'WaveNet' architecture, takes a single channel of audio as input and
285 | produces a single channel of audio of equal length as output. one-sided zero-padding is used to ensure the network is
286 | causal and doesn't reduce the length of the audio.
287 |
288 | Made up of 'blocks', each one applying a series of dilated convolutions, with the dilation of each successive layer
289 | increasing by a factor of 'dilation_growth'. 'layers' determines how many convolutional layers are in each block,
290 | 'kernel_size' is the size of the filters. Channels is the number of convolutional channels.
291 |
292 | The output of the model is creating by the linear mixer, which sums weighted outputs from each of the layers in the
293 | model
294 | """
295 |
296 |
297 | class GCNTF(torch.nn.Module):
298 | def __init__(self,
299 | nblocks=2,
300 | nlayers=9,
301 | nchannels=8,
302 | kernel_size=3,
303 | dilation_growth=2,
304 | tfilm_block_size=128,
305 | **kwargs):
306 | super(GCNTF, self).__init__()
307 | self.nblocks = nblocks
308 | self.nlayers = nlayers
309 | self.nchannels = nchannels
310 | self.kernel_size = kernel_size
311 | self.dilation_growth = dilation_growth
312 | self.tfilm_block_size = tfilm_block_size
313 |
314 | self.blocks = torch.nn.ModuleList()
315 | for b in range(nblocks):
316 | self.blocks.append(GCNBlock(in_ch=1 if b == 0 else nchannels,
317 | out_ch=nchannels,
318 | nlayers=nlayers,
319 | kernel_size=kernel_size,
320 | dilation_growth=dilation_growth,
321 | tfilm_block_size=tfilm_block_size))
322 |
323 | # output mixing layer
324 | self.blocks.append(
325 | torch.nn.Conv1d(in_channels=nchannels * nlayers * nblocks,
326 | out_channels=1,
327 | kernel_size=1,
328 | stride=1,
329 | padding=0))
330 |
331 | def forward(self, x):
332 | # print("GCN: ", x.shape)
333 | # x.shape = [length, batch, channels]
334 | x = x.permute(1, 2, 0) # change to [batch, channels, length]
335 | z = torch.empty([x.shape[0], self.blocks[-1].in_channels, x.shape[2]])
336 |
337 | for n, block in enumerate(self.blocks[:-1]):
338 | x, zn = block(x)
339 | z[:,
340 | n * self.nchannels * self.nlayers:
341 | (n + 1) * self.nchannels * self.nlayers,
342 | :] = zn
343 |
344 | # back to [length, batch, channels]
345 | return self.blocks[-1](z).permute(2, 0, 1)
346 |
347 | def detach_states(self):
348 | # print("DETACH STATES")
349 | for layer in self.modules():
350 | if isinstance(layer, TFiLM):
351 | layer.detach_state()
352 |
353 | # reset state for all TFiLM layers
354 | def reset_states(self):
355 | # print("RESET STATES")
356 | for layer in self.modules():
357 | if isinstance(layer, TFiLM):
358 | layer.reset_state()
359 |
360 | # train_epoch runs one epoch of training
361 | def train_epoch(self,
362 | input_data,
363 | target_data,
364 | loss_fcn,
365 | optim,
366 | batch_size):
367 | # shuffle the segments at the start of the epoch
368 | shuffle = torch.randperm(input_data.shape[1])
369 |
370 | # iterate over the batches
371 | ep_losses = None
372 |
373 | for batch_i in range(math.ceil(shuffle.shape[0] / batch_size)):
374 | # reset states before starting new batch
375 | self.reset_states()
376 |
377 | # zero all gradients
378 | self.zero_grad()
379 |
380 | # load batch
381 | input_batch = input_data[:,
382 | shuffle[batch_i * batch_size:
383 | (batch_i + 1) * batch_size],
384 | :]
385 | target_batch = target_data[:,
386 | shuffle[batch_i * batch_size:
387 | (batch_i + 1) * batch_size],
388 | :]
389 |
390 | # process batch
391 | output = self(input_batch)
392 |
393 | # loss and backprop
394 | batch_losses = loss_fcn(output, target_batch)
395 |
396 | tot_batch_loss = 0
397 | for loss in batch_losses:
398 | tot_batch_loss += batch_losses[loss]
399 |
400 | tot_batch_loss.backward()
401 | optim.step()
402 |
403 | # add batch losses to epoch losses
404 | # ep_loss += loss
405 | for loss in batch_losses:
406 | if ep_losses == None:
407 | ep_losses = batch_losses
408 | else:
409 | ep_losses[loss] += batch_losses[loss]
410 |
411 | # mean epoch losses
412 | for loss in ep_losses:
413 | ep_losses[loss] /= (batch_i + 1)
414 |
415 | return ep_losses
416 |
417 | def process_data(self,
418 | input_data,
419 | target_data=None,
420 | chunk=16384,
421 | loss_fcn=None,
422 | grad=False):
423 |
424 | rf = self.compute_receptive_field()
425 | # round to next block size
426 | rf = rf + (self.tfilm_block_size - rf % self.tfilm_block_size)
427 |
428 | if not (input_data.shape[0] / chunk).is_integer():
429 | # round to next chunk size
430 | padding = chunk - (input_data.shape[0] % chunk)
431 | input_data = torch.nn.functional.pad(input_data,
432 | (0, 0, 0, 0, 0, padding),
433 | mode='constant',
434 | value=0)
435 | if target_data != None:
436 | target_data = torch.nn.functional.pad(target_data,
437 | (0, 0, 0, 0, 0, padding),
438 | mode='constant',
439 | value=0)
440 |
441 | with torch.no_grad():
442 | # reset states before processing
443 | self.reset_states()
444 |
445 | # process input
446 | output_data = torch.empty_like(input_data)
447 |
448 | for l in range(int(output_data.size()[0] / chunk)):
449 | input_chunk = input_data[l * chunk: (l + 1) * chunk]
450 | if l == 0: # first chunk
451 | padding = torch.zeros([rf, input_chunk.shape[1], input_chunk.shape[2]])
452 | else:
453 | padding = input_data[(l * chunk) - rf: l * chunk]
454 | input_chunk = torch.cat([padding, input_chunk])
455 | output_chunk = self(input_chunk)
456 | output_data[l * chunk: (l + 1) * chunk] = \
457 | output_chunk[rf:, :, :]
458 |
459 | if loss_fcn != None and target_data != None:
460 | losses = loss_fcn(output_data, target_data)
461 | return input_data, target_data, output_data, losses
462 |
463 | # reset states before other computations
464 | self.reset_states()
465 |
466 | return input_data, target_data, output_data
467 |
468 | # this functions saves the model and all its paraemters to a json file, so it can be loaded by a JUCE plugin
469 | def save_model(self,
470 | file_name,
471 | direc=""):
472 | if direc:
473 | utils.dir_check(direc)
474 |
475 | model_data = {"model_data": {"model_type": "gcntf",
476 | "nblocks": self.nblocks,
477 | "nlayers": self.nlayers,
478 | "nchannels": self.nchannels,
479 | "kernel_size": self.kernel_size,
480 | "dilation_growth": self.dilation_growth,
481 | "tfilm_block_size": self.tfilm_block_size}}
482 |
483 | if self.save_state:
484 | model_state = self.state_dict()
485 | for each in model_state:
486 | model_state[each] = model_state[each].tolist()
487 | model_data["state_dict"] = model_state
488 |
489 | utils.json_save(model_data, file_name, direc)
490 |
491 | def compute_receptive_field(self):
492 | """ Compute the receptive field in samples."""
493 | rf = self.kernel_size
494 | for n in range(1, self.nblocks * self.nlayers):
495 | dilation = self.dilation_growth ** (n % self.nlayers)
496 | rf = rf + ((self.kernel_size-1) * dilation)
497 | return rf
498 |
499 | # add any model hyperparameters here
500 | @staticmethod
501 | def add_model_specific_args(parent_parser):
502 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
503 | # --- model related ---
504 | parser.add_argument('--nblocks', type=int, default=2)
505 | parser.add_argument('--nlayers', type=int, default=9)
506 | parser.add_argument('--nchannels', type=int, default=16)
507 | parser.add_argument('--kernel_size', type=int, default=3)
508 | parser.add_argument('--dilation_growth', type=int, default=2)
509 | parser.add_argument('--tfilm_block_size', type=int, default=128)
510 |
511 | return parser
512 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import scipy
5 | import torch
6 | import torchinfo
7 | import torch.utils.tensorboard as tensorboard
8 |
9 | import src.utils as utils
10 | import src.training as training
11 | import src.data as data
12 |
13 | from src.rnn import RNN
14 | from src.gcn import GCN
15 | from src.gcntfilm import GCNTF
16 |
17 | torch.backends.cudnn.benchmark = True
18 |
19 | train_configs = [
20 | # { # from Wright et al. - 2020
21 | # "name": "LSTM32",
22 | # "model_type": "rnn",
23 | # "hidden_size": 32,
24 | # "unit_type": "LSTM",
25 | # "loss_fcns": {"ESRPre": 0.75, "DC": 0.25},
26 | # "prefilt": "a-weighting",
27 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
28 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
29 | # "train_length": 132300,
30 | # "val_chunk": 16384,
31 | # "test_chunk": 16384,
32 | # "validation_p": 20,
33 | # "batch_size": 40,
34 | # "up_fr": 1000,
35 | # },
36 | # { # from Wright et al. - 2020
37 | # "name": "LSTM96",
38 | # "model_type": "rnn",
39 | # "hidden_size": 96,
40 | # "unit_type": "LSTM",
41 | # "loss_fcns": {"ESRPre": 0.75, "DC": 0.25},
42 | # "prefilt": "a-weighting",
43 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
44 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
45 | # "train_length": 132300,
46 | # "val_chunk": 16384,
47 | # "test_chunk": 16384,
48 | # "validation_p": 20,
49 | # "batch_size": 40,
50 | # "up_fr": 1000,
51 | # },
52 | # { # from Wright et al. - 2020
53 | # "name": "GCN1",
54 | # "model_type": "gcn",
55 | # "nblocks": 1,
56 | # "nlayers": 10,
57 | # "nchannels": 16,
58 | # "kernel_size": 3,
59 | # "dilation_growth": 2,
60 | # "loss_fcns": {"ESRPre": 1.0, "DC": 0.0},
61 | # "prefilt": "high_pass",
62 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
63 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
64 | # "train_length": 4410,
65 | # "val_chunk": 441000,
66 | # "test_chunk": 441000,
67 | # "validation_p": 20,
68 | # "batch_size": 40
69 | # },
70 | # { # from Wright et al. - 2020
71 | # "name": "GCN3",
72 | # "model_type": "gcn",
73 | # "nblocks": 2,
74 | # "nlayers": 9,
75 | # "nchannels": 16,
76 | # "kernel_size": 3,
77 | # "dilation_growth": 2,
78 | # "loss_fcns": {"ESRPre": 1.0, "DC": 0.0},
79 | # "prefilt": "high_pass",
80 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
81 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
82 | # "train_length": 4410,
83 | # "val_chunk": 441000,
84 | # "test_chunk": 441000,
85 | # "validation_p": 20,
86 | # "batch_size": 40
87 | # },
88 | {
89 | "name": "LSTM32",
90 | "model_type": "rnn",
91 | "hidden_size": 32,
92 | "unit_type": "LSTM",
93 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
94 | "prefilt": None,
95 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
96 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
97 | "train_length": 112640,
98 | "val_chunk": 112640,
99 | "test_chunk": 112640,
100 | "validation_p": 20,
101 | "batch_size": 6,
102 | "up_fr": 2048,
103 | },
104 | {
105 | "name": "LSTM96",
106 | "model_type": "rnn",
107 | "hidden_size": 96,
108 | "unit_type": "LSTM",
109 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
110 | "prefilt": None,
111 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
112 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
113 | "train_length": 112640,
114 | "val_chunk": 112640,
115 | "test_chunk": 112640,
116 | "validation_p": 20,
117 | "batch_size": 6,
118 | "up_fr": 2048,
119 | },
120 | {
121 | "name": "GCN1",
122 | "model_type": "gcn",
123 | "nblocks": 1,
124 | "nlayers": 10,
125 | "nchannels": 16,
126 | "kernel_size": 3,
127 | "dilation_growth": 2,
128 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
129 | "prefilt": None,
130 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
131 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
132 | "train_length": 112640,
133 | "val_chunk": 112640,
134 | "test_chunk": 112640,
135 | "validation_p": 20,
136 | "batch_size": 6,
137 | },
138 | {
139 | "name": "GCN3",
140 | "model_type": "gcn",
141 | "nblocks": 2,
142 | "nlayers": 9,
143 | "nchannels": 16,
144 | "kernel_size": 3,
145 | "dilation_growth": 2,
146 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
147 | "prefilt": None,
148 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
149 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
150 | "train_length": 112640,
151 | "val_chunk": 112640,
152 | "test_chunk": 112640,
153 | "validation_p": 20,
154 | "batch_size": 6,
155 | },
156 | {
157 | "name": "GCNTF1",
158 | "model_type": "gcntf",
159 | "nblocks": 1,
160 | "nlayers": 10,
161 | "nchannels": 16,
162 | "kernel_size": 3,
163 | "dilation_growth": 2,
164 | "tfilm_block_size": 128,
165 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
166 | "prefilt": None,
167 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
168 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
169 | "train_length": 112640,
170 | "val_chunk": 112640,
171 | "test_chunk": 112640,
172 | "validation_p": 20,
173 | "batch_size": 6
174 | },
175 | {
176 | "name": "GCNTF3",
177 | "model_type": "gcntf",
178 | "nblocks": 2,
179 | "nlayers": 9,
180 | "nchannels": 16,
181 | "kernel_size": 3,
182 | "dilation_growth": 2,
183 | "tfilm_block_size": 128,
184 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
185 | "prefilt": None,
186 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
187 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
188 | "train_length": 112640,
189 | "val_chunk": 112640,
190 | "test_chunk": 112640,
191 | "validation_p": 40,
192 | "batch_size": 6
193 | },
194 | {
195 | "name": "GCNTF2500",
196 | "model_type": "gcntf",
197 | "nblocks": 1,
198 | "nlayers": 10,
199 | "nchannels": 16,
200 | "kernel_size": 5,
201 | "dilation_growth": 3,
202 | "tfilm_block_size": 128,
203 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
204 | "prefilt": None,
205 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
206 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
207 | "train_length": 118400,
208 | "val_chunk": 118400,
209 | "test_chunk": 118400,
210 | "validation_p": 20,
211 | "batch_size": 6
212 | },
213 | {
214 | "name": "GCNTF250",
215 | "model_type": "gcntf",
216 | "nblocks": 1,
217 | "nlayers": 4,
218 | "nchannels": 16,
219 | "kernel_size": 41,
220 | "dilation_growth": 6,
221 | "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
222 | "prefilt": None,
223 | "device": "fuzz-rndamp-G5S10A1msR2500ms",
224 | "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
225 | "train_length": 112640,
226 | "val_chunk": 112640,
227 | "test_chunk": 112640,
228 | "validation_p": 20,
229 | "batch_size": 6
230 | },
231 | # {
232 | # "name": "GCNTF1000",
233 | # "model_type": "gcntf",
234 | # "nblocks": 1,
235 | # "nlayers": 4,
236 | # "nchannels": 16,
237 | # "kernel_size": 41,
238 | # "dilation_growth": 10,
239 | # "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
240 | # "prefilt": None,
241 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
242 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
243 | # "train_length": 112640,
244 | # "val_chunk": 112640,
245 | # "test_chunk": 112640,
246 | # "validation_p": 20,
247 | # "batch_size": 6
248 | # },
249 | # {
250 | # "name": "GCNTF500",
251 | # "model_type": "gcntf",
252 | # "nblocks": 1,
253 | # "nlayers": 4,
254 | # "nchannels": 16,
255 | # "kernel_size": 41,
256 | # "dilation_growth": 8,
257 | # "loss_fcns": {"L1": 0.5, "MSTFT": 0.5},
258 | # "prefilt": None,
259 | # "device": "fuzz-rndamp-G5S10A1msR2500ms",
260 | # "file_name": "fuzz-rndamp-G5S10A1msR2500ms",
261 | # "train_length": 112640,
262 | # "val_chunk": 112640,
263 | # "test_chunk": 112640,
264 | # "validation_p": 20,
265 | # "batch_size": 6
266 | # }
267 | ]
268 |
269 | n_configs = len(train_configs)
270 |
271 | for idx, tconf in enumerate(train_configs):
272 |
273 | parser = argparse.ArgumentParser()
274 |
275 | # add PROGRAM level args
276 | parser.add_argument("--model_type", type=str, default="gcntf", help="rnn, gcn, gcntf")
277 | # data locations, file names and config
278 | parser.add_argument("--device", "-d", default="ht1", help="device label")
279 | parser.add_argument("--data_rootdir", "-dr", default="./data", help="data directory")
280 | parser.add_argument("--file_name", "-fn", default="ht1", help="filename-input.wav and -target.wav")
281 | # parser.add_argument('--load_config', '-l', help="config file path")
282 | # parser.add_argument('--config_location', '-cl', default='configs', help='configs directory')
283 | parser.add_argument("--save_location", "-sloc", default="results", help="trained models directory")
284 | parser.add_argument("--load_model", "-lm", action="store_true", help="load pretrained model")
285 |
286 | # pre-processing of the training/val/test data
287 | parser.add_argument("--train_length", "-trlen", type=int, default=16384, help="training frame length in samples")
288 | parser.add_argument("--val_length", "-vllen", type=int, default=0, help="training frame length in samples")
289 | parser.add_argument("--test_length", "-tslen", type=int, default=0, help="training frame length in samples")
290 |
291 | # number of epochs and validation
292 | parser.add_argument("--epochs", "-eps", type=int, default=2000, help="max epochs")
293 | parser.add_argument("--validation_f", "-vfr", type=int, default=2, help="validation frequency (in epochs)")
294 | parser.add_argument("--validation_p", "-vp", type=int, default=25, help="validation patience or None")
295 |
296 | # settings for the training epoch
297 | parser.add_argument("--batch_size", "-bs", type=int, default=40, help="mini-batch size")
298 | parser.add_argument("--iter_num", "-it", type=int, default=None, help="batch_size set to have --iter_num batches")
299 | parser.add_argument("--learn_rate", "-lr", type=float, default=0.005, help="initial learning rate")
300 | parser.add_argument("--cuda", "-cu", default=1, help="use GPU if available")
301 |
302 | # loss function/s
303 | parser.add_argument("--loss_fcns", "-lf", default={"ESRPre": 0.75, "DC": 0.25}, help="loss functions and weights")
304 | parser.add_argument("--prefilt", "-pf", default="high_pass", help="pre-emphasis filter coefficients, can also read in a csv file")
305 |
306 | # validation and test sets chunk size
307 | parser.add_argument("--val_chunk", "-vs", type=int, default=100000, help="validation chunk length")
308 | parser.add_argument("--test_chunk", "-tc", type=int, default=100000, help="test chunk length")
309 |
310 | # parse general args
311 | args = parser.parse_args()
312 |
313 | # add model specific args
314 | if tconf["model_type"] == "rnn":
315 | parser = RNN.add_model_specific_args(parser)
316 | elif tconf["model_type"] == "gcn":
317 | parser = GCN.add_model_specific_args(parser)
318 | elif tconf["model_type"] == "gcntf":
319 | parser = GCNTF.add_model_specific_args(parser)
320 |
321 | # parse general + model args
322 | args = parser.parse_args()
323 |
324 | # create dictionary with args
325 | dict_args = vars(args)
326 |
327 | # overwrite with temporary configuration
328 | dict_args.update(tconf)
329 |
330 | # set filter args
331 | if dict_args["prefilt"] == "a-weighting":
332 | # as reported in in https://ieeexplore.ieee.org/abstract/document/9052944
333 | dict_args["prefilt"] = [0.85, 1]
334 | elif dict_args["prefilt"] == "high_pass":
335 | # args.prefilt = [-0.85, 1] # as reported in https://ieeexplore.ieee.org/abstract/document/9052944
336 | # as reported in (https://www.mdpi.com/2076-3417/10/3/766/htm)
337 | dict_args["prefilt"] = [-0.95, 1]
338 | else:
339 | dict_args["prefilt"] = None
340 |
341 | # directory where results will be saved
342 | if dict_args["model_type"] == "rnn":
343 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}"
344 | specifier += f"__{dict_args['nlayers']}-{dict_args['hidden_size']}"
345 | specifier += "-skip" if dict_args["skip"] else "-noskip"
346 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}"
347 | elif dict_args["model_type"] == "gcn":
348 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}"
349 | specifier += f"__{dict_args['nblocks']}-{dict_args['nlayers']}-{dict_args['nchannels']}"
350 | specifier += f"-{dict_args['kernel_size']}-{dict_args['dilation_growth']}"
351 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}"
352 | elif dict_args["model_type"] == "gcntf":
353 | specifier = f"{idx+1}-{dict_args['name']}-{dict_args['device']}"
354 | specifier += f"__{dict_args['nblocks']}-{dict_args['nlayers']}-{dict_args['nchannels']}"
355 | specifier += f"-{dict_args['kernel_size']}-{dict_args['dilation_growth']}-{dict_args['tfilm_block_size']}"
356 | specifier += f"__prefilt-{dict_args['prefilt']}-bs{dict_args['batch_size']}"
357 |
358 | # results directory
359 | save_path = os.path.join(dict_args["save_location"], specifier)
360 | utils.dir_check(save_path)
361 |
362 | # set the seed
363 | # TODO
364 |
365 | # init model
366 | if dict_args["model_type"] == "rnn":
367 | model = RNN(**dict_args)
368 | elif dict_args["model_type"] == "gcn":
369 | model = GCN(**dict_args)
370 | elif dict_args["model_type"] == "gcntf":
371 | model = GCNTF(**dict_args)
372 |
373 | # compute rf
374 | if dict_args["model_type"] in ["gcn", "gcntf"]:
375 | dict_args["rf"] = model.compute_receptive_field()
376 |
377 | model.save_state = False
378 | model.save_model("model", save_path)
379 |
380 | # save settings
381 | utils.json_save(dict_args, "config", save_path, indent=4)
382 | print(f"\n* Training config {idx+1}/{n_configs}")
383 | print(dict_args)
384 |
385 | # cuda
386 | if not torch.cuda.is_available() or dict_args["cuda"] == 0:
387 | print("\ncuda device not available/not selected")
388 | cuda = 0
389 | else:
390 | torch.set_default_tensor_type("torch.cuda.FloatTensor")
391 | torch.cuda.set_device(0)
392 | print("\ncuda device available")
393 | model = model.cuda()
394 | cuda = 1
395 |
396 | # optimiser + scheduler + loss fcns
397 | optimiser = torch.optim.Adam(model.parameters(),
398 | lr=dict_args["learn_rate"],
399 | weight_decay=1e-4)
400 |
401 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser,
402 | "min",
403 | factor=0.5,
404 | patience=5,
405 | verbose=True)
406 |
407 | loss_functions = training.LossWrapper(dict_args["loss_fcns"], dict_args["prefilt"])
408 |
409 | # training tracker
410 | train_track = training.TrainTrack()
411 | writer = tensorboard.SummaryWriter(os.path.join("results", specifier))
412 |
413 | # dataset
414 | dataset = data.DataSet(data_dir=dict_args["data_rootdir"])
415 |
416 | dataset.create_subset("train", frame_len=dict_args["train_length"])
417 | dataset.load_file(os.path.join("train", dict_args["file_name"]), "train")
418 | print("\ntrain dataset: ", dataset.subsets["train"].data["input"][0].shape)
419 |
420 | dataset.create_subset("val", frame_len=dict_args["val_length"])
421 | dataset.load_file(os.path.join("val", dict_args["file_name"]), "val")
422 | print("val dataset: ", dataset.subsets["val"].data["input"][0].shape)
423 |
424 | dataset.create_subset("test", frame_len=dict_args["test_length"])
425 | dataset.load_file(os.path.join("test", dict_args["file_name"]), "test")
426 | print("test dataset: ", dataset.subsets["test"].data["input"][0].shape)
427 |
428 | # summary
429 | print()
430 | torchinfo.summary(model,
431 | input_size=(dict_args["train_length"], dict_args["batch_size"], 1),
432 | device=None)
433 |
434 | # ===== TRAIN ===== #
435 | start_time = time.time()
436 |
437 | model.save_state = True
438 | patience_counter = 0
439 | init_time = 0
440 |
441 | for epoch in range(train_track['current_epoch'] + 1, dict_args["epochs"] + 1):
442 | ep_st_time = time.time()
443 |
444 | # run 1 epoch of training
445 | if dict_args["model_type"] == "rnn":
446 | epoch_losses = model.train_epoch(dataset.subsets['train'].data['input'][0],
447 | dataset.subsets['train'].data['target'][0],
448 | loss_functions,
449 | optimiser,
450 | dict_args["batch_size"],
451 | up_fr=dict_args["up_fr"])
452 | else:
453 | epoch_losses = model.train_epoch(dataset.subsets['train'].data['input'][0],
454 | dataset.subsets['train'].data['target'][0],
455 | loss_functions,
456 | optimiser,
457 | dict_args["batch_size"])
458 | epoch_loss = 0
459 | for loss in epoch_losses:
460 | epoch_loss += epoch_losses[loss]
461 | print(f"epoch {epoch} | \ttrain loss: \t{epoch_loss:0.4f}", end="")
462 | for loss in epoch_losses:
463 | print(f" | \t{loss}: \t{epoch_losses[loss]:0.4f}", end="")
464 | print()
465 |
466 | # ===== VALIDATION ===== #
467 | if epoch % dict_args["validation_f"] == 0:
468 | val_ep_st_time = time.time()
469 | val_input, val_target, val_output, val_losses = \
470 | model.process_data(dataset.subsets['val'].data['input'][0],
471 | dataset.subsets['val'].data['target'][0],
472 | loss_fcn=loss_functions,
473 | chunk=dict_args["val_chunk"])
474 |
475 | # val losses
476 | val_loss = 0
477 | for loss in val_losses:
478 | val_loss += val_losses[loss]
479 | print(f"\t\tval loss: \t{val_loss:0.4f}", end="")
480 | for loss in val_losses:
481 | print(f" | \t{loss}: \t{val_losses[loss]:0.4f}", end="")
482 | print()
483 |
484 | # update lr
485 | scheduler.step(val_loss)
486 |
487 | # save best model
488 | if val_loss < train_track['val_loss_best']:
489 | patience_counter = 0
490 | model.save_model('model_best', save_path)
491 | scipy.io.wavfile.write(os.path.join(save_path, "best_val_out.wav"),
492 | dataset.subsets['test'].fs,
493 | val_output.cpu().numpy()[:, 0, 0])
494 | else:
495 | patience_counter += 1
496 |
497 | # log validation losses
498 | for loss in val_losses:
499 | val_losses[loss] = val_losses[loss].item()
500 |
501 | train_track.val_epoch_update(val_loss=val_loss.item(),
502 | val_losses=val_losses,
503 | ep_st_time=val_ep_st_time,
504 | ep_end_time=time.time())
505 |
506 | writer.add_scalar('Loss/Val (Tot)', val_loss, epoch)
507 | for loss in val_losses:
508 | writer.add_scalar(f"Loss/Val ({loss})", val_losses[loss], epoch)
509 |
510 | # log training losses
511 | for loss in epoch_losses:
512 | epoch_losses[loss] = epoch_losses[loss].item()
513 |
514 | train_track.train_epoch_update(epoch_loss=epoch_loss.item(),
515 | epoch_losses=epoch_losses,
516 | ep_st_time=ep_st_time,
517 | ep_end_time=time.time(),
518 | init_time=init_time,
519 | current_ep=epoch)
520 |
521 | writer.add_scalar('Loss/Train (Tot)', epoch_loss, epoch)
522 | for loss in epoch_losses:
523 | writer.add_scalar(f"Loss/Train ({loss})", epoch_losses[loss], epoch)
524 |
525 | # log learning rate
526 | writer.add_scalar('LR/current', optimiser.param_groups[0]['lr'])
527 |
528 | # save model
529 | model.save_model('model', save_path)
530 |
531 | # log training stats to json
532 | utils.json_save(train_track, 'training_stats', save_path, indent=4)
533 |
534 | # check early stopping
535 | if dict_args["validation_p"] and patience_counter > dict_args["validation_p"]:
536 | print('\nvalidation patience limit reached at epoch ' + str(epoch))
537 | break
538 |
539 | # ===== TEST (last model) ===== #
540 | test_input, test_target, test_output, test_losses = \
541 | model.process_data(dataset.subsets['test'].data['input'][0],
542 | dataset.subsets['test'].data['target'][0],
543 | loss_fcn=loss_functions,
544 | chunk=dict_args["test_chunk"])
545 |
546 | # test losses (last)
547 | test_loss = 0
548 | for loss in test_losses:
549 | test_loss += test_losses[loss]
550 | print(f"\ttest loss (last): \t{test_loss:0.4f}", end="")
551 | for loss in test_losses:
552 | print(f" | \t{loss}: \t{test_losses[loss]:0.4f}", end="")
553 | print()
554 |
555 | lossESR = training.ESRLoss() # include ESR loss
556 | test_loss_ESR = lossESR(test_output, test_target)
557 |
558 | # save output audio
559 | scipy.io.wavfile.write(os.path.join(save_path, "test_out_final.wav"),
560 | dataset.subsets['test'].fs,
561 | test_output.cpu().numpy()[:, 0, 0])
562 |
563 | # log test losses
564 | for loss in test_losses:
565 | test_losses[loss] = test_losses[loss].item()
566 |
567 | train_track['test_loss_final'] = test_loss.item()
568 | train_track['test_losses_final'] = test_losses
569 | train_track['test_lossESR_final'] = test_loss_ESR.item()
570 |
571 | writer.add_scalar('Loss/Test/Last (Tot)', test_loss, 0)
572 | for loss in test_losses:
573 | writer.add_scalar(f"Loss/Test/Last ({loss})", test_losses[loss], 0)
574 |
575 | # ===== TEST (best validation model) ===== #
576 | best_val_net = utils.json_load('model_best', save_path)
577 | model = utils.load_model(best_val_net)
578 |
579 | test_input, test_target, test_output, test_losses = \
580 | model.process_data(dataset.subsets['test'].data['input'][0],
581 | dataset.subsets['test'].data['target'][0],
582 | loss_fcn=loss_functions,
583 | chunk=dict_args["test_chunk"])
584 |
585 | # test losses (best)
586 | test_loss = 0
587 | for loss in test_losses:
588 | test_loss += test_losses[loss]
589 | print(f"\ttest loss (best): \t{test_loss:0.4f}", end="")
590 | for loss in test_losses:
591 | print(f" | \t{loss}: \t{test_losses[loss]:0.4f}", end="")
592 | print()
593 |
594 | test_loss_ESR = lossESR(test_output, test_target)
595 |
596 | # save output audio
597 | scipy.io.wavfile.write(os.path.join(save_path, "test_out_bestv.wav"),
598 | dataset.subsets['test'].fs,
599 | test_output.cpu().numpy()[:, 0, 0])
600 |
601 | # log test losses
602 | for loss in test_losses:
603 | test_losses[loss] = test_losses[loss].item()
604 |
605 | train_track['test_loss_best'] = test_loss.item()
606 | train_track['test_losses_best'] = test_losses
607 | train_track['test_lossESR_best'] = test_loss_ESR.item()
608 |
609 | writer.add_scalar('Loss/Test/Best (Tot)', test_loss, 0)
610 | for loss in test_losses:
611 | writer.add_scalar(f"Loss/Test/Best ({loss})", test_losses[loss], 0)
612 |
613 | # log training stats to json
614 | utils.json_save(train_track, 'training_stats', save_path, indent=4)
615 |
616 | if cuda:
617 | with open(os.path.join(save_path, 'maxmemusage.txt'), 'w') as f:
618 | f.write(str(torch.cuda.max_memory_allocated()))
619 |
620 | stop_time = time.time()
621 | print(f"\ntraining time: {(stop_time-start_time)/60:0.2f} min")
622 |
--------------------------------------------------------------------------------