├── DataGenerator.py ├── README.md ├── example.py ├── graphics ├── CellLine │ └── GM_200_results.png ├── MNIST │ ├── mnist_3_full_importance.png │ ├── mnist_3_importance.png │ ├── mnist_8_full_importance.png │ ├── mnist_8_importance.png │ ├── mnist_colorbar.png │ ├── mnist_full_importance.png │ └── mnist_masked_importance.png ├── XOR.png ├── model.png ├── update.pdf └── yale │ ├── face_10.png │ ├── face_30.png │ ├── raw_10.png │ └── raw_30.png ├── requirements.txt └── src ├── FeatureSelector.py ├── MaskOptimizer.py ├── Operator.py └── Selector.py /DataGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generate_data(n=100, seed=0): 5 | """ 6 | Generate data (X,y) 7 | Args: 8 | n(int): number of samples 9 | seed: random seed used 10 | Return: 11 | X(float): [n,10]. 12 | y(float): n dimensional array. 13 | Taken from https://github.com/Jianbo-Lab/CCM 14 | See: 15 | http://papers.nips.cc/paper/7270-kernel-feature-selection-via-conditional-covariance-minimization.pdf 16 | for details. 17 | """ 18 | np.random.seed(seed) 19 | X = np.random.randn(n, 10) 20 | y = np.zeros(n) 21 | splits = np.linspace(0, n, num=8 + 1, dtype=int) 22 | signals = [[1, 1, 1], [-1, -1, -1], [1, 1, -1], [-1, -1, 1], [1, -1, -1], [-1, 1, 1], [-1, 1, -1], [1, -1, 1]] 23 | for i in range(8): 24 | X[splits[i]:splits[i + 1], :3] += np.array([signals[i]]) 25 | y[splits[i]:splits[i + 1]] = i // 2 26 | perm_inds = np.random.permutation(n) 27 | X, y = X[perm_inds], y[perm_inds] 28 | return X, y 29 | 30 | 31 | def get_one_hot(targets, nb_classes): 32 | res = np.eye(nb_classes)[np.array(targets).reshape(-1)] 33 | return res.reshape(list(targets.shape) + [nb_classes]) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeatureImportanceDL 2 | Contains example of a deep-learning based dual-net feature selection method from https://arxiv.org/abs/2010.08973 : 3 | 4 | 5 | 6 | This is an embedded method for supervised tasks: after training, it is able to give predictions (better then vanilla architecture) while also return population-wise (global) feature importances. 7 | 8 | See ```example.py``` script that shows how the method works for a basic XOR dataset. 9 | Some examples for MNIST dataset (digit 3 & 8 differentiation): 10 | 11 | ![MNIST importance](./graphics/MNIST/mnist_full_importance.png) 12 | ![Top features](./graphics/MNIST/mnist_masked_importance.png) 13 | 14 | ![Superimposed on mean 3](./graphics/MNIST/mnist_3_importance.png) 15 | ![Superimposed on mean 8](./graphics/MNIST/mnist_8_importance.png) 16 | 17 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tensorflow.keras as keras 4 | from tensorflow.keras import backend as K 5 | from src.FeatureSelector import FeatureSelector 6 | from DataGenerator import generate_data, get_one_hot 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 9 | 10 | # Dataset parameters 11 | N_TRAIN_SAMPLES = 512 12 | N_VAL_SAMPLES = 256 13 | N_TEST_SAMPLES = 1024 14 | N_FEATURES = 10 15 | FEATURE_SHAPE = (10,) 16 | dataset_label = "XOR_" 17 | 18 | # Training parapmeters 19 | data_batch_size = 32 20 | mask_batch_size = 32 21 | # final batch_size is data_batch_size x mask_batch_size 22 | s = 5 # size of optimal subset that we are looking for 23 | s_p = 2 # number of flipped bits in a mask when looking around m_opt 24 | phase_2_start = 6000 # after how many batches phase 2 will begin 25 | max_batches = 15000 # how many batches if the early stopping condition not satisfied 26 | early_stopping_patience = 600 # how many patience batches (after phase 2 starts) 27 | # before the training stops 28 | 29 | # Generate data for XOR dataset: 30 | # First three features are used to create the target (y) 31 | # All the following features are gaussian noise 32 | # In total 10 features 33 | X_tr, y_tr = generate_data(n=N_TRAIN_SAMPLES, seed=0) 34 | X_val, y_val = generate_data(n=N_VAL_SAMPLES, seed=0) 35 | X_te, y_te = generate_data(n=N_TEST_SAMPLES, seed=0) 36 | 37 | # Get one hot encoding of the labels 38 | y_tr = get_one_hot(y_tr.astype(np.int8), 4) 39 | y_te = get_one_hot(y_te.astype(np.int8), 4) 40 | y_val = get_one_hot(y_val.astype(np.int8), 4) 41 | 42 | # Create the framework, needs number of features and batch_sizes, str_id for tensorboard 43 | fs = FeatureSelector(FEATURE_SHAPE, s, data_batch_size, mask_batch_size, str_id=dataset_label) 44 | 45 | # Create a dense operator net, uses the architecture: 46 | # N_FEATURES x 2 -> 60 -> 30 -> 20 -> 4 47 | # with sigmoid activation in the final layer. 48 | fs.create_dense_operator([60, 30, 20, 4], "softmax", metrics=[keras.metrics.CategoricalAccuracy()], 49 | error_func=K.categorical_crossentropy) 50 | # Ealy stopping activate after the phase2 of the training starts. 51 | fs.operator.set_early_stopping_params(phase_2_start, patience_batches=early_stopping_patience, minimize=True) 52 | 53 | # Create a dense selector net, uses the architecture: 54 | # N_FEATURES -> 60 -> 30 -> 20 -> 4 55 | fs.create_dense_selector([100, 50, 10, 1]) 56 | 57 | # Set when the phase2 starts, what is the number of flipped bits when perturbin masks 58 | fs.create_mask_optimizer(epoch_condition=phase_2_start, perturbation_size=s_p) 59 | 60 | #Train networks and set the maximum number of iterations 61 | fs.train_networks_on_data(X_tr, y_tr, max_batches, val_data=(X_val, y_val)) 62 | 63 | #Results 64 | importances, optimal_mask = fs.get_importances(return_chosen_features=True) 65 | optimal_subset = np.nonzero(optimal_mask) 66 | test_performance = fs.operator.test_one(X_te, optimal_mask[None,:], y_te) 67 | print("Importances: ", importances) 68 | print("Optimal_subset: ", optimal_subset) 69 | print("Test performance (CE): ", test_performance[0]) 70 | print("Test performance (ACC): ", test_performance[1]) -------------------------------------------------------------------------------- /graphics/CellLine/GM_200_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/CellLine/GM_200_results.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_3_full_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_3_full_importance.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_3_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_3_importance.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_8_full_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_8_full_importance.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_8_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_8_importance.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_colorbar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_colorbar.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_full_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_full_importance.png -------------------------------------------------------------------------------- /graphics/MNIST/mnist_masked_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/MNIST/mnist_masked_importance.png -------------------------------------------------------------------------------- /graphics/XOR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/XOR.png -------------------------------------------------------------------------------- /graphics/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/model.png -------------------------------------------------------------------------------- /graphics/update.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/update.pdf -------------------------------------------------------------------------------- /graphics/yale/face_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/yale/face_10.png -------------------------------------------------------------------------------- /graphics/yale/face_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/yale/face_30.png -------------------------------------------------------------------------------- /graphics/yale/raw_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/yale/raw_10.png -------------------------------------------------------------------------------- /graphics/yale/raw_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maksym33/FeatureImportanceDL/836096edb9f822e509cadcf9cd2e7cc5fa2324cc/graphics/yale/raw_30.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.2.0 2 | numpy==1.18.1 3 | -------------------------------------------------------------------------------- /src/FeatureSelector.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras import backend as K 6 | 7 | from .MaskOptimizer import MaskOptimizer 8 | from .Operator import OperatorNetwork 9 | from .Selector import SelectorNetwork 10 | 11 | logs_base_dir = "./logs" 12 | os.makedirs(logs_base_dir, exist_ok=True) 13 | 14 | 15 | def mean_squared_error(y_true, y_pred): 16 | return K.mean((y_true - y_pred) * (y_true - y_pred), axis=1) 17 | 18 | 19 | def tf_mean_ax_0(losses): 20 | return tf.reduce_mean(losses, axis=0) 21 | 22 | 23 | def progressbar(it, prefix="", size=60): 24 | count = len(it) 25 | 26 | def show(j): 27 | x = int(size * j / count) 28 | print("\r%s[%s%s] %i/%i" % (prefix, "#" * x, "." * (size - x), j, count), end=" ") 29 | 30 | show(0) 31 | for i, item in enumerate(it): 32 | yield item 33 | show(i + 1) 34 | print() 35 | 36 | 37 | class FeatureSelector(): 38 | def __init__(self, data_shape, unmasked_data_size, data_batch_size, mask_batch_size, str_id="", 39 | epoch_on_which_selector_trained=8): 40 | self.data_shape = data_shape 41 | self.data_size = np.zeros(data_shape).size 42 | self.unmasked_data_size = unmasked_data_size 43 | self.logdir = os.path.join(logs_base_dir, datetime.datetime.now().strftime("%m%d-%H%M%S")) 44 | self.data_batch_size = data_batch_size 45 | self.mask_batch_size = mask_batch_size 46 | self.x_batch_size = mask_batch_size * data_batch_size 47 | self.str_id = str_id 48 | self.prev_mopt_condition = False 49 | self.epoch_on_which_selector_trained = epoch_on_which_selector_trained 50 | 51 | def create_dense_operator(self, arch, activation, metrics=None, error_func=mean_squared_error, es_patience=800): 52 | self.operator = OperatorNetwork(self.data_batch_size, self.mask_batch_size, self.logdir + "operator" + self.str_id) 53 | print("Creating operator model") 54 | self.operator.create_dense_model(self.data_shape, arch, activation) 55 | print("Compiling operator") 56 | self.operator.compile_model(error_func, tf.reduce_mean, tf_mean_ax_0, metrics) 57 | print("Created operator") 58 | 59 | def create_conv_operator(self, filters, kernels, dense_arch, activation, img_shape=None, channels=1, padding="same", 60 | metrics=None, error_func=None, es_patience=800): 61 | self.operator = OperatorNetwork(self.data_batch_size, self.mask_batch_size, self.logdir + "operator" + self.str_id) 62 | print("Creating operator model") 63 | if channels == 1: 64 | self.operator.create_1ch_conv_model(self.data_shape, image_shape=img_shape, filter_sizes=filters, 65 | kernel_sizes=kernels, dense_arch=dense_arch, padding=padding, 66 | last_activation=activation) 67 | else: 68 | self.operator.create_2ch_conv_model(self.data_shape, image_shape=img_shape, filter_sizes=filters, 69 | kernel_sizes=kernels, dense_arch=dense_arch, padding=padding, 70 | last_activation=activation) 71 | print("Compiling operator") 72 | self.operator.compile_model(error_func, tf.reduce_mean, tf_mean_ax_0, metrics) 73 | print("Created operator") 74 | 75 | def create_dense_selector(self, arch): 76 | self.selector = SelectorNetwork(self.mask_batch_size, tensorboard_logs_dir=self.logdir + "selector_" + self.str_id) 77 | self.selector.create_dense_model(self.data_shape, arch) 78 | self.selector.compile_model() 79 | 80 | def create_mask_optimizer(self, epoch_condition=5000, maximize_error=False, record_best_masks=False, 81 | perturbation_size=2, use_new_optimization=False): 82 | self.mopt = MaskOptimizer(self.mask_batch_size, self.data_shape, self.unmasked_data_size, 83 | epoch_condition=epoch_condition, perturbation_size=perturbation_size) 84 | self.selector.sample_weights = self.mopt.get_mask_weights(self.epoch_on_which_selector_trained) 85 | 86 | def test_networks_on_data(self, x, y, masks): 87 | # x,y = self.udg.get_batch(number_of_data_batches) 88 | m = masks 89 | losses = self.operator.test_one(x, m, y) 90 | target_shape = (len(y), len(masks)) 91 | losses = self.operator.get_per_mask_loss(target_shape) 92 | print("SN targets: " + str(losses)) 93 | # print("SN mean targets: "+str(np.mean(losses,axis=0))) 94 | sn_preds = np.squeeze(self.selector.predict(m)) 95 | print("SN preds: " + str(sn_preds)) 96 | return losses 97 | 98 | def train_networks_on_data(self, x_tr, y_tr, number_of_batches, val_data=None, val_freq=16): 99 | use_val_data = True 100 | if val_data is None: 101 | use_val_data = False 102 | X_val = None 103 | y_val = None 104 | if (use_val_data is True): 105 | X_val = val_data[0] 106 | y_val = val_data[1] 107 | 108 | for i in progressbar(range(number_of_batches), "Training batch: ", 50): 109 | mopt_condition = self.mopt.check_condiditon() 110 | 111 | random_indices = np.random.randint(0, len(x_tr), self.data_batch_size) 112 | x = x_tr[random_indices, :] 113 | y = y_tr[random_indices] 114 | selector_train_condition = ((self.operator.epoch_counter % self.epoch_on_which_selector_trained) == 0) 115 | m = self.mopt.get_new_mask_batch(self.selector.model, self.selector.best_performing_mask, 116 | gen_new_opt_mask=selector_train_condition) 117 | 118 | self.operator.train_one(x, m, y) 119 | losses = self.operator.get_per_mask_loss() 120 | losses = losses.numpy() 121 | self.selector.append_data(m, losses) 122 | if (selector_train_condition): 123 | self.selector.train_one(self.operator.epoch_counter, mopt_condition) 124 | 125 | self.prev_mopt_condition = mopt_condition 126 | if (use_val_data is True and self.operator.epoch_counter % val_freq == 0): 127 | self.operator.validate_one(X_val, m, y_val) 128 | if (self.operator.useEarlyStopping is True and self.operator.ES_stop_training is True): 129 | print("Activate early stopping at training epoch/batch: " + str(self.operator.epoch_counter)) 130 | print("Loading weights from epoch: " + str(self.operator.ES_best_epoch)) 131 | self.operator.model.set_weights(self.operator.ES_best_weights) 132 | break 133 | 134 | def get_importances(self, return_chosen_features=True): 135 | features_opt_used = np.squeeze( 136 | np.argwhere(self.mopt.get_opt_mask(self.unmasked_data_size, self.selector.model, 12) == 1)) 137 | m_best_used_features = np.zeros((1, self.data_size)) 138 | m_best_used_features[0, features_opt_used] = 1 139 | grad_used_opt = -MaskOptimizer.gradient(self.selector.model, m_best_used_features)[0][0, :] 140 | importances = grad_used_opt 141 | if(return_chosen_features==False): 142 | return importances 143 | else: 144 | optimal_mask = m_best_used_features[0] 145 | return importances, optimal_mask 146 | -------------------------------------------------------------------------------- /src/MaskOptimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class MaskOptimizer: 6 | def __init__(self, mask_batch_size, data_shape, unmasked_data_size,perturbation_size, 7 | frac_of_rand_masks=0.5, epoch_condition=1000 ): 8 | self.data_shape = data_shape 9 | self.unmasked_data_size = unmasked_data_size 10 | self.data_size = np.zeros(data_shape).size 11 | self.mask_history = [] 12 | self.raw_mask_history = [] 13 | self.loss_history = [] 14 | self.epoch_counter = 0 15 | self.mask_batch_size = mask_batch_size 16 | self.frac_of_rand_masks = frac_of_rand_masks 17 | self.epoch_condition = epoch_condition 18 | self.perturbation_size = perturbation_size 19 | self.max_optimization_iters = 5 20 | self.step_count_history = [] 21 | 22 | def gradient(model, x): 23 | x_tensor = tf.convert_to_tensor(x, dtype=tf.float32) 24 | with tf.GradientTape() as t: 25 | t.watch(x_tensor) 26 | # loss_mask_size = (tf.norm(x_tensor,ord=2,axis=1)) 27 | loss_model = model(x_tensor) 28 | loss = loss_model # +0.001*loss_mask_size#*loss_mask_size 29 | return t.gradient(loss, x_tensor).numpy(), loss_model 30 | 31 | def new_get_mask_from_grads(grads, unmasked_size, mask_size): 32 | m_opt = np.zeros(shape=mask_size) 33 | top_arg_grad = np.argpartition(grads, -unmasked_size)[-unmasked_size:] 34 | m_opt[top_arg_grad] = 1 35 | return m_opt 36 | 37 | def new_get_m_opt(model, unmasked_size): 38 | input_img = np.ones(shape=model.layers[0].output_shape[0][1:])[None, :] / 2 # define an initial random image 39 | grad, loss = MaskOptimizer.gradient(model, input_img) 40 | grad = np.negative(np.squeeze(grad)) # change sign 41 | m_opt = MaskOptimizer.new_get_mask_from_grads(grad, unmasked_size, model.layers[0].output_shape[0][1:]) 42 | return m_opt 43 | 44 | def new_check_for_opposite_grad(m_opt_grad, m_opt_indexes): 45 | m_opt_grad_cp = np.copy(m_opt_grad[m_opt_indexes]) 46 | m_opt_arg_opposite_grad = np.argwhere(m_opt_grad_cp < 0) 47 | return m_opt_indexes[m_opt_arg_opposite_grad] 48 | 49 | def new_check_loss_for_opposite_indexes(model, m_opt, min_index, max_index, opposite_indexes): 50 | m_opt_changed = False 51 | m_opt_loss = model.predict(m_opt[None, :]) 52 | for ind in opposite_indexes: 53 | m_new_opt = np.copy(m_opt) 54 | m_new_opt[max_index] = 1 55 | m_new_opt[ind] = 0 56 | m_new_opt_loss = model.predict(m_new_opt[None, :]) 57 | if m_new_opt_loss < m_opt_loss: 58 | # print("Changed i "+str(max_index)+" from 0->1 and"+str(ind)+" from 1->0.") 59 | return True, m_new_opt 60 | return False, m_opt 61 | 62 | def new_check_for_likely_change(model, m_opt, min_index, max_index, m_opt_grad): 63 | m_opt_changed = False 64 | m_opt_loss = np.squeeze(model.predict(m_opt[None, :])) 65 | not_m_opt_indexes = np.argwhere(m_opt == 0) 66 | max_index = not_m_opt_indexes[np.argmax(m_opt_grad[not_m_opt_indexes])] 67 | m_new_opt = np.copy(m_opt) 68 | m_new_opt[min_index] = 0 69 | m_new_opt[max_index] = 1 70 | m_new_opt_loss = np.squeeze(model.predict(m_new_opt[None, :])) 71 | # print("New proposed likely m_opt: ") 72 | # print(str(m_new_opt)) 73 | # print("Losses old/new: "+str(m_opt_loss)+" "+str(m_new_opt_loss)) 74 | #print(m_new_opt_loss," ",m_opt_loss) 75 | if (m_new_opt_loss < m_opt_loss): 76 | return True, m_new_opt 77 | else: 78 | return False, m_opt 79 | 80 | def get_opt_mask(self, unmasked_size, model, steps=None): 81 | m_opt = MaskOptimizer.new_get_m_opt(model, unmasked_size) 82 | repeat_optimization = True 83 | step_count = 0 84 | if steps is None: 85 | steps = self.max_optimization_iters 86 | while (repeat_optimization == True and step_count < steps): 87 | # print(step_count) 88 | # print(np.squeeze(np.argwhere(m_opt==1))) 89 | step_count += 1 90 | repeat_optimization = False 91 | m_opt_grad, m_opt_loss = MaskOptimizer.gradient(model, m_opt[None, :]) 92 | m_opt_grad = -np.squeeze(m_opt_grad) 93 | m_opt_indexes = np.squeeze(np.argwhere(m_opt == 1)) 94 | # print(m_opt_indexes) 95 | # print(m_opt_grad[m_opt_indexes]) 96 | # min_index = MaskOptimizer.new_get_min_opt_grad(m_opt_grad,m_opt_indexes) 97 | min_index = m_opt_indexes[np.argmin(m_opt_grad[m_opt_indexes])] 98 | not_m_opt_indexes = np.squeeze(np.argwhere(m_opt == 0)) 99 | # print(m_opt_grad[not_m_opt_indexes]) 100 | if (not_m_opt_indexes.size > 1): 101 | max_index = not_m_opt_indexes[np.argmax(m_opt_grad[not_m_opt_indexes])] 102 | elif (not_m_opt_indexes.size == 1): 103 | max_index = not_m_opt_indexes 104 | # print(min_index) 105 | # print(max_index) 106 | opposite_indexes = MaskOptimizer.new_check_for_opposite_grad(m_opt_grad, m_opt_indexes) 107 | # print("opposite indexes: "+str(opposite_indexes)) 108 | repeat_optimization, m_opt = MaskOptimizer.new_check_loss_for_opposite_indexes(model, m_opt, min_index, 109 | max_index, 110 | opposite_indexes) 111 | if (repeat_optimization == True): 112 | # print("Repeating due negative indexes for unmasked inputs") 113 | continue 114 | repeat_optimization, m_opt = MaskOptimizer.new_check_for_likely_change(model, m_opt, min_index, 115 | max_index, m_opt_grad) 116 | if (repeat_optimization == True): 117 | # print("replacing lowest gradient unmasked index with highest gradient masked index gave better results") 118 | continue 119 | self.step_count_history.append(step_count - 1) 120 | return m_opt 121 | 122 | def check_condiditon(self): 123 | if (self.epoch_counter >= self.epoch_condition): 124 | return True 125 | else: 126 | return False 127 | 128 | def get_random_masks(self): 129 | masks_zero = np.zeros(shape=(self.mask_batch_size, self.data_size - self.unmasked_data_size)) 130 | masks_one = np.ones(shape=(self.mask_batch_size, self.unmasked_data_size)) 131 | masks = np.concatenate([masks_zero, masks_one], axis=1) 132 | masks_permuted = np.apply_along_axis(np.random.permutation, 1, masks) 133 | return masks_permuted 134 | 135 | def get_perturbed_masks(mask, n_masks, n_times=1): 136 | masks = np.tile(mask, (n_masks, 1)) 137 | for i in range(n_times): 138 | masks = MaskOptimizer.perturb_masks(masks) 139 | return masks 140 | 141 | def perturb_masks(masks): 142 | def perturb_one_mask(mask): 143 | where_0 = np.nonzero(mask - 1)[0] 144 | where_1 = np.nonzero(mask)[0] 145 | i0 = np.random.randint(0, len(where_0), 1) 146 | i1 = np.random.randint(0, len(where_1), 1) 147 | mask[where_0[i0]] = 1 148 | mask[where_1[i1]] = 0 149 | return mask 150 | 151 | n_masks = len(masks) 152 | masks = np.apply_along_axis(perturb_one_mask, 1, masks) 153 | return masks 154 | 155 | def get_new_mask_batch(self, model, best_performing_mask, gen_new_opt_mask): 156 | self.epoch_counter += 1 157 | random_masks = self.get_random_masks() 158 | if (gen_new_opt_mask): 159 | self.mask_opt = self.get_opt_mask(self.unmasked_data_size, model) 160 | # print("Opt: "+str(np.squeeze(np.argwhere(self.mask_opt==1)))) 161 | # print("Perf: "+str(np.squeeze(np.argwhere(best_performing_mask==1)))) 162 | if (self.check_condiditon() is True): 163 | index = int(self.frac_of_rand_masks * self.mask_batch_size) 164 | 165 | random_masks[index] = self.mask_opt 166 | random_masks[index + 1] = best_performing_mask 167 | random_masks[index + 2:] = MaskOptimizer.get_perturbed_masks(random_masks[index], 168 | self.mask_batch_size - (index + 2), 169 | self.perturbation_size) 170 | # print("mask batch_size: "+str(self.mask_batch_size)) 171 | # print("index: "+str(index)) 172 | # print("left: "+str(self.mask_batch_size - (index+1))) 173 | # [print(np.squeeze(np.argwhere(i==1))) for i in random_masks] 174 | return random_masks 175 | 176 | def get_mask_weights(self, tiling): 177 | w = np.ones(shape=self.mask_batch_size) 178 | index = int(self.frac_of_rand_masks * self.mask_batch_size) 179 | w[index] = 5 180 | w[index + 1] = 10 181 | return np.tile(w, tiling) 182 | -------------------------------------------------------------------------------- /src/Operator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.layers import Dense, Input, Flatten, Reshape, Conv2D, MaxPool2D 5 | from tensorflow.keras.callbacks import TensorBoard 6 | 7 | 8 | class OperatorNetwork: 9 | def __init__(self, x_batch_size, mask_batch_size, tensorboard_logs_dir="", add_mopt_perf_metric=True, 10 | use_early_stopping=True): 11 | self.batch_size = mask_batch_size * x_batch_size 12 | self.mask_batch_size = mask_batch_size 13 | self.x_batch_size = x_batch_size 14 | self.losses_per_sample = None 15 | # self.losses_per_sample = [] 16 | self.tr_loss_history = [] 17 | self.te_loss_history = [] 18 | self.tf_logs = tensorboard_logs_dir 19 | self.epoch_counter = 0 20 | self.add_mopt_perf_metric = add_mopt_perf_metric 21 | self.useEarlyStopping = use_early_stopping 22 | 23 | 24 | def create_dense_model(self, input_shape, dense_arch, last_activation="linear"): 25 | self.x_shape = input_shape 26 | self.y_shape = dense_arch[-1] 27 | input_data_layer = Input(shape=input_shape) 28 | x = Flatten()(input_data_layer) 29 | input_mask_layer = Input(shape=input_shape) 30 | mask = Flatten()(input_mask_layer) 31 | # x = K.concatenate([x,mask]) 32 | x = tf.keras.layers.Concatenate(axis=1)([x, mask]) 33 | for units in dense_arch[:-1]: 34 | x = Dense(units, activation="sigmoid")(x) 35 | x = Dense(dense_arch[-1], activation=last_activation)(x) 36 | self.model = Model(inputs=[input_data_layer, input_mask_layer], outputs=x) 37 | print("Object network model built:") 38 | #self.model.summary() 39 | 40 | def create_1ch_conv_model(self, input_shape, image_shape, filter_sizes, kernel_sizes, dense_arch, padding, 41 | last_activation="softmax"): # only for grayscale 42 | self.x_shape = input_shape 43 | self.y_shape = dense_arch[-1] 44 | input_data_layer = Input(shape=input_shape) 45 | in1 = Reshape(target_shape=(1,) + image_shape)(input_data_layer) 46 | input_mask_layer = Input(shape=input_shape) 47 | in2 = Reshape(target_shape=(1,) + image_shape)(input_mask_layer) 48 | 49 | x = tf.keras.layers.Concatenate(axis=1)([in1, in2]) 50 | for i in range(len(filter_sizes)): 51 | x = Conv2D(filters=filter_sizes[i], kernel_size=kernel_sizes[i], data_format="channels_first", 52 | activation="relu", padding=padding)(x) 53 | x = MaxPool2D(pool_size=(2, 2), padding=padding, data_format="channels_first")(x) 54 | x = Flatten()(x) 55 | for units in dense_arch[:-1]: 56 | x = Dense(units, activation="relu")(x) 57 | x = Dense(dense_arch[-1], activation=last_activation)(x) 58 | self.model = Model(inputs=[input_data_layer, input_mask_layer], outputs=x) 59 | print("Object network model built:") 60 | #self.model.summary() 61 | 62 | def create_2ch_conv_model(self, input_shape, image_shape, filter_sizes, kernel_sizes, dense_arch, padding, 63 | last_activation="softmax"): # only for grayscale 64 | self.x_shape = input_shape 65 | self.y_shape = dense_arch[-1] 66 | input_data_layer = Input(shape=input_shape) 67 | ch_data = Reshape(target_shape=(1,) + image_shape)(input_data_layer) 68 | input_mask_layer = Input(shape=input_shape) 69 | ch_mask = Reshape(target_shape=(1,) + image_shape)(input_mask_layer) 70 | 71 | for i in range(len(filter_sizes)): 72 | ch_data = Conv2D(filters=filter_sizes[i], kernel_size=kernel_sizes[i], data_format="channels_last", 73 | activation="relu", padding=padding)(ch_data) 74 | ch_data = MaxPool2D(pool_size=(2, 2), padding=padding, data_format="channels_last")(ch_data) 75 | ch_mask = Conv2D(filters=filter_sizes[i], kernel_size=kernel_sizes[i], data_format="channels_last", 76 | activation="relu", padding=padding)(ch_mask) 77 | ch_mask = MaxPool2D(pool_size=(2, 2), padding=padding, data_format="channels_last")(ch_mask) 78 | ch_mask = Flatten()(ch_mask) 79 | ch_data = Flatten()(ch_data) 80 | 81 | x = tf.keras.layers.Concatenate(axis=1)([ch_mask, ch_data]) 82 | for units in dense_arch[:-1]: 83 | x = Dense(units, activation="relu")(x) 84 | x = Dense(dense_arch[-1], activation=last_activation)(x) 85 | self.model = Model(inputs=[input_data_layer, input_mask_layer], outputs=x) 86 | print("Object network model built:") 87 | #self.model.summary() 88 | 89 | def create_batch(self, x, masks, y): 90 | """ 91 | x = [[1,2],[3,4]] -> [[1,2],[1,2],[1,2],[3,4],[3,4],[3,4]] 92 | masks = [[0,0],[1,0],[1,1]] -> [[0,0],[1,0],[1,1],[0,0],[1,0],[1,1]] 93 | y = [1,3] -> [1 ,1 ,1 ,3 ,3 ,3 ] 94 | """ 95 | # assert len(masks) == self.mask_size 96 | x_prim = np.repeat(x, len(masks), axis=0) 97 | y_prim = np.repeat(y, len(masks), axis=0) 98 | masks_prim = np.tile(masks, (len(x), 1)) 99 | 100 | x_prim *= masks_prim # MASKING 101 | 102 | # assert len(x_prim) == self.batch_size 103 | return x_prim, masks_prim, y_prim 104 | 105 | def named_logs(self, model, logs, mode="train"): 106 | result = {} 107 | try: 108 | iterator = iter(logs) 109 | except TypeError: 110 | logs = [logs] 111 | metricNames = (mode + "_" + i for i in model.metrics_names) 112 | for l in zip(metricNames, logs): 113 | result[l[0]] = l[1] 114 | return result 115 | 116 | def compile_model(self, loss_per_sample, combine_losses, combine_mask_losses, metrics=None): 117 | self.mask_loss_combine_function = combine_mask_losses 118 | if self.add_mopt_perf_metric is True: 119 | if metrics is None: 120 | metrics = [self.get_mopt_perf_metric()] 121 | else: 122 | metrics.append(self.get_mopt_perf_metric()) 123 | 124 | def logging_loss_function(y_true, y_pred): 125 | losses = loss_per_sample(y_true, y_pred) 126 | self.losses_per_sample = losses 127 | return combine_losses(losses) 128 | 129 | self.model.compile(loss=logging_loss_function, optimizer='nadam', metrics=metrics, run_eagerly=True) 130 | if self.tf_logs != "": 131 | log_path = './logs' 132 | self.tb_clbk = TensorBoard(self.tf_logs) 133 | self.tb_clbk.set_model(self.model) 134 | 135 | def get_per_mask_loss(self, used_target_shape=None): 136 | if used_target_shape is None: 137 | used_target_shape = (self.x_batch_size, self.mask_batch_size) 138 | losses = tf.reshape(self.losses_per_sample, used_target_shape) 139 | 140 | # losses = np.apply_along_axis(self.mask_loss_combine_function,0,losses) 141 | losses = self.mask_loss_combine_function(losses) 142 | return losses 143 | 144 | def get_per_mask_loss_with_custom_batch(self, losses, new_x_batch_size, new_mask_batch_size): 145 | losses = np.reshape(losses, newshape=(new_x_batch_size, new_mask_batch_size)) 146 | losses = np.apply_along_axis(self.mask_loss_combine_function, 0, losses) 147 | return losses 148 | 149 | def train_one(self, x, masks, y): 150 | x_prim, masks_prim, y_prim = self.create_batch(x, masks, y) 151 | curr_loss = self.model.train_on_batch(x=[x_prim, masks_prim], y=y_prim) 152 | self.tr_loss_history.append(curr_loss) 153 | self.epoch_counter += 1 154 | if self.tf_logs != "": 155 | self.tb_clbk.on_epoch_end(self.epoch_counter, self.named_logs(self.model, curr_loss)) 156 | return x_prim, masks_prim, y_prim 157 | 158 | def validate_one(self, x, masks, y): 159 | x_prim, masks_prim, y_prim = self.create_batch(x, masks, y) 160 | # print("ON: x: "+str(x_prim)) 161 | # print("ON: m: "+str(masks_prim)) 162 | # print("ON: y_true: "+str(y_prim)) 163 | curr_loss = self.model.test_on_batch(x=[x_prim, masks_prim], y=y_prim) 164 | self.te_loss_history.append(curr_loss) 165 | if self.tf_logs != "": 166 | self.tb_clbk.on_epoch_end(self.epoch_counter, self.named_logs(self.model, curr_loss, "val")) 167 | if self.useEarlyStopping is True: 168 | self.check_ES() 169 | # print("ON y_pred:" +str(np.squeeze(y_pred))) 170 | # print("ON loss per sample:" +str(np.squeeze(self.losses_per_sample.numpy()))) 171 | return x_prim, masks_prim, y_prim, self.losses_per_sample.numpy() 172 | 173 | def test_one(self, x, masks, y): 174 | x_prim, masks_prim, y_prim = self.create_batch(x, masks, y) 175 | curr_loss = self.model.test_on_batch(x=[x_prim, masks_prim], y=y_prim) 176 | self.te_loss_history.append(curr_loss) 177 | return curr_loss 178 | 179 | def get_mopt_perf_metric(self): 180 | # used_target_shape = (self.x_batch_size,self.mask_batch_size) 181 | def m_opt_loss(y_pred, y_true): 182 | if (self.losses_per_sample.shape[0] % self.mask_batch_size != 0): # when testing happens, not used anymore 183 | return 0.0 184 | else: # for training and validation batches 185 | losses = tf.reshape(self.losses_per_sample, (-1, self.mask_batch_size)) 186 | self.last_m_opt_perf = np.mean(losses[:, int(0.5 * self.mask_batch_size)]) 187 | return self.last_m_opt_perf 188 | 189 | return m_opt_loss 190 | 191 | def set_early_stopping_params(self, starting_epoch, patience_batches=800, minimize=True): 192 | self.ES_patience = patience_batches 193 | self.ES_minimize = minimize 194 | if (minimize is True): 195 | self.ES_best_perf = 1000000.0 196 | else: 197 | self.ES_best_perf = -1000000.0 198 | self.ES_best_epoch = starting_epoch 199 | self.ES_stop_training = False 200 | self.ES_start_epoch = starting_epoch 201 | self.ES_best_weights = None 202 | return 203 | 204 | def check_ES(self, ): 205 | if self.epoch_counter >= self.ES_start_epoch: 206 | if self.ES_minimize is True: 207 | if self.last_m_opt_perf < self.ES_best_perf: 208 | self.ES_best_perf = self.last_m_opt_perf 209 | self.ES_best_epoch = self.epoch_counter 210 | self.ES_best_weights = self.model.get_weights() 211 | else: 212 | if self.last_m_opt_perf > self.ES_best_perf: 213 | self.ES_best_perf = self.last_m_opt_perf 214 | self.ES_best_epoch = self.epoch_counter 215 | self.ES_best_weights = self.model.get_weights() 216 | # print("ES patience left: "+str(self.epoch_counter-self.ES_best_epoch)) 217 | if (self.epoch_counter - self.ES_best_epoch > self.ES_patience): 218 | self.ES_stop_training = True 219 | -------------------------------------------------------------------------------- /src/Selector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.layers import Dense, Input, Flatten 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.callbacks import TensorBoard 7 | 8 | class SelectorNetwork: 9 | def __init__(self, mask_batch_size, tensorboard_logs_dir=""): 10 | self.batch_size = mask_batch_size 11 | self.mask_batch_size = mask_batch_size 12 | self.tr_loss_history = [] 13 | self.te_loss_history = [] 14 | self.y_pred_std_history = [] 15 | self.y_true_std_history = [] 16 | self.tf_logs = tensorboard_logs_dir 17 | self.epoch_counter = 0 18 | self.data_masks = None 19 | self.data_targets = None 20 | self.best_performing_mask = None 21 | self.sample_weights = None 22 | 23 | def set_label_input_params(self, y_shape, y_input_layer): 24 | self.label_input_layer = y_input_layer 25 | self.label_shape = y_shape 26 | 27 | def create_dense_model(self, input_shape, dense_arch): 28 | input_mask_layer = Input(shape=input_shape) 29 | x = Flatten()(input_mask_layer) 30 | for i in range(len(dense_arch[:-1])): 31 | x = Dense(dense_arch[i], activation="sigmoid")(x) 32 | x = Dense(dense_arch[-1], activation="linear")(x) 33 | self.model = Model(inputs=[input_mask_layer], outputs=x) 34 | print("Subject Network model built:") 35 | #self.model.summary() 36 | 37 | def named_logs(self, model, logs): 38 | result = {} 39 | try: 40 | iterator = iter(logs) 41 | except TypeError: 42 | logs = [logs] 43 | for l in zip(model.metrics_names, logs): 44 | result[l[0]] = l[1] 45 | return result 46 | 47 | def compile_model(self): 48 | self.model.compile(loss='mae', optimizer='adam', 49 | metrics=[self.get_y_std_metric(True), self.get_y_std_metric(False)]) 50 | if self.tf_logs != "": 51 | log_path = './logs' 52 | self.tb_clbk = TensorBoard(self.tf_logs) 53 | self.tb_clbk.set_model(self.model) 54 | 55 | 56 | def train_one(self, epoch_number, apply_weights): # train on data in object memory 57 | if apply_weights == False: 58 | curr_loss = self.model.train_on_batch(x=self.data_masks, y=self.data_targets) 59 | else: 60 | curr_loss = self.model.train_on_batch(x=self.data_masks, y=self.data_targets, 61 | sample_weight=self.sample_weights) 62 | self.best_performing_mask = self.data_masks[np.argmin(self.data_targets, axis=0)] 63 | self.tr_loss_history.append(curr_loss) 64 | self.epoch_counter = epoch_number 65 | if self.tf_logs != "": 66 | self.tb_clbk.on_epoch_end(self.epoch_counter, self.named_logs(self.model, curr_loss)) 67 | self.data_masks = None 68 | self.data_targets = None 69 | 70 | 71 | def append_data(self, x, y): 72 | if self.data_masks is None: 73 | self.data_masks = x 74 | self.data_targets = y 75 | else: 76 | self.data_masks = np.concatenate([self.data_masks, x], axis=0) 77 | self.data_targets = tf.concat([self.data_targets, y], axis=0) 78 | 79 | def test_one(self, x, y): 80 | y_pred = self.model.predict(x=x) 81 | curr_loss = self.model.test_on_batch(x=x, y=y) 82 | self.te_loss_history.append(curr_loss) 83 | # print("SN test loss: "+str(curr_loss)) 84 | # print("SN prediction: "+str(np.squeeze(curr_loss))) 85 | # print("SN targets: "+str(np.squeeze(y_pred))) 86 | return curr_loss 87 | 88 | def predict(self, x): 89 | y_pred = self.model.predict(x=x) 90 | return y_pred 91 | 92 | def get_y_std_metric(self, ifpred=True): 93 | def y_pred_std_metric(y_true, y_pred): 94 | y_pred_std = K.std(y_pred) 95 | self.y_pred_std_history.append(y_pred_std) 96 | return y_pred_std 97 | 98 | def y_true_std_metric(y_true, y_pred): 99 | y_true_std = K.std(y_true) 100 | self.y_true_std_history.append(y_true_std) 101 | return y_true_std 102 | 103 | if (ifpred == True): 104 | return y_pred_std_metric 105 | else: 106 | return y_true_std_metric 107 | 108 | --------------------------------------------------------------------------------