├── img ├── mixed.png ├── unmixed.png ├── adap_spa_net.png ├── adapt_spa_layer.png └── adapt_spa_layer_.png ├── README.md ├── ica_poc_data.py ├── ica_poc_train.py └── model.py /img/mixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drasros/neuralICA/HEAD/img/mixed.png -------------------------------------------------------------------------------- /img/unmixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drasros/neuralICA/HEAD/img/unmixed.png -------------------------------------------------------------------------------- /img/adap_spa_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drasros/neuralICA/HEAD/img/adap_spa_net.png -------------------------------------------------------------------------------- /img/adapt_spa_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drasros/neuralICA/HEAD/img/adapt_spa_layer.png -------------------------------------------------------------------------------- /img/adapt_spa_layer_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drasros/neuralICA/HEAD/img/adapt_spa_layer_.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Neural ICA 2 | 3 | This repository provides a simple toy example of neural ICA, i.e ICA embedded in neural network layers. The basic building block is an adaptive channel recomposition Layer. We call it 'Adaptive Spacial layer' because we developped it for EEG analysis where channels are spread over space. Of course it can also be used when channels have no notion of 'space'. 4 | 5 | ## The Adaptive Spacial Layer 6 | 7 | see `model.AdaptiveSpacialLayer`. 8 | An adaptive spatial layer applies adaptive weights to its input channels and returns transformed channels. The weights depend on the input, i.e. each element of a training minibatch uses a different set of weights. Weights are determined based on spacial features of the input obtained with fully connected layers as in [TCL](https://arxiv.org/abs/1605.06336). All time samples of a same example are transformed with the same weights. (an extension for _non-factorized_ spatio-temporal analysis could be to use evolving weights...) 9 | 10 | ![](img/adapt_spa_layer.png) 11 | 12 | ## The Multilayer Adaptive Spatial Net 13 | 14 | A multilayer adaptive spatial net is obtained by stacking such adaptive spacial layers (linearly or non-linearly). Here we present a linear stack, i.e. the unmixing is in each adaptive layer is linear. But spacial features used to determine adaptive weights are non-linear. The net is trained with an _ICA cost_, i.e. it is trained to maxime the non-gaussianity and asymetry of its output. 15 | 16 | ![](img/adap_spa_net.png) 17 | 18 | `ica_poc_train.py` shows a basic proof of concept on toy signals mixed with random coefficients. A training time, many mixed examples (all with different mixing coefficients) are presented. After training, the network is able to separate sources _in a single forward pass_. 19 | 20 | Mixed signals: 21 | 22 | ![](img/mixed.png) 23 | 24 | Unmixed signals: 25 | 26 | ![](img/unmixed.png) 27 | 28 | ## Other analyses based on costatistics between channels 29 | 30 | Note that whilst here for a simple first application we trained the system to perform ICA, it can also be used for other purposes. For example in a spatiotemporal analysis system it can be used as a first spacial analysis brick, followed by a temporal analysis brick, for example temporal CNNs. The whole system can be trained with backprop and the Spacial Analysis Layer will recompose input channels into channels that are the most suitable for the subsequent temporal analysis. We expect that non-linear recomposition is also possible. -------------------------------------------------------------------------------- /ica_poc_data.py: -------------------------------------------------------------------------------- 1 | # Generate sythetic signals for ICA demo 2 | 3 | # Two base waveforms (sine wave + sawtooth) 4 | # 5 | # * Choose phases 6 | # * Mix them with a random matrix 7 | # * Normalize EACH channel (contrary to usual ICA where 8 | # PCA is performed) and feed to neural net 9 | # * train to MAXIMIZE non-gaussianity 10 | 11 | # * See if we get independent components 12 | # * See if this trained model works on other waveforms 13 | 14 | # See if if works on training set of larger diversity 15 | # (more frequencies and shapes) 16 | 17 | # in_size=512 18 | 19 | # TODO: TEST normalizing data by /(2*std) rathen than /std 20 | # Too high input variance may bring instability! (nans for 'normal') 21 | # learning rates. This could allow higher LRs... TO TRY 22 | 23 | import numpy as np 24 | 25 | import matplotlib 26 | matplotlib.use('Agg') 27 | import matplotlib.pyplot as plt 28 | 29 | # TODO: vectorize on BATCH 30 | def sine_wave(k, period, phase): 31 | # k: list of ints 32 | # period: (batch,), in number of samples 33 | # phase: (batch,), in number of samples 34 | signal_len = len(k) 35 | batch_size = len(period) 36 | 37 | k = np.reshape(k, (1, signal_len)) 38 | period = np.reshape(period, (batch_size, 1)) 39 | phase = np.reshape(phase, (batch_size, 1)) 40 | if np.any(phase > period): 41 | raise ValueError('phases and periods must ' + \ 42 | 'verify: 0<=phase period): 61 | raise ValueError('phases and periods must ' + \ 62 | 'verify: 0<=phase= 3: 91 | raise ValueError("Warning: already 3 loops done. " 92 | " The chosen min_Mlines_angle is probably too high...") 93 | 94 | batch_size_temp = (batch_size if min_Mlines_angle==0. \ 95 | else 2*batch_size) 96 | M_ = np.random.uniform(min_mix_coeff, 1., size=(batch_size_temp, 1, 2, 2)) 97 | M_vals_list = np.array([-1, 1]) 98 | M_sign = np.random.randint(0, 2, size=(batch_size_temp, 1, 2, 2)) 99 | # low=0, high=2 because high is exclusive 100 | M_sign = M_vals_list[M_sign] 101 | M = np.multiply(M_, M_sign) 102 | 103 | # now make sure that there is at least min_Mlines_angle between 104 | # the two lines 105 | # norms of lines of M 106 | norms_Mlines = np.linalg.norm(M, axis=-1) # (batch_size_temp, 1, 2) 107 | # dot products 108 | dotprod_Mlines = np.sum( 109 | np.prod(M, axis=2), 110 | axis=-1) # (batch_size_temp, 1) 111 | cos_angles = np.divide( 112 | dotprod_Mlines, 113 | np.prod(norms_Mlines, axis=-1)) 114 | valid_idxs = np.where(np.abs(cos_angles) < np.cos(min_Mlines_angle)) 115 | valid_idxs = valid_idxs[0] 116 | 117 | M = M[valid_idxs, :, :, :] 118 | if len(M) >= batch_size: 119 | return M[:batch_size] 120 | else: return get_random_M( 121 | batch_size, min_mix_coeff, min_Mlines_angle, loops_done+1) 122 | 123 | 124 | 125 | def sine_sawtooth_iterator_fixedperiods(batch_size, 126 | in_size, 127 | sine_period, 128 | sawtooth_period, 129 | min_mix_coeff, 130 | min_Mlines_angle): 131 | # this iterator is for the experiment where the mixing is 132 | # variable but component keep the same period over examples 133 | 134 | # To avoid mixing very little of one input component in 135 | # output components, one can use the parameter 'min_mix_coeff' (in 136 | # the interval [0., 1.]) 137 | # in which case coeffs of the mixing matrix will be drawn from 138 | # the interval [-1, -min_mix_coeff]U[min_mix_coeff, 1], instead of 139 | # the interval [-1, 1] 140 | 141 | k = np.arange(in_size) 142 | sine_period = [sine_period] * batch_size 143 | sawtooth_period = [sawtooth_period] * batch_size 144 | while True: 145 | sine_phase = np.random.randint( 146 | sine_period[0], size=(batch_size,)) 147 | sawtooth_phase = np.random.randint( 148 | sawtooth_period[0], size=(batch_size,)) 149 | sine_y = sine_wave(k, sine_period, sine_phase) 150 | sawtooth_y = sawtooth_wave(k, sawtooth_period, sawtooth_phase) 151 | in_chan = np.stack([sine_y, sawtooth_y], axis=-1) # shape (batch, in_size, 2) 152 | in_chan = np.reshape(in_chan, (batch_size, in_size, 2, 1)) 153 | # Mixing matrix 154 | M = get_random_M( 155 | batch_size, min_mix_coeff, min_Mlines_angle) 156 | # Mix 157 | out_chan = np.matmul(M, in_chan) 158 | # center and scale to 0, PER CHANNEL 159 | # (contrary to normal ICA where a PCA is done) 160 | m = np.mean(out_chan, axis=1, keepdims=True) 161 | v = np.var(out_chan, axis=1, keepdims=True) 162 | out_chan = (out_chan - m) / np.sqrt(1e-8 + v) 163 | out_chan = np.reshape(out_chan, (batch_size, in_size, 2)) 164 | yield out_chan 165 | 166 | def plot_example_2ch(expl, save_name=None): 167 | f, axarr = plt.subplots(2, sharex=True) 168 | axarr[0].plot(expl[:, 0]) 169 | axarr[1].plot(expl[:, 1]) 170 | if save_name is not None: 171 | plt.savefig(save_name) 172 | plt.close() 173 | 174 | if __name__ == "__main__": 175 | # ### TEST 176 | import os 177 | if not os.path.exists('ica_demo_results'): 178 | os.makedirs('ica_demo_results') 179 | k = np.arange(512) 180 | s = sine_wave(k, [32]*64, [10]*64)[0] 181 | import matplotlib 182 | matplotlib.use('Agg') 183 | import matplotlib.pyplot as plt 184 | plt.plot(s) 185 | plt.savefig('ica_demo_results/test.png') 186 | 187 | ### TEST ITERATOR 188 | batch_size = 128 189 | in_size = 512 190 | sine_period = 30 191 | sawtooth_period = 50 192 | b_it = sine_sawtooth_iterator_fixedperiods( 193 | batch_size, 194 | in_size, 195 | sine_period, 196 | sawtooth_period) 197 | b = next(b_it) 198 | print(b.shape) 199 | 200 | plot_example_2ch(b[0], "ica_demo_results/mix_0.png") 201 | plot_example_2ch(b[1], "ica_demo_results/mix_1.png") 202 | # do loop to estimate speed 203 | print('testing speed...') 204 | import time 205 | t = time.time() 206 | for _ in range(1000): 207 | b = next(b_it) 208 | t = time.time() - t 209 | print('Iterated 1000 batches in %d seconds' %t) 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /ica_poc_train.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------- 2 | # POC (Proof of concept) on 'toy' ICA 3 | # ----------------------------------- 4 | 5 | # train Multilayer Adaptive Spacial Filter ICA model 6 | # (with no nonlinearity on main filter) 7 | # on simple 'ICA_poc_data' 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | import ica_poc_data 17 | 18 | from model import MLAdaptiveSF_ICA, MLSF_ICA, lrelu 19 | 20 | import matplotlib 21 | matplotlib.use('Agg') 22 | import matplotlib.pyplot as plt 23 | 24 | ############################################# 25 | ########## CONFIG ########################## 26 | results_dir = '/datadrive1/ica_demo/' 27 | #results_dir = '/home/arnaud/data_these/multichan_cnn' 28 | 29 | in_size = 512 30 | 31 | sine_period = 16 32 | sawtooth_period = 20 33 | 34 | sine_period_valid = 20 35 | sawtooth_period_valid = 30 36 | 37 | # disable this or adapt with the api of your provider 38 | send_txtmsg_when_done = True 39 | if send_txtmsg_when_done: 40 | import textmsg 41 | 42 | ############################################# 43 | 44 | def write_to_comment_file(comment_file, text): 45 | with open(comment_file, "a") as f: 46 | f.write(text) 47 | 48 | def get_exp_name(exp_type, exp_num, min_mix_coeff, min_Mlines_angle, 49 | batch_size, training_batches, learning_rates, 50 | layer_sizes, alpha_decorr, use_backconnects): 51 | name_common = "_exp" + str(exp_num) \ 52 | + "b_" + str(*training_batches) \ 53 | + "_minmix" + str(min_mix_coeff) \ 54 | + "_minangle" + str(min_Mlines_angle) \ 55 | + '_lr' + str(*learning_rates) \ 56 | + "_lyr" + str(layer_sizes) \ 57 | + "_alph" + str(alpha_decorr) 58 | if exp_type == 'adaptive': 59 | name = "ica_adapt_exp" \ 60 | + name_common \ 61 | + "_bckcon" + str(use_backconnects) 62 | # VERY IMPORTANT: If your storage volume (unfortunately) uses 63 | # a NTFS filesystem, AVOID forbidden characters such as [, ] and , 64 | # otherwise very weird things will happen: tensorflow save will 65 | # work but restore not always, depending on the length of the path... 66 | elif exp_type == 'nonadaptive': 67 | name = "ica_nonadapt_exp" \ 68 | + name_common 69 | else: 70 | raise ValueError('exp_type must be \'adaptive\' or \'nonadaptive\'. ') 71 | name = name.replace('[', '') 72 | name = name.replace(']', '') 73 | name = name.replace(', ', '_') 74 | return name 75 | 76 | def str2bool(v): 77 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 78 | return True 79 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 80 | return False 81 | else: 82 | raise argparse.ArgumentTypeError('Boolean value expected.') 83 | 84 | def plot_train_costs(save_dir): 85 | costs_train = np.load(os.path.join(save_dir, 'costs_train.npy')) 86 | costs_ica = np.load(os.path.join(save_dir, 'costs_ica_train.npy')) 87 | costs_decorr = np.load(os.path.join(save_dir, 'costs_decorr_train.npy')) 88 | r = range(len(costs_train[1000:])) 89 | plt.figure() 90 | plt.plot(r, costs_train[1000:], label='total') 91 | plt.plot(r, costs_ica[1000:], label='ica') 92 | plt.plot(r, costs_decorr[1000:], label='decorr') 93 | plt.legend() 94 | plt.savefig(os.path.join(save_dir, 'costs.png')) 95 | plt.close() 96 | 97 | 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-exp_type', type=str, default='adaptive') 102 | parser.add_argument('-exp_num', type=int, default=-1) 103 | parser.add_argument('-min_mix_coeff', type=float, default=0.1) 104 | parser.add_argument('-min_Mlines_angle', type=float, default=np.pi/8.) 105 | parser.add_argument('-batch_size', type=int, default=128) 106 | parser.add_argument('-training_batches', nargs='+', type=int, 107 | default=[20000]) 108 | parser.add_argument('-learning_rates', nargs='+', type=float, 109 | default=[5e-5]) 110 | parser.add_argument('-layer_sizes', nargs='+', type=int, 111 | default=[32, 32, 32, 32, 32, 32, 32, 32]) 112 | parser.add_argument('-alpha_decorr', type=float, default=0.001) 113 | parser.add_argument('-use_backconnects', type=str2bool, 114 | default=str2bool('False')) 115 | args = parser.parse_args() 116 | train(args) 117 | 118 | def train(args): 119 | print(args) 120 | 121 | ###### prepare experiment ###### 122 | if not os.path.exists(results_dir): 123 | os.makedirs(results_dir) 124 | 125 | # NO: see note in exp_name definition. 126 | exp_name = get_exp_name( 127 | exp_type=args.exp_type, 128 | exp_num=args.exp_num, 129 | min_mix_coeff=args.min_mix_coeff, 130 | min_Mlines_angle=args.min_Mlines_angle, 131 | batch_size=args.batch_size, 132 | training_batches=args.training_batches, 133 | learning_rates=args.learning_rates, 134 | layer_sizes=args.layer_sizes, 135 | alpha_decorr=args.alpha_decorr, 136 | use_backconnects=args.use_backconnects) 137 | 138 | exp_dir = os.path.join(results_dir, exp_name) 139 | if not os.path.exists(exp_dir): 140 | os.makedirs(exp_dir) 141 | 142 | comment_text = ("Mixing coeffs drawn from [-1., -min_coeff]U[min_coeff, 1]" + '\n' + 143 | "Min mix coeff: " + str(args.min_mix_coeff) + '\n' + 144 | "Min angle between the lines of M: " + '\n' + 145 | "Training batches: " + str(args.training_batches) + '\n' + 146 | "Batch size: " + str(args.batch_size) + '\n' + 147 | "Learning rates: " + str(args.learning_rates) + '\n' + 148 | "Layer sizes: " + str(args.layer_sizes) + '\n' + 149 | "Alpha_decorr: " + str(args.alpha_decorr) + '\n') 150 | if args.exp_type == "adaptive": 151 | comment_text = (comment_text + 152 | "Use backconnects: " + str(args.use_backconnects) + '\n') 153 | 154 | write_to_comment_file( 155 | os.path.join(exp_dir, "comments.txt"), comment_text) 156 | 157 | ###### define model ###### 158 | print('DEFINING MODEL...') 159 | 160 | if args.exp_type == "adaptive": 161 | modelICA = MLAdaptiveSF_ICA( 162 | in_size=in_size, 163 | batch_size=args.batch_size, 164 | n_chan_in=2, 165 | alpha_decorr=args.alpha_decorr, 166 | layer_sizes=args.layer_sizes, 167 | activations_conv11=lrelu, 168 | activations_to_W=lrelu, 169 | activations_to_out=None, 170 | use_backconnects=args.use_backconnects, 171 | use_var_feat_only_net=False) 172 | else: 173 | modelICA = MLSF_ICA( 174 | in_size=512, 175 | batch_size=args.batch_size, 176 | n_chan_in=2, 177 | alpha_decorr=args.alpha_decorr, 178 | layer_sizes=args.layer_sizes, 179 | activations=lrelu) 180 | 181 | modelICA.init() 182 | 183 | # load previously saved model if there is one 184 | if tf.train.get_checkpoint_state(exp_dir) is not None: 185 | modelICA.load_model(exp_dir) 186 | # also load sequence of previous costs 187 | costs_train = np.load( 188 | os.path.join(exp_dir, 'costs_train.npy')).tolist() 189 | costs_ica_train = np.load( 190 | os.path.join(exp_dir, 'costs_ica_train.npy')).tolist() 191 | costs_decorr_train = np.load( 192 | os.path.join(exp_dir, 'costs_decorr_train.npy')).tolist() 193 | #costs_valid = np.load(os.path.join(exp_dir, 'costs_valid.npy')) 194 | #costs_ica_valid 195 | #costs_decorr_valid 196 | else: 197 | costs_train = [] 198 | costs_ica_train = [] 199 | costs_decorr_train = [] 200 | #costs_valid 201 | #costs_ica_valid 202 | #costs_decorr_valid 203 | 204 | ########## TRAIN ####################################### 205 | print('TRAINING...') 206 | 207 | try: 208 | 209 | it_train = ica_poc_data.sine_sawtooth_iterator_fixedperiods( 210 | batch_size=args.batch_size, 211 | in_size=in_size, 212 | sine_period=sine_period, 213 | sawtooth_period=sawtooth_period, 214 | min_mix_coeff=args.min_mix_coeff, 215 | min_Mlines_angle=args.min_Mlines_angle) 216 | 217 | # it_valid = ica_poc_data.sine_sawtooth_iterator_fixedperiods( 218 | # batch_size=args.batch_size, 219 | # in_size=in_size, 220 | # sine_period=sine_period_valid, 221 | # sawtooth_period=sawtooth_period_valid, 222 | # min_mix_coeff=args.min_mix_coeff, 223 | # min_Mlines_angle=args.min_Mlines_angle) 224 | 225 | # Training loop 226 | for b in range(np.sum(args.training_batches)): 227 | if b > 0 and b % 1000 == 0: 228 | print(str(100*b/np.sum(args.training_batches)) + \ 229 | ' percent done...') 230 | print("average train cost_total over the last 1000 evals: ", 231 | np.mean(costs_train[-1000:])) 232 | print("average train cost_ica over the last 1000 evals: ", 233 | np.mean(costs_ica_train[-1000:])) 234 | print("average train cost_decorr over the last 1000 evals: ", 235 | np.mean(costs_decorr_train[-1000:])) 236 | print("current learning rate: ", lr) 237 | sys.stdout.flush() 238 | 239 | if b > 0 and b % 10000 == 0: 240 | modelICA.save_model( 241 | os.path.join(exp_dir, 'model')) 242 | 243 | # For variable learning rate: 244 | c01 = b > np.cumsum(args.training_batches) 245 | idx_lr = np.where(c01==1)[0] 246 | if len(idx_lr) == 0: 247 | idx_lr = 0 248 | else: 249 | idx_lr = idx_lr[-1] + 1 250 | lr = args.learning_rates[idx_lr] 251 | 252 | # train 253 | examples = next(it_train) 254 | numeric_in = { 255 | 'in_X': examples, 256 | 'lr': lr, 257 | } 258 | cost_value, cost_ica_value, cost_decorr_value = \ 259 | modelICA.train_model(numeric_in) 260 | 261 | costs_train += [cost_value] 262 | costs_ica_train += [cost_ica_value] 263 | costs_decorr_train += [cost_decorr_value] 264 | 265 | if b % 5000 == 0: 266 | np.save( 267 | os.path.join(exp_dir, 'costs_train.npy'), 268 | np.array([*costs_train])) 269 | np.save( 270 | os.path.join(exp_dir, 'costs_ica_train.npy'), 271 | np.array([*costs_ica_train])) 272 | np.save( 273 | os.path.join(exp_dir, 'costs_decorr_train.npy'), 274 | np.array([*costs_decorr_train])) 275 | 276 | # also save after training 277 | np.save(os.path.join(exp_dir, 'costs_train.npy'), 278 | np.array([*costs_train])) 279 | np.save( 280 | os.path.join(exp_dir, 'costs_ica_train.npy'), 281 | np.array([*costs_ica_train])) 282 | np.save( 283 | os.path.join(exp_dir, 'costs_decorr_train.npy'), 284 | np.array([*costs_decorr_train])) 285 | 286 | modelICA.save_model( 287 | os.path.join(exp_dir, 'model')) 288 | 289 | except KeyboardInterrupt: 290 | print(' !!!!!!!! TRAINING INTERRUPTED !!!!!!!!') 291 | 292 | it_train = ica_poc_data.sine_sawtooth_iterator_fixedperiods( 293 | batch_size=args.batch_size, 294 | in_size=in_size, 295 | sine_period=sine_period, 296 | sawtooth_period=sawtooth_period, 297 | min_mix_coeff=args.min_mix_coeff, 298 | min_Mlines_angle=args.min_Mlines_angle) 299 | 300 | print('USING MODEL TO UNMIX SOME MIXED SAMPLES...') 301 | examples = next(it_train) 302 | numeric_in = {'in_X': examples} 303 | transformed_signals = modelICA.use_model(numeric_in) 304 | 305 | fig_save_dir = os.path.join( 306 | exp_dir, 'ica_demo_results', 'samples') 307 | if not os.path.exists(fig_save_dir): 308 | os.makedirs(fig_save_dir) 309 | print('saving example unmixed signals to: ', fig_save_dir) 310 | for idx in range(10): 311 | ica_poc_data.plot_example_2ch( 312 | examples[idx], 313 | os.path.join(fig_save_dir, 'mixed_example_%d.png' %idx)) 314 | ica_poc_data.plot_example_2ch( 315 | transformed_signals[idx], 316 | os.path.join(fig_save_dir, 'unmixed_example_%d.png' %idx)) 317 | 318 | # See values of alphas to see whether backconnects are 319 | # effectively used or not: 320 | # (Put this in model we need to use it more often) 321 | if args.use_backconnects: 322 | alpha_vars = [v for v in tf.global_variables() 323 | if 'alpha' in v.name 324 | and not 'Adam' in v.name] 325 | alpha_vars_names = [v.name for v in tf.global_variables() 326 | if 'alpha' in v.name 327 | and not 'Adam' in v.name] 328 | print(alpha_vars_names) 329 | alpha_vals = [modelICA.session.run(v) for v in alpha_vars] 330 | print(alpha_vals) 331 | 332 | modelICA.close() 333 | tf.reset_default_graph() 334 | 335 | plot_train_costs(save_dir=exp_dir) 336 | 337 | if send_txtmsg_when_done: 338 | textmsg.send_trainingdone_notif(exp_name) 339 | print('######## DONE. #########') 340 | 341 | if __name__ == '__main__': 342 | main() 343 | 344 | 345 | 346 | 347 | 348 | 349 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Progressive model 2 | # (i.e. unmixing happens 'progressively' in each layer) 3 | 4 | ''' 5 | * Common abstract base class for all architectures 6 | This common class implements common things such as 7 | placeholder definition, step counter, acc, cost, 8 | optimization, init, train_model, estimate_model, 9 | save_model, load_model, close 10 | * Each model inherits from common class 11 | Child classes implement specific architectures 12 | ''' 13 | 14 | ##### TODOS MAYBE ? ##### 15 | # For nonlinear recomposition: use more 'progressive' nonlinearities 16 | # than relu/lrelu... For example ELU with some learnable scaling 17 | # factors ... ?? 18 | 19 | 20 | ##### TODOS LATER ####### 21 | # TODO: TRY tf.contrib.layers.variance_scaling_initializer 22 | # TODO: Think whether batch normalization is adapted or not 23 | 24 | 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | from synth_eeg.synth_eeg_graph import create_synth_data_graph 29 | from synth_eeg import mix_mat 30 | 31 | class ModelBase(): 32 | # Abstract class for code and functionalities common to 33 | # all architectures and modes (ICA/classif) 34 | def __init__(self, 35 | in_size=512, 36 | batch_size=128, 37 | n_chan_in=19): 38 | 39 | self.in_size = in_size 40 | self.batch_size = batch_size 41 | self.n_chan_in = n_chan_in 42 | 43 | # placeholders 44 | # placeholders for in_X (and possibly target_Y) 45 | # will be defined in child classes 46 | self.lr = tf.placeholder(tf.float32, []) 47 | self.phase = tf.placeholder(tf.bool) # in case child class uses BN 48 | 49 | # step counter to give to saver, increment in optimizer 50 | self.global_step = tf.Variable( 51 | 0, name='global_step', trainable=False) 52 | 53 | def init(self): 54 | self.session = tf.Session() 55 | init_op = tf.global_variables_initializer() 56 | self.session.run(init_op) 57 | 58 | def save_model(self, checkpoint_path): 59 | print('Saving model...') 60 | self.saver.save( 61 | self.session, checkpoint_path) 62 | # let's use only for keeping the best model. 63 | # global_step saved as a variable, not passed 64 | # in .ckpt name 65 | 66 | def load_model(self, checkpoint_dir): 67 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 68 | #msg = ("total path is longer than 255 characters, " 69 | # "not ok on NTFS volume") 70 | #assert len(ckpt.model_checkpoint_path) <= 255, msg 71 | print("loading model: ", ckpt.model_checkpoint_path) 72 | self.saver.restore(self.session, ckpt.model_checkpoint_path) 73 | 74 | print('Model loaded. ') 75 | 76 | def close(self): 77 | self.session.close() 78 | 79 | 80 | class ICAModelBase(ModelBase): 81 | 82 | # common functions for training a decomposition 83 | # (cost etc.. different from classif model but 84 | # architectures can be the same) 85 | 86 | # Measure of non-Gaussianity: 87 | # For a random variable y, an approximation of negentropy is: 88 | # J(y) = [E(G(y)) - E(G(g))]^2 89 | # where g is a a gaussian variable with zero mean and unit variance 90 | # and G is some non-quadratic function such as 91 | # G1(y) = 1/a * log cosh(ay), 1<=a<=2 92 | # or G2(y) = -exp(-y^2 / 2) 93 | # REM: y is assumed zero mean and unit variance so if we 94 | # do not have the guarantee that it is, let's center and scale it. 95 | 96 | # let's use G2. Then the term E(G(g)) is -sqrt(1/2) 97 | 98 | def __init__(self, 99 | alpha_decorr, 100 | **kwargs): 101 | super().__init__(**kwargs) 102 | self.alpha_decorr = alpha_decorr 103 | 104 | ####### BUILD GRAPH ###### 105 | # placeholder for input 106 | self.in_X = tf.placeholder( 107 | tf.float32, [self.batch_size, self.in_size, self.n_chan_in]) 108 | 109 | transformed_X = self.get_transformed() 110 | # shape (batch_size, signal_len, n_chan) 111 | n_chan = transformed_X.get_shape()[2].value 112 | 113 | # center and normalize channels 114 | mean, var = tf.nn.moments(transformed_X, [1], keep_dims=True) #(batch, 1, n_chan) 115 | #in some cases n_chan can be different from self.n_chan_in 116 | self.transformed_X = (transformed_X - mean) / tf.sqrt(1e-8 + var) 117 | 118 | # We need not only the moments on transformed channels 119 | # (to center them) but also the covariances between channels 120 | # in order to force them to zero, break symmetry and avoid 121 | # a solution with identical output channels 122 | # Let's calculate covariances 123 | covmats_tr_X = get_covmat(self.transformed_X, reduce=True) 124 | _, offdiag_covs = get_diag_outdiag(covmats_tr_X) 125 | 126 | # Main 'NON-gaussianity' term (we want to maximize it) 127 | G2 = -tf.exp(-tf.square(self.transformed_X) / 2.) 128 | E_G2 = tf.reduce_mean(G2, axis=1) 129 | J = tf.square(E_G2 + np.sqrt(0.5).astype(np.float32)) 130 | # And decorrelation term: #REM: Use SQRT of sum or not ??? 131 | C_decorr = tf.reduce_sum( 132 | tf.square(offdiag_covs), axis=1) 133 | 134 | ############## COST AND OPTIMIZATION ################ 135 | self.COST_ICA = tf.reduce_mean(-J) # mean on batch and channels 136 | self.COST_DECORR = self.alpha_decorr * tf.reduce_mean(C_decorr) 137 | self.COST = self.COST_ICA + self.COST_DECORR 138 | # optim 139 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 140 | # tie BN statistics update to optimizer step 141 | # (in case BN is used in child class) 142 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 143 | with tf.control_dependencies(update_ops): 144 | self.OP = optimizer.minimize( 145 | self.COST, global_step=self.global_step) 146 | 147 | # saver 148 | self.saver = tf.train.Saver(max_to_keep=1) 149 | # (let's keep only the best model) 150 | 151 | def get_transformed(self): 152 | # to be implemented in child classes 153 | raise NotImplementedError 154 | 155 | def train_model(self, numeric_in): 156 | _, cost_value, cost_ica_value, cost_decorr_value = \ 157 | self.session.run( 158 | [self.OP, self.COST, self.COST_ICA, self.COST_DECORR], 159 | feed_dict={ 160 | self.in_X: numeric_in['in_X'], 161 | self.lr: numeric_in['lr'], 162 | self.phase: 1, 163 | }) 164 | return cost_value, cost_ica_value, cost_decorr_value 165 | 166 | def estimate_model(self, numeric_in): 167 | cost_value, cost_ica_value, cost_decorr_value = \ 168 | self.session.run( 169 | [self.COST, self.COST_ICA, self.COST_DECORR], 170 | feed_dict={ 171 | self.in_X: numeric_in['in_X'], 172 | self.phase: 0, 173 | }) 174 | return cost_value, cost_ica_value, cost_decorr_value 175 | 176 | def use_model(self, numeric_in): 177 | transformed_signals = self.session.run( 178 | self.transformed_X, 179 | feed_dict={ 180 | self.in_X: numeric_in['in_X'], 181 | self.phase: 0, 182 | }) 183 | return transformed_signals 184 | 185 | 186 | class ClassifModelBase(ModelBase): 187 | 188 | # common functions for training a classification model 189 | 190 | def __init__(self, 191 | n_totalsphere_sources, 192 | r_sources, 193 | r_zone, 194 | radial_only, 195 | fixed_ampl, 196 | different_zone_distrib, 197 | center_sources_idxs, 198 | r_geom, 199 | sigmas, 200 | **kwargs): 201 | # For this classif model, data will be generated 202 | # in-graph on the fly, so class first has to create 203 | # data graph, and then network graph 204 | 205 | # center_sources_idxs: a list of int, 206 | # the possible source indices of zone centers 207 | 208 | super().__init__(**kwargs) 209 | self.center_sources_idxs = center_sources_idxs 210 | 211 | ###### BUILD DATA GRAPH ####### 212 | 213 | n_sources = len(mix_mat.get_spread_points_on_sphere( 214 | n=n_totalsphere_sources, r=r_sources)) 215 | # These are the idxs of sources amongst which 216 | # 'zone center' sources can be drawn: 217 | self.possible_center_idxs = tf.placeholder_with_default( 218 | np.arange(n_sources).astype(np.int32), shape=[None]) 219 | 220 | common_args = { 221 | 'n_totalsphere_sources': n_totalsphere_sources, 222 | 'r_sources': r_sources, 223 | 'in_size': self.in_size, 224 | 'r_zone': r_zone, 225 | 'radial_only': radial_only, 226 | 'fixed_ampl': fixed_ampl, 227 | 'different_zone_distrib': different_zone_distrib, 228 | 'r_geom': r_geom, 229 | 'sigmas': sigmas, 230 | } 231 | 232 | if (len(center_sources_idxs)>=2) \ 233 | and (-1 not in center_sources_idxs): 234 | # Here we do k-class classification 235 | # BCI-like task, differentiate spacial patterns 236 | # For k=2 also compare with CSP 237 | print('Task: k-class classification') 238 | self.n_classes = len(center_sources_idxs) 239 | batch_size_ = self.batch_size 240 | _, act_elec_clz, csource_idxs_amgst_possible = \ 241 | create_synth_data_graph( 242 | possible_center_idxs=center_sources_idxs, 243 | batch_size=batch_size_, 244 | **common_args) 245 | self.in_X = act_elec_clz 246 | self.target_Y = tf.one_hot( 247 | csource_idxs_amgst_possible, depth=self.n_classes) 248 | else: 249 | if (len(center_sources_idxs)>=2) \ 250 | and (-1 in center_sources_idxs): 251 | # Here we do 2-class classification 252 | # k-class + background 253 | # Task: background vs. rest 254 | print(('Task: 2-class classification: ' 255 | '\'background\' vs \'rest in k-classes\'')) 256 | elif center_sources_idxs==[-1]: 257 | # Here also 2 class classification 258 | # background vs. rest, in rest any source 259 | # can serve as center. 260 | print(('Task: 2-class classification: ' 261 | '\'background\' vs \'rest amongst all possible sources\'')) 262 | else: 263 | raise ValueError(('Check center_sources_idxs ?' 264 | 'This case was not envisioned')) 265 | # in both of these cases: 266 | self.n_classes = 2 267 | if self.batch_size % 2 != 0: 268 | raise ValueError( 269 | '2 must divide batch size. ') 270 | batch_size_ = self.batch_size // 2 271 | act_elec_clb, act_elec_clz, _ = \ 272 | create_synth_data_graph( 273 | possible_center_idxs=center_sources_idxs, 274 | batch_size=batch_size_, 275 | **common_args) 276 | self.in_X = tf.concat( 277 | [act_elec_clb, act_elec_clz], axis=0) 278 | # let's use label 0 for background, 1 for zone 279 | self.target_Y = tf.concat( 280 | [tf.one_hot([0]*batch_size_, depth=self.n_classes), 281 | tf.one_hot([1]*batch_size_, depth=self.n_classes)], 282 | axis=0) 283 | 284 | 285 | ####### BUILD NETWORK GRAPH ####### 286 | # architecture 287 | out_vals = self.get_out_vals() 288 | # (batch_size, layer_sizes[-1]) 289 | # a last (fully-connected) layer for classification: 290 | out_logits = tf.layers.dense( 291 | out_vals, 292 | units=self.n_classes, 293 | activation=None, 294 | name='dense_classif') 295 | 296 | ############ ACC, COST, OPTIMIZATION ################ 297 | self.COST = tf.reduce_mean( 298 | tf.nn.softmax_cross_entropy_with_logits( 299 | labels=self.target_Y, 300 | logits=out_logits)) 301 | # for external eval 302 | self.preds = tf.nn.softmax(out_logits, dim=-1) 303 | self.preds_int = tf.argmax(self.preds, axis=1) 304 | self.preds_true = tf.equal( 305 | self.preds_int, tf.argmax(self.target_Y, axis=1)) 306 | self.acc = tf.reduce_mean( 307 | tf.cast(self.preds_true, tf.float32)) 308 | # optimization 309 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 310 | # tie BN statistics update to optimizer step 311 | # (in case BN is used in child class) 312 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 313 | with tf.control_dependencies(update_ops): 314 | self.OP = optimizer.minimize( 315 | self.COST, global_step=self.global_step) 316 | 317 | # saver 318 | self.saver = tf.train.Saver( 319 | tf.all_variables(), max_to_keep=1) 320 | # (let's keep only the best model) 321 | 322 | def get_out_vals(self): 323 | # to be implemented in child classes 324 | raise NotImplementedError 325 | 326 | def train_model(self, numeric_in): 327 | d = {self.lr: numeric_in['lr'], self.phase: 1,} 328 | if self.center_sources_idxs != [-1]: 329 | d_ = {self.possible_center_idxs: self.center_sources_idxs} 330 | d.update(d_) 331 | 332 | _, acc_value, cost_value = self.session.run( 333 | [self.OP, self.acc, self.COST], 334 | feed_dict=d) 335 | 336 | return acc_value, cost_value 337 | 338 | def estimate_model(self): 339 | d = {self.phase: 0,} 340 | if self.center_sources_idxs != [-1]: 341 | d_ = {self.possible_center_idxs: self.center_sources_idxs} 342 | d.update(d_) 343 | 344 | pred_values, acc_value, cost_value = self.session.run( 345 | [self.preds_int, self.acc, self.COST], 346 | feed_dict=d) 347 | 348 | return pred_values, acc_value, cost_value 349 | 350 | 351 | class MLAdaptiveSF_ICA(ICAModelBase): 352 | # Multilayer Adaptive Spacial Filter for ICA comparison 353 | 354 | def __init__(self, 355 | layer_sizes, 356 | activations_conv11, 357 | activations_to_W, 358 | activations_to_out, 359 | use_backconnects, 360 | use_var_feat_only_net, 361 | **kwargs): 362 | 363 | self.layer_sizes = layer_sizes 364 | self.activations_conv11 = activations_conv11 365 | self.activations_to_W = activations_to_W 366 | self.activations_to_out = activations_to_out 367 | self.use_backconnects = use_backconnects 368 | self.use_var_feat_only_net = use_var_feat_only_net 369 | super().__init__(**kwargs) 370 | 371 | def get_transformed(self): 372 | 373 | c = MultiLayerAdaptiveSpacialFilter( 374 | self.in_X, 375 | layer_sizes = self.layer_sizes, 376 | activations_conv11 = self.activations_conv11, 377 | activations_to_W = self.activations_to_W, 378 | activations_to_out = self.activations_to_out, 379 | use_backconnects = self.use_backconnects, 380 | use_var_feat_only = self.use_var_feat_only_net) 381 | return c 382 | 383 | 384 | class MLSF_ICA(ICAModelBase): 385 | # Multilayer (non-adaptive) Spatial Filter for ICA 386 | 387 | def __init__(self, 388 | layer_sizes, 389 | activations, 390 | **kwargs): 391 | 392 | self.layer_sizes = layer_sizes 393 | self.activations = activations 394 | super().__init__(**kwargs) 395 | 396 | def get_transformed(self): 397 | 398 | c = MultiLayerSpacialFilter( 399 | self.in_X, 400 | layer_sizes = self.layer_sizes, 401 | activations = self.activations) 402 | return c 403 | 404 | 405 | class MLSF_classif(ClassifModelBase): 406 | # MultiLayer Spacial Filter, non adaptive, 407 | # used for classication 408 | 409 | def __init__(self, 410 | layer_sizes, 411 | activations, 412 | use_var_feat_only_classif, 413 | **kwargs): 414 | 415 | self.layer_sizes = layer_sizes 416 | self.activations = activations 417 | self.use_var_feat_only_classif = use_var_feat_only_classif 418 | super().__init__(**kwargs) 419 | 420 | def get_out_vals(self): 421 | 422 | c = MultiLayerSpacialFilter( 423 | self.in_X, 424 | layer_sizes = self.layer_sizes, 425 | activations = self.activations) 426 | # (batch_size, in_size, n_chan) 427 | 428 | # Actually, if using only an overall linear transform, 429 | # it is probably not a good idea to use mean, var... 430 | mean, var = tf.nn.moments(c, [1]) # (batch_size, n_chan) 431 | if not self.use_var_feat_only_classif: 432 | spa_feats = tf.concat([mean, var], axis=1) 433 | else: spa_feats = var 434 | return spa_feats 435 | # TRY to use a non-gaussianity measure instead 436 | # (it simulates better what a temporal detector would do) 437 | # n_chan = c.get_shape()[2].value 438 | # # center and normalize channels (because measure of NG needs this) 439 | # mean, var = tf.nn.moments(c, [1], keep_dims=True) 440 | # c = (c - mean) / tf.sqrt(1e-8 + var) 441 | # # Main 'NON-gaussianity' term 442 | # G2 = -tf.exp(-tf.square(c) / 2.) 443 | # E_G2 = tf.reduce_mean(G2, axis=1) 444 | # J = tf.square(E_G2 + np.sqrt(0.5).astype(np.float32)) 445 | # J = tf.multiply(J, 10000) 446 | # # (batch_size, n_chan) 447 | # return J 448 | 449 | 450 | 451 | 452 | class MLAdaptiveSF_classif(ClassifModelBase): 453 | # MultiLayer Adaptive Spacial Filter used for classification 454 | 455 | def __init__(self, 456 | layer_sizes, 457 | activations_conv11, 458 | activations_to_W, 459 | activations_to_out, 460 | use_backconnects, 461 | use_var_feat_only_net, 462 | use_var_feat_only_classif, 463 | **kwargs): 464 | 465 | self.layer_sizes = layer_sizes 466 | self.activations_conv11 = activations_conv11 467 | self.activations_to_W = activations_to_W 468 | self.activations_to_out = activations_to_out 469 | self.use_backconnects = use_backconnects 470 | self.use_var_feat_only_net = use_var_feat_only_net 471 | self.use_var_feat_only_classif = use_var_feat_only_classif 472 | super().__init__(**kwargs) 473 | 474 | 475 | def get_out_vals(self): 476 | 477 | c = MultiLayerAdaptiveSpacialFilter( 478 | self.in_X, 479 | layer_sizes = self.layer_sizes, 480 | activations_conv11 = self.activations_conv11, 481 | activations_to_W = self.activations_to_W, 482 | activations_to_out = self.activations_to_out, 483 | use_backconnects = self.use_backconnects, 484 | use_var_feat_only = self.use_var_feat_only_net) 485 | 486 | # Actually, if using only an overall linear transform, 487 | # it is probably not a good idea to use mean, var... 488 | mean, var = tf.nn.moments(c, [1]) # (batch_size, n_chan) 489 | if not self.use_var_feat_only_classif: 490 | spa_feats = tf.concat([mean, var], axis=1) 491 | else: spa_feats = var 492 | return spa_feats 493 | 494 | # # TRY to use a non-gaussianity measure instead 495 | # # (it simulates better what a temporal detector would do) 496 | # n_chan = c.get_shape()[2].value 497 | # # center and normalize channels (because measure of NG needs this) 498 | # mean, var = tf.nn.moments(c, [1], keep_dims=True) 499 | # c = (c - mean) / tf.sqrt(1e-8 + var) 500 | # # Main 'NON-gaussianity' term 501 | # G2 = -tf.exp(-tf.square(c) / 2.) 502 | # E_G2 = tf.reduce_mean(G2, axis=1) 503 | # J = tf.square(E_G2 + np.sqrt(0.5).astype(np.float32)) 504 | # # normalize along channels 505 | # Jv = tf.reduce_mean(J, axis=-1, keep_dims=True) 506 | # J = tf.divide(J, tf.sqrt(1e-8 + Jv)) # replace by BN ?? 507 | # # (batch_size, n_chan) 508 | # #J = tf.multiply(J, 10000) 509 | # return J 510 | 511 | 512 | def lrelu(x, leak=0.1, name='lrelu'): 513 | return tf.maximum(x, leak*x, name) 514 | 515 | 516 | def get_covmat(tsr, reduce=False): 517 | # tensor tsr is assumed to have shape (batch_size, in_size, n_chan) 518 | # covariances are calculated between the n_chan channels 519 | # and along the in_size dimension. 520 | # other dimensions are kept. 521 | 522 | # ALL CHANNELS ARE EXPECTED TO BE CENTERED ! 523 | 524 | # broadcast channel dims before elmt-wise multiply 525 | # in order to calculate pairwise combinations 526 | batch_size = tsr.get_shape()[0].value 527 | in_size = tsr.get_shape()[1].value 528 | n_chan = tsr.get_shape()[2].value 529 | 530 | tsr_0 = tf.reshape(tsr, [batch_size, in_size, 1, n_chan]) 531 | tsr_1 = tf.reshape(tsr, [batch_size, in_size, n_chan, 1]) 532 | covs = tf.multiply(tsr_0, tsr_1) 533 | # shape (.., in_size, n_chan, n_chan) 534 | if reduce: 535 | covs = tf.reduce_mean(covs, axis=1) 536 | # shape (.., n_chan, n_chan) 537 | return covs 538 | 539 | 540 | def get_diag_outdiag(covs): 541 | # covs: a covmat of shape (batch_size, n_chan, n_chan) 542 | # OR (batch_size, in_size, n_chan, n_chan) 543 | # 544 | # tensorflow does not have an option for indexing along 545 | # a diagonal 546 | # This function gets diagonal and off-diag terms from 547 | # a covmat 548 | # -------------------------------------------------------- 549 | # EXAMPLE WITH DIAG TERMS: 550 | # ------------------------ 551 | # So let's do this by 'offsetting' the matrix, for example 552 | # (not showing batch dim) 553 | # covs = [[d, n, n, n], 554 | # [n, d, n, n], 555 | # [n, n, d, n], 556 | # [n, n, n, d]] 557 | # tf.reshape(covs, [4*4]) 558 | # covs = [d, n, n, n, n, d, n, n, n, n, d, n, n, n, n, d] 559 | # covs_allm1 = covs[0:(4*4-1)] 560 | # covs_last = covs[-1] 561 | # tf.reshape(cov_allm1, [3, 5]) 562 | # [[d, n, n, n, n], 563 | # [d, n, n, n, n], 564 | # [d, n, n, n, n]] 565 | # now diagonal elemts are the first column + covs_last 566 | # and non-diagonal elements are the other columns 567 | 568 | n_dim = len(covs.get_shape()) 569 | batch_size = covs.get_shape()[0].value 570 | n_chan = covs.get_shape()[-1].value 571 | if n_dim==3: 572 | covs = tf.reshape(covs, [batch_size, n_chan*n_chan]) 573 | covs_allm1 = covs[:, 0:n_chan*n_chan-1] 574 | covs_last = tf.reshape(covs[:, -1], 575 | [batch_size, 1]) 576 | covs_allm1 = tf.reshape(covs_allm1, 577 | [batch_size, n_chan-1, n_chan+1]) 578 | diag_terms = covs_allm1[:, :, 0] 579 | diag_terms = tf.concat([diag_terms, covs_last], axis=1) 580 | offdiag_terms = covs_allm1[:, :, 1:] 581 | offdiag_terms = tf.reshape(offdiag_terms, 582 | [batch_size, n_chan*(n_chan-1)]) 583 | # REM: we lose the order and we have duplicates (because 584 | # cov is symmetric... but let's not care for now. ) 585 | elif n_dim==4: 586 | # is there a way to avoid this duplicate code 587 | # with tensorflow indexing ? Is .. supported in tf ? 588 | in_size = covs.get_shape()[1].value 589 | covs = tf.reshape(covs, [batch_size, in_size, n_chan*n_chan]) 590 | covs_allm1 = covs[:, :, 0:n_chan*n_chan-1] 591 | covs_last = tf.reshape(covs[:, :, -1], 592 | [batch_size, in_size, 1]) 593 | covs_allm1 = tf.reshape(covs_allm1, 594 | [batch_size, in_size, n_chan-1, n_chan+1]) 595 | diag_terms = covs_allm1[:, :, :, 0] 596 | diag_terms = tf.concat([diag_terms, covs_lasxt], axis=2) 597 | offdiag_terms = covs_allm1[:, :, :, 1:] 598 | offdiag_terms = tf.reshape(offdiag_terms, 599 | [batch_size, in_size, n_chan*(n_chan-1)]) 600 | 601 | return diag_terms, offdiag_terms 602 | 603 | 604 | def MultiLayerSpacialFilter(in_X, 605 | layer_sizes=[19], # or list for MultiLayer 606 | activations=lrelu): 607 | 608 | # here we simply learn len(layer_sizes) spatial linear filters 609 | # and apply nonlinearities 610 | # sum along output channels are trained to be 611 | # maximally discriminative for 612 | 613 | in_size = in_X.get_shape()[1].value 614 | n_chan_in = in_X.get_shape()[2].value 615 | 616 | c = tf.reshape( 617 | in_X, [-1, 1, in_size, n_chan_in]) 618 | 619 | for l in range(len(layer_sizes)): 620 | c = tf.layers.conv2d( 621 | c, 622 | filters=layer_sizes[l], 623 | kernel_size=[1, 1], 624 | strides=1, 625 | activation=activations, 626 | name="spa_fil_%d" %l) 627 | 628 | c = tf.reshape(c, [-1, in_size, layer_sizes[-1]]) 629 | return c 630 | 631 | 632 | def AdaptiveSpacialLayer(in_X, 633 | layer_size, 634 | activation_conv11, 635 | activation_to_W, 636 | activation_to_out, 637 | use_backconnect=False, 638 | prev_spa_feats=None, 639 | use_var_feat_only=False, 640 | normalize_Wb=True, 641 | name=""): 642 | # adaptive spacial layer, to be used in MultiLayerAdaptiveSpacialFilter 643 | # expects and returns input of shape [batch_size, 1, in_size, n_chan] 644 | 645 | # if activation_to_out is not None, we also use a nonlinearity 646 | # with the main filter application 647 | 648 | batch_size = in_X.get_shape()[0].value 649 | in_size = in_X.get_shape()[2].value 650 | n_chan = in_X.get_shape()[3].value 651 | 652 | # make sure that if using backconnects, 653 | # a prev_spa_feats is provided 654 | if use_backconnect and (prev_spa_feats is None): 655 | raise ValueError('If using a backconnect, please provide' + \ 656 | 'a prev_spa_feats as input. ') 657 | 658 | c11 = tf.layers.conv2d( 659 | in_X, filters=layer_size, 660 | kernel_size=[1, 1], strides=1, 661 | activation=activation_conv11, 662 | name="c11_"+name) 663 | # optionally also add 3*1 filter here for spatio- 664 | # temporal filtering, and concatenate (later) 665 | mean, var = tf.nn.moments(c11, [1, 2]) # [batch_size, layer_size] 666 | spa_feats = var 667 | mulfac_spa_feats = 1 668 | if not use_var_feat_only: 669 | spa_feats = tf.concat([mean, var], axis=1) 670 | mulfac_spa_feats = 2 671 | 672 | if use_backconnect: 673 | # note that if using backconnects we do not necessarily need 674 | # that the current layer has the same number of units as 675 | # the previous layer, but it must have the SAME NUMBER 676 | # OF SPACIAL FEATURES 677 | 678 | # let's learn ONE alpha coeff per feature, so that 679 | # the network is able to 'decide to use different depths...' 680 | alpha = tf.get_variable( 681 | 'alpha_'+name, 682 | initializer=tf.ones([1, layer_size*mulfac_spa_feats])) 683 | spa_feats = tf.multiply(alpha, spa_feats) \ 684 | + tf.multiply(1. - alpha, prev_spa_feats) 685 | 686 | Wbdense_units = (n_chan + 1) * n_chan \ 687 | if activation_to_out is not None else n_chan * n_chan 688 | 689 | if activation_to_out is None: 690 | Wb = tf.layers.dense( 691 | spa_feats, units=n_chan*n_chan, 692 | activation=activation_to_W, 693 | name="dense_"+name) 694 | W = Wb 695 | if normalize_Wb: 696 | Wmean = tf.reduce_mean(W, axis=-1, keep_dims=True) 697 | Wscale = tf.get_variable( 698 | "Wscale_"+name, initializer=0.01) 699 | W = tf.multiply(Wscale, tf.divide(W, Wmean)) 700 | else: 701 | Wb = tf.layers.dense( 702 | spa_feats, units=n_chan*(n_chan+1), 703 | activation=activation_to_W, 704 | name="denseWb_"+name) 705 | W = Wb[:, :n_chan*n_chan] 706 | b = Wb[:, n_chan*n_chan:] 707 | if normalize_Wb: 708 | _, Wvar = tf.nn.moments(W, [-1], keep_dims=True) 709 | _, bvar = tf.nn.moments(b, [-1], keep_dims=True) 710 | Wstd = tf.sqrt(1e-8 + Wvar) 711 | bstd = tf.sqrt(1e-8 + bvar) 712 | Wscale = tf.get_variable( 713 | "Wscale_"+name, initializer=0.01) 714 | bscale = tf.get_variable( 715 | "bscale_"+name, initializer=0.01) 716 | W = tf.multiply(Wscale, tf.divide(W, Wstd)) 717 | b = tf.multiply(bscale, tf.divide(b, bstd)) 718 | I = tf.eye(n_chan) 719 | I = tf.reshape(I, [1, n_chan*n_chan]) 720 | W = W + I # learn residual part 721 | # we will use this as weights of the input transformation layer 722 | 723 | # ------------------------- 724 | # Now we apply the W (one DIFFERNET for EACH element of the batch) 725 | # For this we 'hack' a depthwise convolution: 726 | W = tf.reshape(W, [1, 1, batch_size*n_chan, n_chan]) 727 | # in_X has shape (batch_size, 1, in_size, n_chan) 728 | in_X_r = tf.transpose(in_X, [1, 2, 0, 3]) 729 | in_X_r = tf.reshape(in_X_r, [1, 1, in_size, batch_size*n_chan]) 730 | 731 | out_filtered = tf.nn.depthwise_conv2d( 732 | in_X_r, 733 | filter=W, 734 | strides=[1, 1, 1, 1], 735 | padding='VALID') #REM: here we don't care about padding 736 | # but still have to provide a value 737 | # out_filtered shape: (1, 1, in_size, batch_size*n_chan*n_chan) 738 | out_filtered = tf.reshape( 739 | out_filtered, [1, in_size, batch_size, n_chan, n_chan]) 740 | out_filtered = tf.transpose(out_filtered, [2, 0, 1, 3, 4]) 741 | # and finally sum on input chans 742 | out_filtered = tf.reduce_sum(out_filtered, axis=3) 743 | # shape (batch_size, 1, in_size, n_chan) 744 | if activation_to_out is not None: 745 | b = tf.reshape(b, [batch_size, 1, 1, n_chan]) 746 | out_filtered = out_filtered + b 747 | out_filtered = activation_to_out(out_filtered) 748 | 749 | return out_filtered if not use_backconnect \ 750 | else spa_feats, out_filtered 751 | 752 | 753 | def MultiLayerAdaptiveSpacialFilter(in_X, 754 | layer_sizes=[19], 755 | activations_conv11=lrelu, 756 | activations_to_W=lrelu, 757 | activations_to_out=lrelu, 758 | use_backconnects=False, 759 | use_var_feat_only=False): 760 | 761 | # here at each layer we determine spacial features, 762 | # which we then use to determine (through an FC layer) 763 | # the spacial filter to be applied 764 | 765 | # an optional backconnect injects add features from 766 | # previous layers to current layer 767 | # if using backconnects, all layers must use the same 768 | # number of spatial features: 769 | if use_backconnects: 770 | s_ = [layer_sizes[0]]*len(layer_sizes) 771 | if not s_ == layer_sizes: 772 | raise ValueError('If using a backconnect, please use' + \ 773 | 'layer sizes that are all equal') 774 | 775 | batch_size = in_X.get_shape()[0].value 776 | in_size = in_X.get_shape()[1].value 777 | n_chan = in_X.get_shape()[2].value 778 | 779 | c = tf.reshape( 780 | in_X, [batch_size, 1, in_size, n_chan]) 781 | # if using backconnects, in any case for the first layer 782 | # provide zeros as spacial features 783 | mulfac_spa_feats = 2 if not use_var_feat_only else 1 784 | prev_spa_feats = None if not use_backconnects \ 785 | else tf.zeros([batch_size, layer_sizes[0]*mulfac_spa_feats]) 786 | 787 | for l in range(len(layer_sizes)): 788 | res = AdaptiveSpacialLayer( 789 | c, layer_size=layer_sizes[l], 790 | activation_conv11=activations_conv11, 791 | activation_to_W=activations_to_W, 792 | activation_to_out=activations_to_out, 793 | use_backconnect=use_backconnects, 794 | prev_spa_feats=prev_spa_feats, 795 | use_var_feat_only=use_var_feat_only, 796 | name="adapt_%d" %l) 797 | if len(res) == 2: 798 | prev_spa_feats, c = res 799 | else: 800 | c = res 801 | 802 | c = tf.reshape(c, [-1, in_size, n_chan]) 803 | return c 804 | 805 | 806 | 807 | 808 | --------------------------------------------------------------------------------