├── predictions ├── cnbf_real_prediction.png └── cnbf_complex_prediction.png ├── cnbf.json ├── LICENSE ├── utils ├── keras_helpers.py ├── mat_helpers.py └── matplotlib_helpers.py ├── loaders ├── audio_loader.py ├── feature_generator.py └── rir_generator.py ├── README.md ├── ops ├── kernelized_layers.py ├── complex_ops.py └── complex_layers.py ├── algorithms └── audio_processing.py └── experiments ├── cnbf_complex.py └── cnbf_real.py /predictions/cnbf_real_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rrbluke/CNBF/HEAD/predictions/cnbf_real_prediction.png -------------------------------------------------------------------------------- /predictions/cnbf_complex_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rrbluke/CNBF/HEAD/predictions/cnbf_complex_prediction.png -------------------------------------------------------------------------------- /cnbf.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 100000, 3 | "fs": 16000, 4 | "duration": 5.0, 5 | "wlen": 1024, 6 | "shift": 256, 7 | "nsrc": 2, 8 | "predictions_path": "../predictions/", 9 | "weights_path": "../weights/", 10 | "train_path": "/clusterFS/project/beamforming/data/wsj0/si_tr_*/*/", 11 | "eval_path": "/clusterFS/project/beamforming/data/wsj0/si_et_*/*/", 12 | "test_path": "/clusterFS/project/beamforming/data/wsj0/si_et_*/*/" 13 | } 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 rrbluke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/keras_helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import numpy as np 6 | import time 7 | 8 | import keras.backend as K 9 | import tensorflow as tf 10 | 11 | 12 | 13 | #----------------------------------------------------- 14 | class Logger(tf.keras.callbacks.Callback): 15 | 16 | def __init__(self, name): 17 | self.name = name 18 | self.losses = [] 19 | self.iteration = 0 20 | 21 | def on_epoch_begin(self, epoch, logs=None): 22 | self.epoch_time_start = time.time() 23 | self.losses = [] 24 | 25 | def on_batch_end(self, batch, logs=None): 26 | if np.isnan(np.sum(logs['loss'])): 27 | quit() 28 | self.losses = np.append(self.losses, logs['loss']) 29 | #print('end of batch: ', logs['loss'].shape) 30 | 31 | def on_epoch_end(self, epoch, logs=None): 32 | self.epoch_time_end = time.time() 33 | duration = self.epoch_time_end-self.epoch_time_start 34 | self.iteration += 1 35 | 36 | print('model: %s, iteration: %d, epoch: %d, runtime: %.3fs, loss (avg/min/max): %.3f/%.3f/%.3f' % \ 37 | (self.name, self.iteration, epoch, duration, np.mean(self.losses), np.amin(self.losses), np.amax(self.losses)) ) 38 | 39 | 40 | 41 | #----------------------------------------------------- 42 | def Debug(name, x): 43 | 44 | # print the dynamic shape of tensor x during runtime 45 | print_op = tf.print(name, '.shape =', tf.shape(x), '.dtype=', x.dtype, '.value=', x) 46 | with tf.control_dependencies([print_op]): 47 | return tf.identity(x) 48 | 49 | 50 | #----------------------------------------------------- 51 | def log10(x): 52 | 53 | return tf.math.log(x) / 2.302585092994046 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /loaders/audio_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import numpy as np 6 | import glob 7 | import sys 8 | import os 9 | 10 | sys.path.append(os.path.abspath('../')) 11 | from algorithms.audio_processing import * 12 | 13 | 14 | 15 | # loader class for mono wav files, i.e. wsj0 16 | 17 | class audio_loader(object): 18 | 19 | # -------------------------------------------------------------------------- 20 | def __init__(self, config, set): 21 | 22 | self.fs = config['fs'] 23 | self.wlen = config['wlen'] 24 | self.shift = config['shift'] 25 | self.samples = int(self.fs*config['duration']) 26 | self.nfram = int(np.ceil( (self.samples-self.wlen+self.shift)/self.shift )) 27 | self.nbin = int(self.wlen/2+1) 28 | 29 | 30 | if set == 'train': 31 | path = config['train_path'] 32 | elif set == 'test': 33 | path = config['test_path'] 34 | elif set == 'eval': 35 | path = config['eval_path'] 36 | else: 37 | print('unknown set name: ', set) 38 | quit(0) 39 | 40 | self.file_list = glob.glob(path+'*.wav') 41 | self.numof_files = len(self.file_list) 42 | 43 | print('*** audio_loader found %d files in: %s' % (self.numof_files, path)) 44 | 45 | 46 | 47 | #------------------------------------------------------------------------- 48 | def concatenate_random_files(self,): 49 | 50 | x = np.zeros((self.samples,), dtype=np.float32) 51 | n = 0 52 | while n 16 | # from the m-file 17 | 18 | def load_numpy_from_mat(matfile, varnames=None, hdf5=False): 19 | 20 | matdata = load_dict_from_mat(matfile=matfile, hdf5=hdf5) 21 | if matdata is None: 22 | return None 23 | 24 | #return all variables from the mat-file 25 | if varnames is None: 26 | return matdata 27 | 28 | #make list with one element 29 | if type(varnames) is not list: 30 | varnames = [varnames] 31 | 32 | data = {} 33 | for varname in varnames: 34 | 35 | #check if matfile contains the requested variable 36 | if varname in matdata: 37 | 38 | #load variable from matdata 39 | x = matdata[varname] 40 | 41 | #copy variable to rearrange strides 42 | y = np.copy(x) 43 | 44 | data[varname] = y 45 | 46 | return data 47 | 48 | 49 | #--------------------------------------------------------- 50 | # saves a dict of numpy arrays 51 | # to the m-file 52 | 53 | def save_numpy_to_mat(matfile, data, hdf5=False, overwrite=False): 54 | 55 | return save_dict_to_mat(matfile=matfile, matdata=data, hdf5=hdf5, overwrite=overwrite) 56 | 57 | 58 | #--------------------------------------------------------- 59 | # saves a dictionary to a matfile 60 | # existing contents are preserved 61 | 62 | def save_dict_to_mat(matfile, matdata, hdf5=False, overwrite=False): 63 | 64 | if matfile is None: 65 | return False 66 | 67 | #create folder if it doesn't exist 68 | path = os.path.dirname(os.path.abspath(matfile)) 69 | if not os.path.exists(path): 70 | os.makedirs(path) 71 | 72 | if hdf5 is False: 73 | if overwrite is False: 74 | #to append to the file, its contents have to be read, updated and written 75 | existing_data = load_dict_from_mat(matfile, hdf5=hdf5) 76 | else: 77 | existing_data = None 78 | 79 | if existing_data is not None: 80 | existing_data.update(matdata) 81 | scipy.io.savemat(matfile, existing_data) 82 | else: 83 | scipy.io.savemat(matfile, matdata) 84 | else: 85 | #appends to the file by default 86 | hdf5storage.savemat(matfile, matdata, compress=False) 87 | 88 | return True 89 | 90 | 91 | #--------------------------------------------------------- 92 | # loads a dictionary from a matfile 93 | # the dictionary will contain ndarray objects! 94 | # for details, see: https://docs.scipy.org/doc/scipy-0.18.1/reference/tutorial/io.html 95 | 96 | def load_dict_from_mat(matfile, hdf5=False): 97 | 98 | if matfile is None: 99 | return None 100 | 101 | if os.path.isfile(matfile): 102 | if hdf5 is False: 103 | try: 104 | matdata = scipy.io.loadmat(matfile) 105 | except: 106 | #avoid false positive corruption check, see: 107 | #https://github.com/scipy/scipy/issues/6999 108 | try: 109 | matdata = scipy.io.loadmat(matfile, verify_compressed_data_integrity=False) 110 | except: 111 | return None 112 | 113 | else: 114 | matdata = hdf5storage.loadmat(matfile) 115 | 116 | return matdata 117 | 118 | else: 119 | return None 120 | 121 | 122 | -------------------------------------------------------------------------------- /loaders/feature_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import time 6 | import glob 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | import numpy as np 12 | import pyroomacoustics as pra 13 | 14 | sys.path.append(os.path.abspath('../')) 15 | 16 | from loaders.audio_loader import audio_loader 17 | from loaders.rir_generator import rir_generator 18 | from algorithms.audio_processing import * 19 | from utils.mat_helpers import * 20 | 21 | 22 | 23 | #-------------------------------------------------------------------------- 24 | #-------------------------------------------------------------------------- 25 | class feature_generator(object): 26 | 27 | 28 | #-------------------------------------------------------------------------- 29 | def __init__(self, config, set='train'): 30 | 31 | self.set = set 32 | self.config = config 33 | self.fs = config['fs'] 34 | self.wlen = config['wlen'] 35 | self.shift = config['shift'] 36 | self.samples = int(self.fs*config['duration']) 37 | self.nfram = int(np.ceil( (self.samples-self.wlen+self.shift)/self.shift )) 38 | self.nbin = int(self.wlen/2+1) 39 | 40 | self.nsrc = config['nsrc'] 41 | assert(self.nsrc == 2) # only 2 sources are supported 42 | 43 | self.audio_loader = audio_loader(config, set) 44 | self.rgen = rir_generator(config, set) 45 | self.nmic = self.rgen.nmic 46 | 47 | 48 | 49 | #--------------------------------------------------------- 50 | def generate_mixture(self,): 51 | 52 | hs, hn = self.rgen.load_rirs() 53 | s = self.audio_loader.concatenate_random_files() # shape = (samples,) 54 | n = self.audio_loader.concatenate_random_files() # shape = (samples,) 55 | 56 | Fhs = rfft(hs, n=self.samples, axis=0) # shape = (samples/2+1, nmic) 57 | Fhn = rfft(hn, n=self.samples, axis=0) # shape = (samples/2+1, nmic) 58 | 59 | Fs = rfft(s, n=self.samples, axis=0) # shape = (samples/2+1,) 60 | Fn = rfft(n, n=self.samples, axis=0) # shape = (samples/2+1,) 61 | 62 | Fs = Fhs*Fs[:,np.newaxis] 63 | Fn = Fhn*Fn[:,np.newaxis] 64 | 65 | s = irfft(Fs, n=self.samples, axis=0) # shape = (samples, nmic) 66 | n = irfft(Fn, n=self.samples, axis=0) # shape = (samples, nmic) 67 | 68 | Fs = mstft(s.T, self.wlen, self.shift) # shape = (nmic, nfram, nbin) 69 | Fs = np.transpose(Fs, (1,2,0)) # shape = (nfram, nbin, nmic) 70 | 71 | Fn = mstft(n.T, self.wlen, self.shift) # shape = (nmic, nfram, nbin) 72 | Fn = np.transpose(Fn, (1,2,0)) # shape = (nfram, nbin, nmic) 73 | 74 | Fs = self.rgen.whiten_data(Fs) 75 | Fn = self.rgen.whiten_data(Fn) 76 | 77 | return Fs, Fn 78 | 79 | 80 | 81 | #--------------------------------------------------------- 82 | def generate_mixtures(self, nbatch=10): 83 | 84 | Fs = np.zeros(shape=(nbatch, self.nfram, self.nbin, self.nmic), dtype=np.complex64) 85 | Fn = np.zeros(shape=(nbatch, self.nfram, self.nbin, self.nmic), dtype=np.complex64) 86 | for b in np.arange(nbatch): 87 | 88 | Fs[b,...], Fn[b,...] = self.generate_mixture() 89 | 90 | return Fs, Fn 91 | 92 | 93 | 94 | 95 | 96 | #--------------------------------------------------------- 97 | #--------------------------------------------------------- 98 | if __name__ == "__main__": 99 | 100 | 101 | parser = argparse.ArgumentParser(description='mcss feature generator') 102 | parser.add_argument('--config_file', help='name of json configuration file', default='../cnbf.json') 103 | args = parser.parse_args() 104 | 105 | 106 | with open(args.config_file, 'r') as f: 107 | config = json.load(f) 108 | 109 | 110 | fgen = feature_generator(config, set='train') 111 | 112 | 113 | t0 = time.time() 114 | Fs, Fn = fgen.generate_mixture() 115 | t1 = time.time() 116 | print(t1-t0) 117 | 118 | data = { 119 | 'Fs': Fs, 120 | 'Fn': Fn, 121 | } 122 | save_numpy_to_mat('../matlab/fgen_check.mat', data) 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /utils/matplotlib_helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | 9 | import matplotlib 10 | #matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | from matplotlib.colors import LinearSegmentedColormap 13 | 14 | 15 | 16 | #------------------------------------------------------------------------- 17 | # data = tuple of data matrices 18 | # legend = tuple of labels 19 | # each tuple entry is plotted into a subplot 20 | # legend entries for each subplot are comma-separated entries in the name tuple 21 | 22 | def draw_subplots(data, legend, filename=None): 23 | 24 | if filename is None: 25 | plt.ion() #interactive, non-blocking plots 26 | fig = plt.gcf() 27 | if fig is None: 28 | fig = plt.figure() 29 | else: 30 | fig.clf() 31 | else: 32 | plt.switch_backend('agg') 33 | plt.ioff() 34 | fig = plt.figure() 35 | 36 | #define custom color wheel 37 | colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] 38 | 39 | #create empty subplots 40 | fig, axes = plt.subplots(nrows=len(data), ncols=1) 41 | fig.set_size_inches(10, min(2*len(data),10)) 42 | 43 | #use vertical padding of 2.5x the font size between the subplots to avoid overlapping labels 44 | fig.tight_layout(h_pad=2.5, rect=(0,0,1,0.97)) 45 | 46 | #each sublabels,data tuple goes into a subplot 47 | for i, x in enumerate(data): 48 | ax = plt.subplot(len(data), 1, i+1) 49 | 50 | if x.ndim==1: 51 | x = x[:,np.newaxis] 52 | elif x.ndim==2 and x.shape[0]2: 55 | print('x.ndim must be less than 3') 56 | quit(0) 57 | 58 | #extract labels for each subplot 59 | labels = legend[i].split(',') 60 | c_idx = 0 61 | for j in range(x.shape[1]): 62 | if j W -> B 94 | cmap = LinearSegmentedColormap.from_list('meow', colors, N=64) 95 | 96 | plt.switch_backend('agg') 97 | plt.ioff() 98 | fig = plt.figure() 99 | 100 | #create empty subplots 101 | fig, axes = plt.subplots(nrows=len(data), ncols=1) 102 | fig.set_size_inches(10, min(5*len(data),20)) 103 | 104 | #use vertical padding of 2x the font size between the subplots to avoid overlapping labels 105 | fig.tight_layout(h_pad=2, rect=(0,0,0.99,0.97)) 106 | 107 | #each legend,data tuple goes into a subplot 108 | for i, x in enumerate(data): 109 | ax = plt.subplot(len(data), 1, i+1) 110 | plt.pcolor(x, cmap=cmap) 111 | plt.title(legend[i]) 112 | ax.set_xlim([0, x.shape[1]]) 113 | ax.set_ylim([0, x.shape[0]]) 114 | plt.colorbar(aspect=20, fraction=0.05) 115 | plt.clim(*clim) 116 | 117 | fig.savefig(filename, dpi=200) 118 | 119 | 120 | 121 | 122 | # --------------------------------------------------------------------- 123 | def pcolor(x, filename='pcolor', x_min=None, x_max=None): 124 | 125 | plt.switch_backend('agg') 126 | 127 | path, name = os.path.split(filename) 128 | if not os.path.exists(path) and path != '': 129 | os.makedirs(path) 130 | 131 | # set font 132 | font = {'weight': 'bold', 'size': 12} 133 | matplotlib.rc('font', **font) 134 | 135 | plt.ioff() 136 | fig = plt.figure() 137 | 138 | plt.pcolor(x.T, cmap='jet') 139 | if x_min is not None and x_max is not None: plt.clim(x_min, x_max) 140 | plt.colorbar() 141 | plt.xlabel('nfram') 142 | plt.ylabel('nbin') 143 | plt.title(name) 144 | 145 | # save figure, scaled*1.5 146 | plt.savefig(filename, dpi=fig.dpi*1.5) 147 | plt.close(fig) 148 | 149 | -------------------------------------------------------------------------------- /ops/kernelized_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import os 6 | import sys 7 | 8 | import tensorflow as tf 9 | from keras.layers import Layer, RNN 10 | from keras import activations 11 | 12 | sys.path.append(os.path.abspath('../')) 13 | from utils.keras_helpers import * 14 | from ops.complex_ops import * 15 | 16 | 17 | 18 | #------------------------------------------------------------------- 19 | 20 | class Kernelized_Dense(Layer): 21 | 22 | def __init__(self, units, activation='tanh'): 23 | 24 | super(Kernelized_Dense, self).__init__() 25 | self.units = units 26 | self.activation = activations.get(activation) 27 | 28 | 29 | def build(self, input_shape): 30 | 31 | # input_shape = (..., nkernels, n_in) 32 | nkernels = input_shape[-2] 33 | n_in = input_shape[-1] 34 | 35 | self.W = self.add_weight(name='W', shape=(nkernels, n_in, self.units), initializer='random_normal', dtype=tf.float32) 36 | self.b = self.add_weight(name='b', shape=(nkernels, self.units), initializer='zeros', dtype=tf.float32) 37 | 38 | super(Kernelized_Dense, self).build(input_shape) 39 | 40 | 41 | def call(self, inputs): 42 | 43 | x = inputs # shape = (..., nkernels, n_in) 44 | 45 | z = tf.einsum('...ki,kij->...kj', x, self.W) + self.b # shape = (..., nkernels, units) 46 | 47 | if self.activation is not None: 48 | z = self.activation(z) 49 | 50 | return z 51 | 52 | 53 | def compute_output_shape(self, input_shape): 54 | 55 | output_shape = list(input_shape) 56 | output_shape[-1] = self.units 57 | return tuple(output_shape) 58 | 59 | 60 | 61 | 62 | 63 | #------------------------------------------------------------------- 64 | 65 | class Kernelized_LSTM(Layer): 66 | 67 | def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', return_sequences=True, go_backwards=False): 68 | 69 | super(Kernelized_LSTM, self).__init__() 70 | self.units = units 71 | self.activation = activation 72 | self.recurrent_activation = recurrent_activation 73 | self.return_sequences = return_sequences 74 | self.go_backwards = go_backwards 75 | 76 | 77 | def build(self, input_shape): 78 | 79 | # input to the kernelized LSTM is a 4D tensor: 80 | nbatch, nfram, kernels, n_in = input_shape 81 | 82 | cell = self.Cell(kernels, self.units, self.activation, self.recurrent_activation) 83 | self.rnn = RNN(cell, return_sequences=self.return_sequences, go_backwards=self.go_backwards) 84 | 85 | # the Keras RNN implementation does only work with 3D tensors, hence we flatten the last two dimensions of the input: 86 | self.rnn.build(input_shape=(nbatch, nfram, kernels*n_in)) 87 | self._trainable_weights = self.rnn.trainable_weights 88 | super(Kernelized_LSTM, self).build(input_shape) 89 | 90 | 91 | def call(self, inputs): 92 | 93 | x = inputs # shape = (nbatch, nfram, kernels, n_in) 94 | 95 | # reshape input to 3D 96 | nbatch = tf.shape(x)[0] 97 | nfram = tf.shape(x)[1] 98 | kernels = tf.shape(x)[2] 99 | x = tf.reshape(x, [nbatch, nfram, -1]) 100 | 101 | # reshape output to 4D 102 | y = self.rnn(x) 103 | y = tf.reshape(y, [nbatch, nfram, kernels, self.units]) 104 | 105 | # reverse time axis back to normal 106 | if self.go_backwards is True: 107 | y = tf.reverse(y, axis=[1]) 108 | 109 | return y 110 | 111 | 112 | def compute_output_shape(self, input_shape): 113 | 114 | output_shape = list(input_shape) 115 | output_shape[-1] = self.units 116 | return tuple(output_shape) 117 | 118 | 119 | 120 | class Cell(Layer): 121 | 122 | def __init__(self, kernels, units, activation='tanh', recurrent_activation='hard_sigmoid'): 123 | 124 | super(Kernelized_LSTM.Cell, self).__init__() 125 | self.activation = activations.get(activation) 126 | self.recurrent_activation = activations.get(recurrent_activation) 127 | self.units = units # = data size of the output 128 | self.kernels = kernels # = kernel size of the output 129 | self.state_size = (kernels*units, kernels*units) # = flattened sizes of the hidden and carry state 130 | self.output_size = kernels*units # = flattened size of the output 131 | 132 | 133 | def build(self, input_shape): 134 | 135 | # the input of the Cell is a 3D tensor with shape (nbatch, nfram, kernels*n_in) 136 | n_in = int(input_shape[-1]/self.kernels) 137 | 138 | self.W = self.add_weight(shape=(self.kernels, n_in, self.units*4), name='W', initializer='glorot_uniform') 139 | self.U = self.add_weight(shape=(self.kernels, self.units, self.units*4), name='U', initializer='orthogonal') 140 | self.b = self.add_weight(shape=(self.kernels, self.units*4), name='b', initializer='zeros') 141 | 142 | super(Kernelized_LSTM.Cell, self).build(input_shape) 143 | 144 | 145 | # this function is called every time steps 146 | def call(self, inputs, states, training=None): 147 | 148 | x = inputs # shape = (nbatch, kernels*n_in) 149 | nbatch = tf.shape(x)[0] 150 | x = tf.reshape(x, [nbatch, self.kernels, -1]) # expand input to 3D 151 | h_tm1 = tf.reshape(states[0], [nbatch, self.kernels, self.units]) # expand previous hidden state to 3D 152 | c_tm1 = tf.reshape(states[1], [nbatch, self.kernels, self.units]) # expand previous carry state to 3D 153 | 154 | 155 | z = tf.einsum('...ki,kij->...kj', x, self.W) # shape = (..., kernels, units*4) 156 | z += tf.einsum('...ki,kij->...kj', h_tm1, self.U) # shape = (..., kernels, units*4) 157 | z += self.b 158 | 159 | a, i, f, o = [ z[..., i*self.units:(i+1)*self.units] for i in range(4) ] 160 | 161 | a = self.activation(a) 162 | i = self.recurrent_activation(i) 163 | f = self.recurrent_activation(f) 164 | o = self.recurrent_activation(o) 165 | 166 | c = a*i + f*c_tm1 167 | h = o*self.activation(c) 168 | 169 | 170 | # flatten new hidden and carry state back to 2D 171 | h = tf.reshape(h, [nbatch, -1]) # shape = (nbatch, kernels*units) 172 | c = tf.reshape(c, [nbatch, -1]) 173 | 174 | return h, [h, c] 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /loaders/rir_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import time 6 | import glob 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | import numpy as np 12 | import pyroomacoustics as pra 13 | 14 | sys.path.append(os.path.abspath('../')) 15 | 16 | from utils.mat_helpers import * 17 | 18 | 19 | 20 | #-------------------------------------------------------------------------- 21 | #-------------------------------------------------------------------------- 22 | class rir_generator(object): 23 | 24 | 25 | #-------------------------------------------------------------------------- 26 | def __init__(self, config, set='train'): 27 | 28 | self.set = set 29 | self.config = config 30 | self.fs = config['fs'] 31 | self.samples = int(self.fs*1.0) 32 | self.wlen = config['wlen'] 33 | self.shift = config['shift'] 34 | self.nbin = int(self.wlen/2+1) 35 | self.set = set 36 | 37 | self.nsrc = config['nsrc'] 38 | assert(self.nsrc == 2) # only 2 sources are supported 39 | 40 | self.rir_file = '../loaders/rir_cache.mat' 41 | 42 | self.define_mic_array() 43 | self.generate_whitening_matrix() 44 | self.cache_rirs() 45 | 46 | 47 | 48 | #------------------------------------------------------------------------- 49 | def define_mic_array(self): 50 | 51 | self.nmic = 6 # number of microphones 52 | self.radius = 46.3/1000 # radius of the respeaker core v2 microphone array 53 | self.c = 343.0 # speed of sound at 20°C 54 | 55 | #mics 1..6 are on a circle 56 | self.micpos = np.zeros((self.nmic, 3)) 57 | for m in np.arange(self.nmic): 58 | a = -2*np.pi*m/self.nmic # microphones are arranged clockwise! 59 | self.micpos[m,0] = self.radius*np.cos(a) 60 | self.micpos[m,1] = self.radius*np.sin(a) 61 | self.micpos[m,2] = 0 62 | 63 | 64 | 65 | #---------------------------------------------------------------------------- 66 | def generate_whitening_matrix(self): 67 | 68 | dist = self.micpos[:,np.newaxis,:] - self.micpos[np.newaxis,:,:] # shape = (self.nmic, self.nmic, 3) 69 | dist = np.linalg.norm(dist, axis=-1) # shape = (self.nmic, self.nmic) 70 | tau = dist/self.c 71 | 72 | self.U = np.zeros((self.nbin, self.nmic, self.nmic), dtype=np.complex64) # whitening matrix 73 | for k in range(self.nbin): 74 | fc = self.fs*k/((self.nbin-1)*2) 75 | Cnn = np.sinc(2*fc*tau) # spherical coherence matrix 76 | d, E = np.linalg.eigh(Cnn) 77 | d = np.maximum(d.real, 1e-3) 78 | iD = np.diag(1/np.sqrt(d)) 79 | 80 | # ZCA whitening 81 | self.U[k,:,:] = np.dot(E, np.dot(iD, E.T.conj())) # U = E*D^-0.5*E' 82 | 83 | 84 | 85 | #---------------------------------------------------------------------------- 86 | def whiten_data(self, Fs): 87 | 88 | # U.shape = (nbin, nmic, nmic) 89 | # Fs.shape = (..., nbin, nmic) 90 | Fus = np.einsum('kdc, ...kc->...kd', self.U, Fs) # shape = (..., nbin, nmic) 91 | 92 | return Fus 93 | 94 | 95 | 96 | #---------------------------------------------------------------------------- 97 | def cache_rirs(self,): 98 | 99 | if os.path.isfile(self.rir_file): 100 | data = load_numpy_from_mat(self.rir_file) 101 | self.rir_A = data['rir_A'] # shape = (nrir, samples, nmic) 102 | self.rir_B = data['rir_B'] # shape = (nrir, samples, nmic) 103 | self.nrir = self.rir_A.shape[0] 104 | print('Loaded', self.nrir, 'RIRs from', self.rir_file) 105 | 106 | else: 107 | self.nrir = 500 # pre-calculate RIRs 108 | print('Generating', self.nrir, 'RIRs ...') 109 | 110 | # define room/shoebox 111 | rt60 = 0.250 # define rt60 of the generated RIRs 112 | room_dim = np.asarray([6.0, 4.0, 2.5]) # define room dimensions in [m] 113 | absorption, max_order = pra.inverse_sabine(rt60, room_dim) # invert Sabine's formula to obtain the parameters for the ISM simulator 114 | 115 | # create the room 116 | room = pra.ShoeBox(room_dim, fs=self.fs, materials=pra.Material(absorption), max_order=max_order) 117 | 118 | # place the array in the room 119 | array_center = np.asarray([2.5 , 1.5, 0.8]) 120 | pos = self.micpos.T + array_center[:,np.newaxis] 121 | room.add_microphone_array(pos) 122 | 123 | # add sources for region A and B to the room 124 | for r in range(self.nrir): 125 | 126 | # source 1 is randomly placed within region A 127 | x = np.random.uniform(1.0, 2.0) 128 | y = np.random.uniform(2.0, 3.0) 129 | z = np.random.uniform(1.5, 2.0) 130 | room.add_source([x, y, z], signal=0, delay=0) 131 | 132 | # source 2 is randomly placed within region B 133 | x = np.random.uniform(3.0, 4.0) 134 | y = np.random.uniform(2.0, 3.0) 135 | z = np.random.uniform(1.5, 2.0) 136 | room.add_source([x, y, z], signal=0, delay=0) 137 | 138 | 139 | # compute all RIRs and extend their length to 140 | t0 = time.time() 141 | room.compute_rir() 142 | t1 = time.time() 143 | print('Generated', self.nrir, 'RIRs in', t1-t0, 'seconds') 144 | 145 | 146 | self.rir_A = np.zeros((self.nrir, self.samples, self.nmic), dtype=np.float32) 147 | self.rir_B = np.zeros((self.nrir, self.samples, self.nmic), dtype=np.float32) 148 | for r in range(self.nrir): 149 | for m in range(self.nmic): 150 | 151 | h_A = room.rir[m][r*2+0] 152 | n = min(self.samples, h_A.size) 153 | self.rir_A[r,:n,m] = h_A[:n] 154 | 155 | h_B = room.rir[m][r*2+1] 156 | n = min(self.samples, h_B.size) 157 | self.rir_B[r,:n,m] = h_B[:n] 158 | 159 | 160 | data = { 161 | 'rir_A': self.rir_A, 162 | 'rir_B': self.rir_B, 163 | } 164 | save_numpy_to_mat(self.rir_file, data) 165 | 166 | 167 | 168 | #---------------------------------------------------------------------------- 169 | def load_rirs(self,): 170 | 171 | h_A = self.rir_A[np.random.choice(self.nrir),:,:] 172 | h_B = self.rir_B[np.random.choice(self.nrir),:,:] 173 | 174 | return h_A, h_B 175 | 176 | 177 | -------------------------------------------------------------------------------- /algorithms/audio_processing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import os 6 | import numpy as np 7 | import soundfile as sf 8 | from scipy import signal 9 | from scipy.fftpack import dct 10 | 11 | 12 | 13 | #---------------------------------------------------------------- 14 | # read multichannel audio data 15 | # output x.shape = (samples, nmic) 16 | def audioread(filename, normalize=True): 17 | 18 | x, fs = sf.read(filename) 19 | 20 | if normalize==True: 21 | x = x*0.99/np.max(np.abs(x)) 22 | 23 | return (x, fs) 24 | 25 | 26 | #---------------------------------------------------------------- 27 | # write multichannel audio data 28 | # input x.shape = (samples, nmic) 29 | def audiowrite(x, filename, fs=16000, normalize=True): 30 | 31 | # x.shape = (samples, channels) 32 | 33 | if normalize==True: 34 | x = x*0.99/np.max(np.abs(x)) 35 | 36 | sf.write(filename, x, fs) 37 | 38 | 39 | #------------------------------------------------------------------------- 40 | # convert STFT data back to time domain, and save to WAV-files 41 | # data = tuple of STFT tensors 42 | # filenames = tuple of file names 43 | def convert_and_save_wavs(data, filenames, fs=16000): 44 | 45 | for Fz, filename in zip(data, filenames): 46 | z = mistft(Fz) # Fz.shape = (nfram, self.nbin) 47 | 48 | mkdir(os.path.dirname(filename)) 49 | audiowrite(z, filename, fs) 50 | 51 | 52 | 53 | #---------------------------------------------------------------- 54 | # wrapper for python real fft 55 | def rfft(Bx, n=None, axis=-1): 56 | Fx = np.fft.rfft(Bx, n=n, axis=axis) 57 | return Fx 58 | 59 | 60 | #---------------------------------------------------------------- 61 | # wrapper for python real ifft 62 | def irfft(Fx, n=None, axis=-1): 63 | Bx = np.fft.irfft(Fx, n=n, axis=axis) 64 | return Bx 65 | 66 | 67 | #---------------------------------------------------------------- 68 | # perform a multichannel STFT on audio data x 69 | # x.shape = (..., samples) 70 | # Fx.shape = (..., nfram, nbin) 71 | def mstft(x, wlen=1024, shift=256, window=signal.blackman): 72 | 73 | x = np.asarray(x, dtype=np.float32) # shape = (..., samples) 74 | shape_x = tuple(x.shape[:-1]) 75 | samples = x.shape[-1] 76 | 77 | nbin = int(wlen/2+1) 78 | nfram = int(np.ceil( (samples-wlen+shift)/shift )) 79 | samples_padding = nfram*shift+wlen-shift - samples 80 | 81 | pad = np.zeros(shape_x+(samples_padding,), dtype=np.float32) 82 | x = np.concatenate([x, pad], axis=-1) 83 | 84 | analysis_window = window(wlen) 85 | 86 | Bx = np.zeros(shape_x+(nfram, wlen), dtype=np.float32) 87 | idx = np.arange(wlen) 88 | for t in range(nfram): 89 | Bx[...,t,:] = x[...,idx+t*shift] 90 | 91 | Bx = np.einsum('...tw,w->...tw', Bx, analysis_window) 92 | Fx = rfft(Bx, n=wlen, axis=-1).astype(np.complex64) # shape = (..., nfram, nbin) 93 | 94 | return Fx 95 | 96 | 97 | 98 | #---------------------------------------------------------------- 99 | # perform a multichannel inverse STFT on audio data Fx 100 | # Fx.shape = (nbin, nfram, nmic) 101 | # output x.shape = (samples, nmic) 102 | def mistft(Fx, wlen=1024, shift=256, window=signal.blackman): 103 | 104 | assert (Fx.ndim == 2 or Fx.ndim == 3), 'Fx must have either 2 or 3 dimensions' 105 | 106 | Fx = np.asarray(Fx, dtype=np.complex64) 107 | nbin = Fx.shape[0] 108 | nfram = Fx.shape[1] 109 | samples = nfram*shift+wlen-shift 110 | 111 | analysis_window = window(wlen) 112 | assert np.mod(wlen, shift) == 0 113 | number_of_shifts = int(wlen/shift) 114 | 115 | sum_of_squares = np.zeros(shift) 116 | for i in range(number_of_shifts): 117 | idx = np.arange(shift) + i*shift 118 | sum_of_squares = sum_of_squares + np.abs(analysis_window[idx])**2 119 | 120 | sum_of_squares = np.kron(np.ones(number_of_shifts), sum_of_squares) 121 | synthesis_window = analysis_window / sum_of_squares 122 | 123 | if Fx.ndim == 2: 124 | x = np.zeros((samples,), dtype=np.float32) 125 | for t in range(nfram): 126 | Bx = np.real(np.fft.irfft(Fx[:,t])) 127 | idx = np.arange(wlen) + t*shift 128 | x[idx] += Bx * synthesis_window 129 | 130 | 131 | if Fx.ndim == 3: 132 | nmic = Fx.shape[2] 133 | x = np.zeros((samples, nmic), dtype=np.float32) 134 | for c in range(nmic): 135 | for t in range(nfram): 136 | Bx = np.real(np.fft.irfft(Fx[:,t,c])) 137 | idx = np.arange(wlen) + t*shift 138 | x[idx,c] += Bx * synthesis_window 139 | 140 | 141 | return x 142 | 143 | 144 | 145 | #------------------------------------------------------------------------------ 146 | # get amplitude response of a highpass filter with 147 | # fs = samplerate 148 | # fc = corner frequency 149 | # response at H(fc) = 1/sqrt(2) 150 | # nbin = number of frequency bins 151 | def get_highpass_filter(nbin=513, fs=16e3, fc=100, order=2): 152 | 153 | fvect = np.arange(nbin)*fs/(2*(nbin-1)) 154 | k = np.power( np.sqrt(2)-1 , -1/order) 155 | tmp = np.power( (fvect/fc)*k , order ) 156 | H = np.maximum( 1-1/(1+tmp) , 1e-6 ) 157 | 158 | return H 159 | 160 | 161 | 162 | #------------------------------------------------------------------------------ 163 | def apply_highpass_filter(Fx, fs=16e3, fc=100, order=2): 164 | 165 | nbin = Fx.shape[0] 166 | 167 | H = get_highpass_filter(nbin=nbin, fs=fs, fc=fc, order=order) 168 | Fz = np.zeros_like(Fx) 169 | for k in range(nbin): 170 | Fz[k,...] = Fx[k,...]*H[k] 171 | 172 | return Fz 173 | 174 | 175 | 176 | #------------------------------------------------------------------------------ 177 | def hz_to_mel(hz): 178 | 179 | return 2595*np.log10(1+hz/700) 180 | 181 | 182 | 183 | #------------------------------------------------------------------------------ 184 | def mel_to_hz(mel): 185 | 186 | return 700*(10**(mel/2595)-1) 187 | 188 | 189 | 190 | #------------------------------------------------------------------------------ 191 | def create_mel_filterbank(nbin=513, fs=16e3, nband=40): 192 | 193 | low_freq_mel = hz_to_mel(100) 194 | high_freq_mel = hz_to_mel(fs/2) 195 | 196 | mel_points = np.linspace(low_freq_mel, high_freq_mel, nband+2, dtype=np.float32) # equally spaced mel scale with kernels 197 | hz_points = mel_to_hz(mel_points) 198 | 199 | bin_index = np.asarray(np.floor(2*nbin*hz_points/fs), dtype=np.int32) 200 | filterbank = np.zeros((nband, nbin), dtype=np.float32) 201 | 202 | for m in range(1, nband+1): 203 | 204 | # create triangular kernels 205 | f_m_left = bin_index[m-1] 206 | f_m_center = bin_index[m] 207 | f_m_right = bin_index[m+1] 208 | 209 | for k in range(f_m_left, f_m_center): 210 | filterbank[m-1, k] = (k-f_m_left) / (f_m_center-f_m_left) 211 | 212 | for k in range(f_m_center, f_m_right): 213 | filterbank[m-1, k] = (f_m_right-k) / (f_m_right-f_m_center) 214 | 215 | return filterbank # shape = (nband, nbin) 216 | 217 | 218 | 219 | 220 | #------------------------------------------------------------------------------ 221 | def convert_to_mel(Fx, filterbank): 222 | 223 | Px = np.abs(Fx)**2 224 | Mx = np.dot(Px, filterbank.T) 225 | mel = np.log(Mx + 1e-3) 226 | 227 | return mel 228 | 229 | 230 | 231 | #------------------------------------------------------------------------------ 232 | def convert_to_mfcc(Fx, filterbank): 233 | 234 | Px = np.abs(Fx)**2 235 | Mx = np.dot(Px, filterbank.T) 236 | Mx = np.log(Mx + 1e-3) 237 | 238 | mfcc = dct(Mx, axis=-1, type=2, norm='ortho') 239 | 240 | return mfcc 241 | 242 | 243 | 244 | #------------------------------------------------------------------------------ 245 | def mkdir(path): 246 | 247 | if not os.path.exists(os.path.dirname(path)): 248 | os.makedirs(os.path.dirname(path)) 249 | 250 | 251 | -------------------------------------------------------------------------------- /ops/complex_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import os 8 | import sys 9 | sys.path.append(os.path.abspath('../')) 10 | 11 | from utils.keras_helpers import * 12 | 13 | 14 | 15 | #------------------------------------------------------------------- 16 | def safe_conj(z): 17 | return tf.complex(tf.math.real(z), -tf.math.imag(z)) 18 | 19 | 20 | 21 | #------------------------------------------------------------------- 22 | @tf.custom_gradient 23 | def mean_square_error(z, c): 24 | 25 | s = tf.reduce_mean(tf.abs(z-c)**2) 26 | 27 | def grad(grad_s): 28 | 29 | grad_s = tf.cast(grad_s, tf.complex64) 30 | Nz = tf.cast(tf.reduce_prod(tf.shape(z)), tf.complex64) 31 | Nc = tf.cast(tf.reduce_prod(tf.shape(c)), tf.complex64) 32 | grad_z = 2*grad_s*(z-c)/Nz 33 | grad_c = 2*grad_s*(c-z)/Nc 34 | 35 | return grad_z, grad_c 36 | 37 | return s, grad 38 | 39 | 40 | 41 | #------------------------------------------------------------------- 42 | @tf.custom_gradient 43 | def vector_conj_inner(z, c): 44 | 45 | s = tf.reduce_sum(z*tf.math.conj(c), -1) 46 | 47 | def grad(grad_s): 48 | 49 | grad_s = tf.expand_dims(grad_s, axis=-1) 50 | #grad_s = Debug('grad_s', grad_s) 51 | grad_z = grad_s*c 52 | grad_c = tf.math.conj(grad_s)*z 53 | 54 | return grad_z, grad_c 55 | 56 | return s, grad 57 | 58 | 59 | 60 | #------------------------------------------------------------------- 61 | @tf.custom_gradient 62 | def cast_to_complex(z_real, z_imag): 63 | 64 | s = tf.complex(z_real, z_imag) 65 | 66 | def grad(grad_s): 67 | 68 | return tf.math.real(grad_s), tf.math.imag(grad_s) 69 | 70 | return s, grad 71 | 72 | 73 | 74 | #------------------------------------------------------------------- 75 | @tf.custom_gradient 76 | def cast_to_float(z): 77 | 78 | s = tf.stack([tf.math.real(z), tf.math.imag(z)], axis=-1) 79 | 80 | def grad(grad_s): 81 | 82 | return tf.complex(grad_s[...,0], grad_s[...,1]) 83 | 84 | return s, grad 85 | 86 | 87 | 88 | #------------------------------------------------------------------- 89 | @tf.custom_gradient 90 | def elementwise_real(z): 91 | 92 | s = tf.math.real(z) 93 | 94 | def grad(grad_s): 95 | 96 | return tf.complex(grad_s, tf.zeros_like(grad_s)) 97 | 98 | return s, grad 99 | 100 | 101 | 102 | #------------------------------------------------------------------- 103 | @tf.custom_gradient 104 | def elementwise_complex(z_real): 105 | 106 | s = tf.cast(z_real, tf.complex64) 107 | 108 | def grad(grad_s): 109 | 110 | return tf.math.real(grad_s) 111 | 112 | return s, grad 113 | 114 | 115 | 116 | #------------------------------------------------------------------- 117 | @tf.custom_gradient 118 | def elementwise_abs(z): 119 | 120 | s = tf.abs(z) 121 | 122 | def grad(grad_s): 123 | 124 | grad_s = tf.cast(tf.math.real(grad_s), tf.complex64) 125 | az = tf.cast(tf.abs(z)+1e-6, tf.complex64) 126 | gs = tf.cast(tf.math.real(grad_s), tf.complex64) 127 | grad_z = gs*z/az 128 | 129 | return grad_z 130 | 131 | return s, grad 132 | 133 | 134 | 135 | #------------------------------------------------------------------- 136 | @tf.custom_gradient 137 | def elementwise_abs2(z): 138 | 139 | s = tf.abs(z)**2 140 | 141 | def grad(grad_s): 142 | 143 | grad_s = tf.cast(tf.math.real(grad_s), tf.complex64) 144 | grad_z = 2*grad_s*z 145 | 146 | return grad_z 147 | 148 | return s, grad 149 | 150 | 151 | 152 | #------------------------------------------------------------------- 153 | @tf.custom_gradient 154 | def vector_normalize_magnitude(z): 155 | 156 | #num = tf.linalg.norm(z, axis=-1, keepdims=True) + 1e-6 # norm over last axis 157 | num = tf.sqrt(tf.reduce_sum(tf.abs(z)**2, axis=-1, keepdims=True)) + 1e-6 # norm over last axis 158 | s = z/tf.cast(num, tf.complex64) 159 | 160 | def grad(grad_s): 161 | 162 | tmp = tf.reduce_sum(tf.math.real(z)*tf.math.real(grad_s) + tf.math.imag(z)*tf.math.imag(grad_s), axis=-1, keepdims=True) 163 | tmp /= num*num*num + 1e-6 164 | grad_z = grad_s/tf.cast(num, tf.complex64) - z*tf.cast(tmp, tf.complex64) 165 | 166 | return grad_z 167 | 168 | return s, grad 169 | 170 | 171 | 172 | #------------------------------------------------------------------- 173 | @tf.custom_gradient 174 | def vector_normalize_phase(z): 175 | 176 | num = tf.abs(z[...,0]) + 1e-6 177 | num = tf.cast(num, tf.complex64) 178 | phi = z[...,0] / num 179 | s = z*tf.math.conj(phi)[...,tf.newaxis] 180 | 181 | 182 | def grad(grad_s): 183 | 184 | grad_z_1 = tf.cast(tf.math.real(grad_s[...,0]), tf.complex64)*phi 185 | 186 | tmp = tf.einsum('...i,...i->...', safe_conj(grad_s)[...,1:], z[...,1:]) 187 | tmp -= tf.math.conj(tmp)*phi*phi 188 | grad_z_1 += 0.5*tmp / num 189 | 190 | grad_z_2 = grad_s*phi[...,tf.newaxis] 191 | grad_z = tf.concat([grad_z_1[...,tf.newaxis], grad_z_2[...,1:]], axis=-1) 192 | 193 | return grad_z 194 | 195 | return s, grad 196 | 197 | 198 | 199 | #------------------------------------------------------------------- 200 | @tf.custom_gradient 201 | def elementwise_normalize(z): 202 | 203 | num = tf.abs(z) + 1e-6 204 | num = tf.cast(num, tf.complex64) 205 | s = z/num 206 | 207 | def grad(grad_s): 208 | 209 | # gradient from division 210 | grad_z1 = grad_s / tf.math.conj(num) 211 | grad_num = -grad_s*tf.math.conj(z) / (tf.math.conj(num)**2) 212 | 213 | # gradient of abs 214 | az = tf.cast(tf.abs(z)+1e-6, tf.complex64) 215 | grad_num = tf.cast(tf.math.real(grad_num), tf.complex64) 216 | grad_z2 = grad_num*z/az 217 | 218 | return grad_z1+grad_z2 219 | 220 | return s, grad 221 | 222 | 223 | #------------------------------------------------------------------- 224 | @tf.custom_gradient 225 | def elementwise_mul(z, c): 226 | 227 | s = z*c 228 | 229 | def grad(grad_s): 230 | 231 | grad_z = grad_s*tf.math.conj(c) 232 | grad_c = grad_s*tf.math.conj(z) 233 | 234 | return grad_z, grad_c 235 | 236 | return s, grad 237 | 238 | 239 | 240 | #------------------------------------------------------------------- 241 | @tf.custom_gradient 242 | def elementwise_div(z, c): 243 | 244 | s = z/c 245 | 246 | def grad(grad_s): 247 | 248 | grad_z = grad_s / tf.math.conj(c) 249 | grad_c = -grad_s*tf.math.conj(z) / (tf.math.conj(c)**2) 250 | 251 | return grad_z, grad_c 252 | 253 | return s, grad 254 | 255 | 256 | 257 | #------------------------------------------------------------------- 258 | @tf.custom_gradient 259 | def elementwise_conj(z): 260 | 261 | s = safe_conj(z) 262 | 263 | def grad(grad_s): 264 | 265 | grad_z = safe_conj(grad_s) 266 | 267 | return grad_z 268 | 269 | return s, grad 270 | 271 | 272 | 273 | #------------------------------------------------------------------- 274 | @tf.custom_gradient 275 | def elementwise_tanh(z): 276 | 277 | num = tf.abs(z) + 1e-6 278 | ta = tf.tanh(num) 279 | mag = ta/num 280 | s = z*tf.cast(mag, tf.complex64) 281 | 282 | def grad(grad_s): 283 | 284 | sa = 0.5-0.5*ta*ta 285 | 286 | tmp1 = tf.cast(sa + 0.5*mag, tf.complex64) 287 | tmp2 = tf.cast(sa - 0.5*mag, tf.complex64) 288 | tmp3 = z*z/tf.cast(num*num, tf.complex64) # = (z/abs(z))**2 289 | grad_z = tf.math.conj(grad_s)*tmp2*tmp3 + grad_s*tmp1 290 | 291 | return grad_z 292 | 293 | return s, grad 294 | 295 | 296 | 297 | #------------------------------------------------------------------- 298 | @tf.custom_gradient 299 | def elementwise_sigmoid(z): 300 | 301 | s_real = tf.tanh(tf.math.real(z)*0.5)*0.5+0.5 302 | s = tf.cast(s_real, tf.complex64) 303 | 304 | def grad(grad_s): 305 | 306 | grad_z_real = tf.math.real(grad_s)*s_real*(1-s_real) 307 | grad_z = tf.cast(grad_z_real, tf.complex64) 308 | 309 | return grad_z 310 | 311 | return s, grad 312 | 313 | 314 | 315 | #------------------------------------------------------------------- 316 | # einsum with complex arguments causes adjoint errors in CUDNN, workaround: 317 | def einsum(subscripts, z, c): 318 | 319 | s1 = tf.einsum(subscripts, tf.math.real(z), tf.math.real(c)) 320 | s2 = tf.einsum(subscripts, tf.math.real(z), tf.math.imag(c)) 321 | s3 = tf.einsum(subscripts, tf.math.imag(z), tf.math.real(c)) 322 | s4 = tf.einsum(subscripts, tf.math.imag(z), tf.math.imag(c)) 323 | 324 | s = cast_to_complex(s1-s4, s2+s3) 325 | 326 | return s 327 | 328 | 329 | -------------------------------------------------------------------------------- /experiments/cnbf_complex.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "Lukas Pfeifenberger" 3 | 4 | 5 | import time 6 | import numpy as np 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | sys.path.append(os.path.abspath('../')) 13 | 14 | from keras.models import Model 15 | from keras.layers import Dense, Activation, LSTM, Input, Lambda 16 | import keras.backend as K 17 | import tensorflow as tf 18 | 19 | from loaders.feature_generator import feature_generator 20 | from utils.mat_helpers import * 21 | from utils.keras_helpers import * 22 | from ops.complex_ops import * 23 | from ops.complex_layers import * 24 | from ops.kernelized_layers import * 25 | from algorithms.audio_processing import * 26 | from utils.matplotlib_helpers import * 27 | 28 | 29 | np.set_printoptions(precision=3, threshold=3, edgeitems=3) 30 | 31 | 32 | 33 | #------------------------------------------------------------------------- 34 | #------------------------------------------------------------------------- 35 | 36 | 37 | class cnbf(object): 38 | 39 | def __init__(self, config, fgen): 40 | 41 | self.config = config 42 | self.fgen = fgen 43 | self.name = 'cnbf_complex' 44 | self.logger = Logger(self.name) 45 | 46 | self.creation_date = os.path.getmtime(self.name+'.py') # timestamp of this script 47 | self.weights_file = self.config['weights_path'] + self.name + '.h5' 48 | self.predictions_file = self.config['predictions_path'] + self.name + '.mat' 49 | 50 | self.nbatch = 5 51 | self.nfram = self.fgen.nfram 52 | self.nbin = self.fgen.nbin 53 | self.nmic = self.fgen.nmic 54 | 55 | self.create_model() 56 | 57 | 58 | 59 | #--------------------------------------------------------- 60 | def layer0(self, inp): 61 | 62 | Fz = tf.cast(inp, tf.complex64) # shape = (nbatch, nfram, nbin, nmic) 63 | 64 | Pz = elementwise_abs2(Fz) # shape = (nbatch, nfram, nbin, nmic) 65 | Pz = tf.reduce_mean(Pz, axis=-1) # shape = (nbatch, nfram, nbin) 66 | Lz = tf.math.log(Pz+1e-3)[...,tf.newaxis] # shape = (nbatch, nfram, nbin, 1) 67 | Lz = elementwise_complex(Lz) 68 | 69 | vz = vector_normalize_magnitude(Fz) # shape = (nbatch, nfram, nbin, nmic) 70 | vz = vector_normalize_phase(vz) # shape = (nbatch, nfram, nbin, nmic) 71 | 72 | X = tf.concat([vz, Lz], axis=-1) # shape = (nbatch, nfram, nbin, nmic+1) 73 | Y = tf.reshape(X, [self.nbatch, self.nfram, self.nbin*(self.nmic+1)]) # shape = (nbatch, nfram, nbin*(nmic+1)) 74 | 75 | return [X,Y] 76 | 77 | 78 | 79 | #--------------------------------------------------------- 80 | def layer1(self, inp): 81 | 82 | X = tf.cast(inp[0], tf.complex64) # shape = (nbatch, nfram, nbin, nmic+1) 83 | Y = tf.cast(inp[1], tf.complex64) # shape = (nbatch, nfram, nbin) 84 | 85 | X = tf.concat([X,Y[...,tf.newaxis]], axis=-1) 86 | 87 | return X 88 | 89 | 90 | 91 | #--------------------------------------------------------- 92 | def layer2(self, inp): 93 | 94 | Fs = tf.cast(inp[0], tf.complex64) # shape = (nbatch, nfram, nbin, nmic) 95 | Fn = tf.cast(inp[1], tf.complex64) # shape = (nbatch, nfram, nbin, nmic) 96 | W = tf.cast(inp[2], tf.complex64) # shape = (nbatch, nfram, nbin, nmic) 97 | 98 | # beamforming 99 | W = vector_normalize_magnitude(W) # shape = (nbatch, nfram, nbin, nmic) 100 | W = vector_normalize_phase(W) # shape = (nbatch, nfram, nbin, nmic) 101 | Fys = vector_conj_inner(Fs, W) # shape = (nbatch, nfram, nbin) 102 | Fyn = vector_conj_inner(Fn, W) # shape = (nbatch, nfram, nbin) 103 | 104 | # energy of the input 105 | Ps = tf.reduce_mean(elementwise_abs2(Fs), axis=-1) # input (desired source) 106 | Pn = tf.reduce_mean(elementwise_abs2(Fn), axis=-1) # input (unwanted source) 107 | Ls = 10*log10(Ps + 1e-2) 108 | Ln = 10*log10(Pn + 1e-2) 109 | 110 | # energy of the beamformed outputs 111 | Pys = elementwise_abs2(Fys) # output (desired source) 112 | Pyn = elementwise_abs2(Fyn) # output (unwanted source) 113 | Lys = 10*log10(Pys + 1e-2) 114 | Lyn = 10*log10(Pyn + 1e-2) 115 | 116 | delta_snr = Lys-Lyn - (Ls-Ln) 117 | 118 | cost = -tf.reduce_mean(delta_snr, axis=(1,2)) 119 | 120 | return [Fys, Fyn, cost] 121 | 122 | 123 | 124 | 125 | #--------------------------------------------------------- 126 | def create_model(self): 127 | 128 | print('*** creating model: %s' % self.name) 129 | 130 | # shape definitions: (nbatch, ...) 131 | Fs = Input(batch_shape=(self.nbatch, self.nfram, self.nbin, self.nmic), dtype=tf.complex64) 132 | Fn = Input(batch_shape=(self.nbatch, self.nfram, self.nbin, self.nmic), dtype=tf.complex64) 133 | 134 | Fz = Fs+Fn 135 | X,Y = Lambda(self.layer0)(Fz) 136 | Y = Complex_Dense(units=50, activation='tanh')(Y) # shape = (nbatch, nfram, 50) 137 | Y = Complex_Dense(units=self.nbin, activation='tanh')(Y) # shape = (nbatch, nfram, nbin) 138 | X = Lambda(self.layer1)([X,Y]) 139 | X = Kernelized_Complex_LSTM(units=self.nmic*2)(X) # shape = (nbatch, nfram, nbin, nmic*2) 140 | X = Kernelized_Complex_Dense(units=self.nmic*2, activation='tanh')(X) # shape = (nbatch, nfram, nbin, nmic*2) 141 | W = Kernelized_Complex_Dense(units=self.nmic, activation='linear')(X) # shape = (nbatch, nfram, nbin, nmic) 142 | 143 | Fys, Fyn, cost = Lambda(self.layer2)([Fs, Fn, W]) 144 | 145 | self.model = Model(inputs=[Fs, Fn], outputs=[Fys, Fyn]) 146 | self.model.add_loss(cost) 147 | self.model.compile(loss=None, optimizer='adam') 148 | 149 | print(self.model.summary()) 150 | try: 151 | self.model.load_weights(self.weights_file) 152 | except: 153 | print('error loading weights file: %s' % self.weights_file) 154 | 155 | 156 | 157 | #--------------------------------------------------------- 158 | def train(self): 159 | 160 | Fs, Fn = self.fgen.generate_mixtures(self.nbatch) 161 | self.model.fit([Fs, Fn], None, batch_size=self.nbatch, epochs=1, verbose=0, shuffle=False, callbacks=[self.logger]) 162 | 163 | 164 | 165 | #--------------------------------------------------------- 166 | def save_weights(self): 167 | 168 | self.model.save_weights(self.weights_file) 169 | 170 | return 171 | 172 | 173 | 174 | #--------------------------------------------------------- 175 | def save_prediction(self): 176 | 177 | Fs, Fn = self.fgen.generate_mixtures(self.nbatch) 178 | Fys, Fyn = self.model.predict([Fs, Fn]) 179 | 180 | data = { 181 | 'Fs': np.transpose(Fs, [0,2,1,3])[0,:,:,0], # shape = (nbin, nfram) 182 | 'Fn': np.transpose(Fn, [0,2,1,3])[0,:,:,0], # shape = (nbin, nfram) 183 | 'Fys': np.transpose(Fys, [0,2,1])[0,:,:], # shape = (nbin, nfram) 184 | 'Fyn': np.transpose(Fyn, [0,2,1])[0,:,:], # shape = (nbin, nfram) 185 | } 186 | save_numpy_to_mat(self.predictions_file, data) 187 | 188 | 189 | 190 | #--------------------------------------------------------- 191 | def check_date(self): 192 | 193 | if (self.creation_date == os.path.getmtime(self.name+'.py')): 194 | return True 195 | else: 196 | return False 197 | 198 | 199 | 200 | #--------------------------------------------------------- 201 | def inference(self): 202 | 203 | Fs, Fn = self.fgen.generate_mixtures(self.nbatch) 204 | Fys, Fyn = self.model.predict([Fs, Fn]) 205 | 206 | Fs = Fs[0,...,0].T # input (desired source) 207 | Fn = Fn[0,...,0].T # input (unwanted source) 208 | Fys = Fys[0,...].T # output (desired source) 209 | Fyn = Fyn[0,...].T # output (unwanted source) 210 | Fz = Fs+Fn # noisy mixture 211 | Fy = Fys+Fyn # enhanced output 212 | 213 | data = (Fz, Fy) 214 | filenames = ( 215 | self.config['predictions_path'] + self.name + '_noisy.wav', 216 | self.config['predictions_path'] + self.name + '_enhanced.wav', 217 | ) 218 | convert_and_save_wavs(data, filenames) 219 | 220 | Lz = ( 20*np.log10(np.abs(Fs)+1e-1) - 20*np.log10(np.abs(Fn)+1e-1) )/30 221 | Ly = ( 20*np.log10(np.abs(Fys)+1e-1) - 20*np.log10(np.abs(Fyn)+1e-1) )/30 222 | legend = ('noisy', 'enhanced') 223 | clim = (-1, +1) 224 | filename = self.config['predictions_path'] + self.name + '_prediction.png' 225 | draw_subpcolor((Lz, Ly), legend, clim, filename) 226 | 227 | 228 | 229 | 230 | #--------------------------------------------------------- 231 | #--------------------------------------------------------- 232 | if __name__ == "__main__": 233 | 234 | 235 | # parse command line args 236 | parser = argparse.ArgumentParser(description='CNBF') 237 | parser.add_argument('--config_file', help='name of json configuration file', default='../cnbf.json') 238 | parser.add_argument('--predict', help='inference', action='store_true') 239 | args = parser.parse_args() 240 | 241 | 242 | # load config file 243 | try: 244 | print('*** loading config file: %s' % args.config_file ) 245 | with open(args.config_file, 'r') as f: 246 | config = json.load(f) 247 | except: 248 | print('*** could not load config file: %s' % args.config_file) 249 | quit(0) 250 | 251 | 252 | 253 | if args.predict is False: 254 | fgen = feature_generator(config, 'train') 255 | bf = cnbf(config, fgen) 256 | print('training') 257 | i = 0 258 | while (i...j', x, W) + b # shape = (..., units) 50 | 51 | if self.activation == 'tanh': 52 | z = elementwise_tanh(z) 53 | elif self.activation == 'norm': 54 | z = vector_normalize_magnitude(z) 55 | 56 | return z 57 | 58 | 59 | def compute_output_shape(self, input_shape): 60 | 61 | output_shape = list(input_shape) 62 | output_shape[-1] = self.units 63 | return tuple(output_shape) 64 | 65 | 66 | 67 | 68 | #------------------------------------------------------------------- 69 | 70 | class Complex_LSTM(Layer): 71 | 72 | def __init__(self, units, activation='tanh', return_sequences=True, go_backwards=False): 73 | 74 | super(Complex_LSTM, self).__init__() 75 | self.units = units 76 | self.activation = activation 77 | self.return_sequences = return_sequences 78 | self.go_backwards = go_backwards 79 | 80 | 81 | def build(self, input_shape): 82 | 83 | cell = self.Cell(self.units) 84 | self.rnn = RNN(cell, return_sequences=self.return_sequences, go_backwards=self.go_backwards) 85 | self.rnn.build(input_shape=input_shape) 86 | self._trainable_weights = self.rnn.trainable_weights 87 | super(Complex_LSTM, self).build(input_shape) 88 | 89 | 90 | def call(self, inputs): 91 | 92 | x = inputs # shape = (nbatch, nfram, n_in) 93 | 94 | # reshape input to 3D 95 | nbatch = tf.shape(x)[0] 96 | nfram = tf.shape(x)[1] 97 | 98 | # reshape output to 4D 99 | y = self.rnn(x) 100 | y = tf.reshape(y, [nbatch, nfram, self.units]) 101 | 102 | # reverse time axis back to normal 103 | if self.go_backwards is True: 104 | y = tf.reverse(y, axis=[1]) 105 | 106 | return y 107 | 108 | 109 | def compute_output_shape(self, input_shape): 110 | 111 | output_shape = list(input_shape) 112 | output_shape[-1] = self.units 113 | return tuple(output_shape) 114 | 115 | 116 | 117 | class Cell(Layer): 118 | 119 | def __init__(self, units, activation='tanh', recurrent_activation='sigmoid'): 120 | 121 | super(Complex_LSTM.Cell, self).__init__() 122 | self.units = units # = kernel size of the output 123 | self.activation = activation 124 | self.state_size = (units, units) # = flattened sizes of the hidden and carry state 125 | self.output_size = units # = flattened size of the output 126 | 127 | 128 | def build(self, input_shape): 129 | 130 | # the input of the Cell is a 3D tensor with shape (nbatch, nfram, n_in) 131 | n_in = input_shape[-1] 132 | 133 | self.W_real = self.add_weight(shape=(n_in, self.units*4), name='W_real', initializer='glorot_uniform') 134 | self.U_real = self.add_weight(shape=(self.units, self.units*4), name='U_real', initializer='orthogonal') 135 | self.b_real = self.add_weight(shape=(self.units*4), name='b_real', initializer='zeros') 136 | 137 | self.W_imag = self.add_weight(shape=(n_in, self.units*4), name='W_imag', initializer='glorot_uniform') 138 | self.U_imag = self.add_weight(shape=(self.units, self.units*4), name='U_imag', initializer='orthogonal') 139 | self.b_imag = self.add_weight(shape=(self.units*4), name='b_imag', initializer='zeros') 140 | 141 | super(Complex_LSTM.Cell, self).build(input_shape) 142 | 143 | 144 | # this function is called every time steps 145 | def call(self, inputs, states, training=None): 146 | 147 | x = inputs # shape = (nbatch, n_in) 148 | h_tm1 = states[0] # shape = (nbatch, units) 149 | c_tm1 = states[1] # shape = (nbatch, units) 150 | 151 | 152 | W = cast_to_complex(self.W_real, self.W_imag) 153 | U = cast_to_complex(self.U_real, self.U_imag) 154 | b = cast_to_complex(self.b_real, self.b_imag) 155 | 156 | z = einsum('bi,ij->bj', x, W) # shape = (nbatch, units*4) 157 | z += einsum('bi,ij->bj', h_tm1, U) # shape = (nbatch, units*4) 158 | z += b 159 | 160 | a, i, f, o = [ z[:,i*self.units:(i+1)*self.units] for i in range(4) ] 161 | 162 | a = elementwise_tanh(a) 163 | i = elementwise_sigmoid(i) 164 | f = elementwise_sigmoid(f) 165 | o = elementwise_sigmoid(o) 166 | 167 | c = a*i + f*c_tm1 168 | 169 | if self.activation == 'tanh': 170 | h = o*elementwise_tanh(c) 171 | elif self.activation == 'norm': 172 | h = o*vector_normalize_magnitude(c) 173 | else: 174 | h = o*c 175 | 176 | return h, [h, c] 177 | 178 | 179 | 180 | 181 | #------------------------------------------------------------------- 182 | 183 | class Kernelized_Complex_Dense(Layer): 184 | 185 | def __init__(self, units, activation=None): 186 | 187 | super(Kernelized_Complex_Dense, self).__init__() 188 | self.units = units 189 | self.activation = activation 190 | 191 | 192 | def build(self, input_shape): 193 | 194 | # input_shape = (..., kernels, n_in) 195 | 196 | # keras does not allow dtype=tf.complex64 on trainable weights, workaround: 197 | self.W_real = self.add_weight(name='W_real', shape=(input_shape[-2], input_shape[-1], self.units), initializer='random_normal', dtype=tf.float32) 198 | self.W_imag = self.add_weight(name='W_imag', shape=(input_shape[-2], input_shape[-1], self.units), initializer='random_normal', dtype=tf.float32) 199 | 200 | self.b_real = self.add_weight(name='b_real', shape=(input_shape[-2], self.units), initializer='zeros', dtype=tf.float32) 201 | self.b_imag = self.add_weight(name='b_imag', shape=(input_shape[-2], self.units), initializer='zeros', dtype=tf.float32) 202 | 203 | super(Kernelized_Complex_Dense, self).build(input_shape) 204 | 205 | 206 | def call(self, inputs): 207 | 208 | x = inputs 209 | 210 | W = cast_to_complex(self.W_real, self.W_imag) 211 | b = cast_to_complex(self.b_real, self.b_imag) 212 | 213 | z = einsum('...ki,kij->...kj', x, W) + b # shape = (..., kernels, units) 214 | 215 | if self.activation == 'tanh': 216 | z = elementwise_tanh(z) 217 | elif self.activation == 'norm': 218 | z = vector_normalize_magnitude(z) 219 | 220 | return z 221 | 222 | 223 | def compute_output_shape(self, input_shape): 224 | 225 | output_shape = list(input_shape) 226 | output_shape[-1] = self.units 227 | return tuple(output_shape) 228 | 229 | 230 | 231 | #------------------------------------------------------------------- 232 | 233 | class Kernelized_Complex_LSTM(Layer): 234 | 235 | def __init__(self, units, return_sequences=True, go_backwards=False): 236 | 237 | super(Kernelized_Complex_LSTM, self).__init__() 238 | self.units = units 239 | self.return_sequences = return_sequences 240 | self.go_backwards = go_backwards 241 | 242 | 243 | def build(self, input_shape): 244 | 245 | # input to the kernelized LSTM is a 4D tensor: 246 | nbatch, nfram, kernels, n_in = input_shape 247 | 248 | cell = self.Cell(kernels, self.units) 249 | self.rnn = RNN(cell, return_sequences=self.return_sequences, go_backwards=self.go_backwards) 250 | 251 | # the Keras RNN implementation does only work with 3D tensors, hence we flatten the last two dimensions of the input: 252 | self.rnn.build(input_shape=(nbatch, nfram, kernels*n_in)) 253 | self._trainable_weights = self.rnn.trainable_weights 254 | super(Kernelized_Complex_LSTM, self).build(input_shape) 255 | 256 | 257 | def call(self, inputs): 258 | 259 | x = inputs # shape = (nbatch, nfram, kernels, n_in) 260 | 261 | # reshape input to 3D 262 | nbatch = tf.shape(x)[0] 263 | nfram = tf.shape(x)[1] 264 | kernels = tf.shape(x)[2] 265 | x = tf.reshape(x, [nbatch, nfram, -1]) 266 | 267 | # reshape output to 4D 268 | y = self.rnn(x) 269 | y = tf.reshape(y, [nbatch, nfram, kernels, self.units]) 270 | 271 | # reverse time axis back to normal 272 | if self.go_backwards is True: 273 | y = tf.reverse(y, axis=[1]) 274 | 275 | return y 276 | 277 | 278 | def compute_output_shape(self, input_shape): 279 | 280 | output_shape = list(input_shape) 281 | output_shape[-1] = self.units 282 | return tuple(output_shape) 283 | 284 | 285 | 286 | class Cell(Layer): 287 | 288 | def __init__(self, kernels, units, activation='tanh', recurrent_activation='hard_sigmoid'): 289 | 290 | super(Kernelized_Complex_LSTM.Cell, self).__init__() 291 | self.units = units # = kernel size of the output 292 | self.kernels = kernels # = data size of the output 293 | self.state_size = (kernels*units, kernels*units) # = flattened sizes of the hidden and carry state 294 | self.output_size = kernels*units # = flattened size of the output 295 | 296 | 297 | def build(self, input_shape): 298 | 299 | # the input of the Cell is a 3D tensor with shape (nbatch, nfram, kernels*n_in) 300 | n_in = int(input_shape[-1]/self.kernels) 301 | 302 | self.W_real = self.add_weight(shape=(self.kernels, n_in, self.units*4), name='W_real', initializer='glorot_uniform') 303 | self.U_real = self.add_weight(shape=(self.kernels, self.units, self.units*4), name='U_real', initializer='orthogonal') 304 | self.b_real = self.add_weight(shape=(self.kernels, self.units*4), name='b_real', initializer='zeros') 305 | 306 | self.W_imag = self.add_weight(shape=(self.kernels, n_in, self.units*4), name='W_imag', initializer='glorot_uniform') 307 | self.U_imag = self.add_weight(shape=(self.kernels, self.units, self.units*4), name='U_imag', initializer='orthogonal') 308 | self.b_imag = self.add_weight(shape=(self.kernels, self.units*4), name='b_imag', initializer='zeros') 309 | 310 | super(Kernelized_Complex_LSTM.Cell, self).build(input_shape) 311 | 312 | 313 | # this function is called every time steps 314 | def call(self, inputs, states, training=None): 315 | 316 | x = inputs # shape = (nbatch, kernels*n_in) 317 | nbatch = tf.shape(x)[0] 318 | x = tf.reshape(x, [nbatch, self.kernels, -1]) # expand input to 3D 319 | h_tm1 = tf.reshape(states[0], [nbatch, self.kernels, self.units]) # expand previous hidden state to 3D 320 | c_tm1 = tf.reshape(states[1], [nbatch, self.kernels, self.units]) # expand previous carry state to 3D 321 | 322 | 323 | W = cast_to_complex(self.W_real, self.W_imag) 324 | U = cast_to_complex(self.U_real, self.U_imag) 325 | b = cast_to_complex(self.b_real, self.b_imag) 326 | 327 | z = einsum('...ki,kij->...kj', x, W) # shape = (..., kernels, units*4) 328 | z += einsum('...ki,kij->...kj', h_tm1, U) # shape = (..., kernels, units*4) 329 | z += b 330 | 331 | a, i, f, o = [ z[..., i*self.units:(i+1)*self.units] for i in range(4) ] 332 | 333 | a = elementwise_tanh(a) 334 | i = elementwise_sigmoid(i) 335 | f = elementwise_sigmoid(f) 336 | o = elementwise_sigmoid(o) 337 | 338 | c = a*i + f*c_tm1 339 | h = o*elementwise_tanh(c) 340 | 341 | 342 | # flatten new hidden and carry state back to 2D 343 | h = tf.reshape(h, [nbatch, -1]) # shape = (nbatch, kernels*units) 344 | c = tf.reshape(c, [nbatch, -1]) 345 | 346 | return h, [h, c] 347 | 348 | 349 | 350 | --------------------------------------------------------------------------------