├── ExplainAnomaliesUsingSHAP.py ├── README.txt ├── flow_example.ipynb ├── prefect_autoencoder_data_top_10k.csv ├── requirements.txt └── simulate perfect autoencoder data.ipynb /ExplainAnomaliesUsingSHAP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.layers import Input, Dense 6 | from tensorflow.keras import regularizers 7 | from tensorflow.keras.callbacks import EarlyStopping 8 | 9 | import shap 10 | import warnings 11 | import logging 12 | 13 | warnings.filterwarnings("ignore") 14 | logger = logging.getLogger('shap') 15 | logger.disabled = True 16 | 17 | 18 | class ExplainAnomaliesUsingSHAP: 19 | ''' 20 | This class implements method described in 'Explaining Anomalies Detected by Autoencoders Using SHAP' to explain 21 | anomalies revealed by an unsupervised Autoencoder model using SHAP. 22 | ''' 23 | 24 | autoencoder = None 25 | num_anomalies_to_explain = None 26 | reconstruction_error_percent = None 27 | shap_values_selection = None 28 | counter = None 29 | 30 | def __init__(self, num_anomalies_to_explain=100, reconstruction_error_percent=0.5, shap_values_selection='mean'): 31 | """ 32 | Args: 33 | num_anomalies_to_explain (int): number of top ranked anomalies (ranked by anomaly score that is the mse) to 34 | explain. 35 | reconstruction_error_percent (float): Number between 0 to 1- see explanation to this parameter in 36 | 'Explaining Anomalies Detected by Autoencoders Using SHAP' under 37 | ReconstructionErrorPercent. 38 | shap_values_selection (str): One of the possible methods to choose explaining features by their SHAP values. 39 | Can be: 'mean', 'median', 'constant'. See explanation to this parameter in 40 | 'Explaining Anomalies Detected by Autoencoders Using SHAP' under 41 | SHAPvaluesSelection. 42 | """ 43 | 44 | self.num_anomalies_to_explain = num_anomalies_to_explain 45 | self.reconstruction_error_percent = reconstruction_error_percent 46 | self.shap_values_selection = shap_values_selection 47 | 48 | def train_model(self, x_train, nb_epoch=1000, batch_size=64): 49 | """ 50 | Train 6-layer Autoencoder model on the given x_train data. 51 | 52 | Args: 53 | x_train (data frame): The data to train the Autoencoder model on 54 | nb_epoch (int): Number of epoch the model will perform 55 | batch_size (int): Size of each batch of data enter to the model 56 | 57 | Returns: 58 | model: Trained autoencoder 59 | """ 60 | 61 | input_dim = x_train.shape[1] 62 | 63 | input_layer = Input(shape=(input_dim,)) 64 | 65 | encoder = Dense(int(input_dim / 2), activation="relu", activity_regularizer=regularizers.l1(10e-7))( 66 | input_layer) 67 | 68 | encoder = Dense(int(input_dim / 4), activation="relu", kernel_regularizer=regularizers.l2(10e-7))(encoder) 69 | 70 | decoder = Dense(int(input_dim / 2), activation='relu', kernel_regularizer=regularizers.l2(10e-7))(encoder) 71 | 72 | decoder = Dense(input_dim, activation='sigmoid', kernel_regularizer=regularizers.l2(10e-7))(decoder) 73 | 74 | self.autoencoder = Model(inputs=input_layer, outputs=decoder) 75 | 76 | self.autoencoder.summary() 77 | 78 | self.autoencoder.compile(optimizer='adam', loss='mean_squared_error', metrics=['mse']) 79 | 80 | earlystopper = EarlyStopping(monitor='val_loss', patience=5, verbose=1) 81 | self.autoencoder.fit(x_train, x_train, epochs=nb_epoch, batch_size=batch_size, shuffle=True, 82 | validation_split=0.1, verbose=2, callbacks=[earlystopper]) 83 | 84 | return self.autoencoder 85 | 86 | def get_top_anomaly_to_explain(self, x_explain): 87 | """ 88 | Sort all records in x_explain by their MSE calculated according to their prediction by the trained Autoencoder 89 | and return the top num_anomalies_to_explain (its value given by the user at class initialization) records. 90 | 91 | Args: 92 | x_explain (data frame): Set of records we want to explain the most anomalous ones from it. 93 | 94 | Returns: 95 | list: List of index of the top num_anomalies_to_explain records with highest MSE that will be explained. 96 | """ 97 | 98 | predictions = self.autoencoder.predict(x_explain) 99 | square_errors = np.power(x_explain - predictions, 2) 100 | mse_series = pd.Series(np.mean(square_errors, axis=1)) 101 | 102 | most_anomal_trx = mse_series.sort_values(ascending=False) 103 | columns = ["id", "mse_all_columns"] 104 | columns.extend(["squared_error_" + x for x in list(x_explain.columns)]) 105 | items = [] 106 | for x in most_anomal_trx.iteritems(): 107 | item = [x[0], x[1]] 108 | item.extend(square_errors.loc[x[0]]) 109 | items.append(item) 110 | 111 | df_anomalies = pd.DataFrame(items, columns=columns) 112 | df_anomalies.set_index('id', inplace=True) 113 | 114 | top_anomalies_to_explain = df_anomalies.head(self.num_anomalies_to_explain).index 115 | return top_anomalies_to_explain 116 | 117 | def get_num_features_with_highest_reconstruction_error(self, total_squared_error, errors_df): 118 | """ 119 | Calculate the number of features whose reconstruction errors sum to reconstruction_error_percent of the 120 | total_squared_error of the records that selected to be explained at the moment. This is the number of the 121 | top reconstructed errors features that going to be explained and eventually this features together with their 122 | explanation will build up the features explanation set of this record. 123 | 124 | Args: 125 | total_squared_error (int): MSE of the records selected to be explained 126 | errors_df (data frame): The reconstruction error of each feature- this is the first output output of 127 | get_errors_df_per_record function 128 | 129 | Returns: 130 | int: Number of features whose reconstruction errors sum to reconstruction_error_percent of the 131 | total_squared_error of the records that selected to be explained at the moment 132 | """ 133 | 134 | error = 0 135 | for num_of_features, index in enumerate(errors_df.index): 136 | error += errors_df.loc[index, 'err'] 137 | if error >= self.reconstruction_error_percent * total_squared_error: 138 | break 139 | return num_of_features + 1 140 | 141 | def get_background_set(self, x_train, background_size=200): 142 | """ 143 | Get the first background_size records from x_train data and return it. Used for SHAP explanation process. 144 | 145 | Args: 146 | x_train (data frame): the data we will get the background set from 147 | background_size (int): The number of records to select from x_train. Default value is 200. 148 | 149 | Returns: 150 | data frame: Records from x_train that will be the background set of the explanation of the record that we 151 | explain at that moment using SHAP. 152 | """ 153 | 154 | background_set = x_train.head(background_size) 155 | return background_set 156 | 157 | def get_errors_df_per_record(self, record): 158 | """ 159 | Create data frame of the reconstruction errors of each features of the given record. Eventually we get data 160 | frame so each row contain the index of feature, its name, and its reconstruction error based on the record 161 | prediction provided by the trained autoencoder. This data frame is sorted by the reconstruction error of the 162 | features 163 | 164 | Args: 165 | record (pandas series): The record we explain at the moment; values of all its features. 166 | 167 | Returns: 168 | data frame: Data frame of all features reconstruction error sorted by the reconstruction error. 169 | """ 170 | 171 | prediction = self.autoencoder.predict(np.array([[record]])[0])[0] 172 | square_errors = np.power(record - prediction, 2) 173 | errors_df = pd.DataFrame({'col_name': square_errors.index, 'err': square_errors}).reset_index(drop=True) 174 | total_mse = np.mean(square_errors) 175 | errors_df.sort_values(by='err', ascending=False, inplace=True) 176 | return errors_df, total_mse 177 | 178 | def get_highest_shap_values(self, shap_values_df): 179 | """ 180 | Choosing explaining features based on their SHAP values by shap_values_selection method (mean, median, constant) 181 | i.e. remove all features with SHAP values that do not meet the method requirements as described in 'Explaining 182 | Anomalies Detected by Autoencoders Using SHAP' under SHAPvaluesSelection. 183 | 184 | Args: 185 | shap_values_df (data frame): Data frame with all existing features and their SHAP values. 186 | 187 | Returns: 188 | data frame: Data frame that contain for each feature we explain (features with high reconstruction error) 189 | its explaining features that selected by the shap_values_selection method and their SHAP values. 190 | """ 191 | 192 | all_explaining_features_df = pd.DataFrame() 193 | 194 | for i in range(shap_values_df.shape[0]): 195 | shap_values = shap_values_df.iloc[i] 196 | 197 | if self.shap_values_selection == 'mean': 198 | treshold_val = np.mean(shap_values) 199 | 200 | elif self.shap_values_selection == 'median': 201 | treshold_val = np.median(shap_values) 202 | 203 | elif self.shap_values_selection == 'constant': 204 | num_explaining_features = 5 205 | explaining_features = shap_values_df[i:i + 1].stack().nlargest(num_explaining_features) 206 | all_explaining_features_df = pd.concat([all_explaining_features_df, explaining_features], axis=0) 207 | continue 208 | 209 | else: 210 | raise ValueError('unknown SHAP value selection method') 211 | 212 | num_explaining_features = 0 213 | for j in range(len(shap_values)): 214 | if shap_values[j] > treshold_val: 215 | num_explaining_features += 1 216 | explaining_features = shap_values_df[i:i + 1].stack().nlargest(num_explaining_features) 217 | all_explaining_features_df = pd.concat([all_explaining_features_df, explaining_features], axis=0) 218 | return all_explaining_features_df 219 | 220 | def func_predict_feature(self, record): 221 | """ 222 | Predict the value of specific feature (with 'counter' index) using the trained autoencoder 223 | 224 | Args: 225 | record (pandas series): The record we explain at the moment; values of all its features. 226 | 227 | Returns: 228 | list: List the size of the number of features, contain the value of the predicted features with 'counter' 229 | index (the feature we explain at the moment) 230 | """ 231 | 232 | record_prediction = self.autoencoder.predict(record)[:, self.counter] 233 | return record_prediction 234 | 235 | def explain_unsupervised_data(self, x_train, x_explain, autoencoder=None, return_shap_values=False): 236 | """ 237 | First, if Autoencoder model not provided ('autoencoder' is None) train Autoencoder model on given x_train data. 238 | Then, for each record in 'top_records_to_explain' selected from given 'x_explain' as described in 239 | 'get_top_anomaly_to_explain' function, we use SHAP to explain the features with the highest reconstruction 240 | error based on the output of 'get_num_features_with_highest_reconstruction_error' function described above. 241 | Then, after we got the SHAP value of each feature in the explanation of the high reconstructed error feature, 242 | we select the explaining features using 'highest_contributing_features' function described above. Eventually, 243 | when we got the explaining features for each one of the features with highest reconstruction error, we build the 244 | explaining features set so the feature with the highest reconstruction error and its explaining features enter 245 | first to the explaining features set, then the next feature with highest reconstruction error and its explaining 246 | features enter to the explaining features set only if they don't already exist in the explaining features set 247 | and so on (full explanation + example exist in 'Explaining Anomalies Detected by Autoencoders Using SHAP') 248 | 249 | Args: 250 | x_train (data frame): The data to train the autoencoder model on and to select the background set from (for 251 | SHAP explanation process) 252 | x_explain (data frame): The data from which the top 'num_anomalies_to_explain' records are selected by their 253 | MSE to be explained. 254 | autoencoder (model): Trained Autoencoder model that will be used to explain x_explain data. If None (model 255 | not provided) then we will build and train from scratch a Autoencoder model as described 256 | in train_model function. 257 | return_shap_values (bool): If False, the resulting explnation featues set for each record will include only 258 | the names of the explaining features. If True, in addition to explaining feature name, 259 | the explnation featues set will include the SHAP value of each feature in the explnation 260 | featues set so the explnation featues set will be composed of tupels of (str, float) 261 | when str will be the name of the explaining feature and float will be its SHAP value. 262 | Note that for the explained features (features with high reconstraction error), if they 263 | did not appear in previuse feature explanation (explnation of feature with higher 264 | recustraction error), they will not have any SHAP values. Therefore they get unique 265 | value of -1. 266 | 267 | Returns: 268 | dict: Return all_sets_explaining_features dictionary that contain the explanation for 269 | 'top_records_to_explain' records so that the keys are int; the records indexes and the values are 270 | lists; the explanation features sets. 271 | """ 272 | 273 | self.autoencoder = autoencoder 274 | if self.autoencoder is None: 275 | self.train_model(x_train) 276 | 277 | top_records_to_explain = self.get_top_anomaly_to_explain(x_explain) 278 | all_sets_explaining_features = {} 279 | 280 | for record_idx in top_records_to_explain: 281 | print(record_idx) 282 | 283 | record_to_explain = x_explain.loc[record_idx] 284 | 285 | df_err, total_mse = self.get_errors_df_per_record(record_to_explain) 286 | num_of_features = self.get_num_features_with_highest_reconstruction_error(total_mse * df_err.shape[0], 287 | df_err) 288 | 289 | df_top_err = df_err.head(num_of_features) 290 | all_sets_explaining_features[record_idx] = [] 291 | shap_values_all_features = [[] for num in range(num_of_features)] 292 | 293 | backgroungd_set = self.get_background_set(x_train, 200).values 294 | for i in range(num_of_features): 295 | self.counter = df_top_err.index[i] 296 | explainer = shap.KernelExplainer(self.func_predict_feature, backgroungd_set) 297 | shap_values = explainer.shap_values(record_to_explain, nsamples='auto') 298 | shap_values_all_features[i] = shap_values 299 | 300 | shap_values_all_features = np.fabs(shap_values_all_features) 301 | 302 | shap_values_all_features = pd.DataFrame(data=shap_values_all_features, columns=x_train.columns) 303 | highest_contributing_features = self.get_highest_shap_values(shap_values_all_features) 304 | 305 | for idx_explained_feature in range(num_of_features): 306 | set_explaining_features =[] 307 | for idx, row in highest_contributing_features.iterrows(): 308 | if idx[0] == idx_explained_feature: 309 | set_explaining_features.append((idx[1], row[0])) 310 | explained_feature_index = df_top_err.index[idx_explained_feature] 311 | set_explaining_features.insert(0, (x_train.columns[explained_feature_index], -1)) 312 | 313 | all_sets_explaining_features[record_idx].append(set_explaining_features) 314 | 315 | final_set_features = [] 316 | final_set_items = [] 317 | for item in sum(all_sets_explaining_features[record_idx], []): 318 | if item[0] not in final_set_features: 319 | final_set_features.append(item[0]) 320 | final_set_items.append(item) 321 | 322 | if return_shap_values: 323 | all_sets_explaining_features[record_idx] = final_set_items 324 | else: 325 | all_sets_explaining_features[record_idx] = final_set_features 326 | 327 | return all_sets_explaining_features 328 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | install given requirements.txt file to run the code -------------------------------------------------------------------------------- /flow_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example of using our ExplainAnomaliesUsingSHAP code on data taken from kaggle: https://www.kaggle.com/mlg-ulb/creditcardfraud" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from ExplainAnomaliesUsingSHAP import ExplainAnomaliesUsingSHAP\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "# Normaliztion of data" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/html": [ 35 | "
\n", 36 | "\n", 49 | "\n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | "
V1V2V3V4V5V6V7V8V9V10...V21V22V23V24V25V26V27V28AmountClass
00.9351920.7664900.8813650.3130230.7634390.2676690.2668150.7864440.4753120.510600...0.5611840.5229920.6637930.3912530.5851220.3945570.4189760.3126970.0058240
10.9785420.7700670.8402980.2717960.7661200.2621920.2648750.7862980.4539810.505267...0.5578400.4802370.6669380.3364400.5872900.4460130.4163450.3134230.0001050
20.9352170.7531180.8681410.2687660.7623290.2811220.2701770.7880420.4106030.513018...0.5654770.5460300.6789390.2893540.5595150.4027270.4154890.3119110.0147390
30.9418780.7653040.8684840.2136610.7656470.2755590.2668030.7894340.4149990.507585...0.5597340.5102770.6626070.2238260.6142450.3891970.4176690.3143710.0048070
40.9386170.7765200.8642510.2697960.7629750.2639840.2689680.7824840.4909500.524303...0.5613270.5472710.6633920.4012700.5663430.5074970.4205610.3174900.0027240
\n", 199 | "

5 rows × 30 columns

\n", 200 | "
" 201 | ], 202 | "text/plain": [ 203 | " V1 V2 V3 V4 V5 V6 V7 \\\n", 204 | "0 0.935192 0.766490 0.881365 0.313023 0.763439 0.267669 0.266815 \n", 205 | "1 0.978542 0.770067 0.840298 0.271796 0.766120 0.262192 0.264875 \n", 206 | "2 0.935217 0.753118 0.868141 0.268766 0.762329 0.281122 0.270177 \n", 207 | "3 0.941878 0.765304 0.868484 0.213661 0.765647 0.275559 0.266803 \n", 208 | "4 0.938617 0.776520 0.864251 0.269796 0.762975 0.263984 0.268968 \n", 209 | "\n", 210 | " V8 V9 V10 ... V21 V22 V23 V24 \\\n", 211 | "0 0.786444 0.475312 0.510600 ... 0.561184 0.522992 0.663793 0.391253 \n", 212 | "1 0.786298 0.453981 0.505267 ... 0.557840 0.480237 0.666938 0.336440 \n", 213 | "2 0.788042 0.410603 0.513018 ... 0.565477 0.546030 0.678939 0.289354 \n", 214 | "3 0.789434 0.414999 0.507585 ... 0.559734 0.510277 0.662607 0.223826 \n", 215 | "4 0.782484 0.490950 0.524303 ... 0.561327 0.547271 0.663392 0.401270 \n", 216 | "\n", 217 | " V25 V26 V27 V28 Amount Class \n", 218 | "0 0.585122 0.394557 0.418976 0.312697 0.005824 0 \n", 219 | "1 0.587290 0.446013 0.416345 0.313423 0.000105 0 \n", 220 | "2 0.559515 0.402727 0.415489 0.311911 0.014739 0 \n", 221 | "3 0.614245 0.389197 0.417669 0.314371 0.004807 0 \n", 222 | "4 0.566343 0.507497 0.420561 0.317490 0.002724 0 \n", 223 | "\n", 224 | "[5 rows x 30 columns]" 225 | ] 226 | }, 227 | "execution_count": 2, 228 | "metadata": {}, 229 | "output_type": "execute_result" 230 | } 231 | ], 232 | "source": [ 233 | "df = pd.read_csv('../data/creditcard.csv', delimiter=',')\n", 234 | "df = df.drop(['Time'], axis=1)\n", 235 | "for col in df.columns[:-1]:\n", 236 | " min_val = df[col].min()\n", 237 | " max_val = df[col].max()\n", 238 | " if min_val != max_val:\n", 239 | " df[col] = (df[col] - min_val) / (max_val - min_val)\n", 240 | " \n", 241 | "df.head()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "# Split to Train and Test" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 3, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "x shape: (284807, 29)\n", 261 | "y shape: (284807,)\n", 262 | "0 284315\n", 263 | "1 492\n", 264 | "Name: Class, dtype: int64\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "X = df.iloc[:,:-1]\n", 270 | "y = df.iloc[:, -1]\n", 271 | "\n", 272 | "print('x shape:', X.shape)\n", 273 | "print('y shape:', y.shape)\n", 274 | "print(y.value_counts())\n", 275 | "\n", 276 | "train_idx = y[y==0].index.values\n", 277 | "test_idx = y[y==1].index.values\n", 278 | "\n", 279 | "X_train = X.iloc[train_idx]\n", 280 | "y_train = y[train_idx]\n", 281 | "\n", 282 | "X_test = X.iloc[test_idx]\n", 283 | "y_test = y[test_idx]" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "# Get explnation" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 4, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "exp_model = ExplainAnomaliesUsingSHAP(num_anomalies_to_explain=10)" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 5, 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "Model: \"model\"\n", 312 | "_________________________________________________________________\n", 313 | "Layer (type) Output Shape Param # \n", 314 | "=================================================================\n", 315 | "input_1 (InputLayer) [(None, 29)] 0 \n", 316 | "_________________________________________________________________\n", 317 | "dense (Dense) (None, 14) 420 \n", 318 | "_________________________________________________________________\n", 319 | "dense_1 (Dense) (None, 7) 105 \n", 320 | "_________________________________________________________________\n", 321 | "dense_2 (Dense) (None, 14) 112 \n", 322 | "_________________________________________________________________\n", 323 | "dense_3 (Dense) (None, 29) 435 \n", 324 | "=================================================================\n", 325 | "Total params: 1,072\n", 326 | "Trainable params: 1,072\n", 327 | "Non-trainable params: 0\n", 328 | "_________________________________________________________________\n", 329 | "WARNING:tensorflow:Falling back from v2 loop because of error: Failed to find data adapter that can handle input: , \n", 330 | "Train on 255883 samples, validate on 28432 samples\n", 331 | "Epoch 1/10\n", 332 | "255883/255883 - 8s - loss: 0.0024 - mse: 0.0023 - val_loss: 0.0012 - val_mse: 0.0012\n", 333 | "Epoch 2/10\n", 334 | "255883/255883 - 8s - loss: 0.0012 - mse: 0.0012 - val_loss: 9.7971e-04 - val_mse: 9.4585e-04\n", 335 | "Epoch 3/10\n", 336 | "255883/255883 - 8s - loss: 9.4128e-04 - mse: 9.0603e-04 - val_loss: 7.7618e-04 - val_mse: 7.3871e-04\n", 337 | "Epoch 4/10\n", 338 | "255883/255883 - 8s - loss: 7.9653e-04 - mse: 7.6116e-04 - val_loss: 7.5711e-04 - val_mse: 7.2357e-04\n", 339 | "Epoch 5/10\n", 340 | "255883/255883 - 8s - loss: 7.7851e-04 - mse: 7.4647e-04 - val_loss: 7.5049e-04 - val_mse: 7.1972e-04\n", 341 | "Epoch 6/10\n", 342 | "255883/255883 - 8s - loss: 7.7088e-04 - mse: 7.4122e-04 - val_loss: 7.4166e-04 - val_mse: 7.1278e-04\n", 343 | "Epoch 7/10\n", 344 | "255883/255883 - 9s - loss: 7.6278e-04 - mse: 7.3418e-04 - val_loss: 7.3703e-04 - val_mse: 7.0830e-04\n", 345 | "Epoch 8/10\n", 346 | "255883/255883 - 9s - loss: 7.5110e-04 - mse: 7.2236e-04 - val_loss: 7.3026e-04 - val_mse: 7.0129e-04\n", 347 | "Epoch 9/10\n", 348 | "255883/255883 - 9s - loss: 7.4219e-04 - mse: 7.1342e-04 - val_loss: 7.2397e-04 - val_mse: 6.9522e-04\n", 349 | "Epoch 10/10\n", 350 | "255883/255883 - 8s - loss: 7.3582e-04 - mse: 7.0741e-04 - val_loss: 7.2062e-04 - val_mse: 6.9238e-04\n", 351 | "WARNING:tensorflow:Falling back from v2 loop because of error: Failed to find data adapter that can handle input: , \n", 352 | "154587\n", 353 | "154684\n", 354 | "154371\n", 355 | "154234\n", 356 | "150644\n", 357 | "150647\n", 358 | "150665\n", 359 | "150654\n", 360 | "42528\n", 361 | "42635\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "all_sets_explaining_features = exp_model.explain_unsupervised_data(x_train=X_train, \n", 367 | " x_explain=X_test,\n", 368 | " return_shap_values=True)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 6, 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "data": { 378 | "text/plain": [ 379 | "{154587: [('V4', -1),\n", 380 | " ('V11', 0.0731265756822315),\n", 381 | " ('V15', 0.02369090910031154),\n", 382 | " ('V7', 0.016886018593518293),\n", 383 | " ('V3', 0.016570416382713796),\n", 384 | " ('V13', 0.015665877182100774),\n", 385 | " ('V10', 0.014052846919758964),\n", 386 | " ('V18', 0.00019320900654557518),\n", 387 | " ('V1', 0.00017568476566009978),\n", 388 | " ('V17', 0.00017336466464923974),\n", 389 | " ('V9', 0.0006888261077363056),\n", 390 | " ('V21', 0.0020651601972084717),\n", 391 | " ('V12', 0.0018995651805052225),\n", 392 | " ('V6', 0.001422672056074865),\n", 393 | " ('V8', -1)],\n", 394 | " 154684: [('V4', -1),\n", 395 | " ('V11', 0.06857125766422263),\n", 396 | " ('V15', 0.024017587211378415),\n", 397 | " ('V13', 0.01748191127094885),\n", 398 | " ('V7', 0.01707902271076717),\n", 399 | " ('V3', 0.01379392178383141),\n", 400 | " ('V10', 0.01240802991961949),\n", 401 | " ('V1', -1),\n", 402 | " ('V18', 0.0025376704074176696),\n", 403 | " ('V9', 0.000837353113461579),\n", 404 | " ('V17', 0.00358466848771998),\n", 405 | " ('V12', 0.002072845369657088),\n", 406 | " ('V21', 0.0020574753036938336),\n", 407 | " ('V22', 0.00046080208520149593),\n", 408 | " ('V8', -1),\n", 409 | " ('V24', 0.0003597983812910178)],\n", 410 | " 154371: [('V4', -1),\n", 411 | " ('V11', 0.08294142890489672),\n", 412 | " ('V15', 0.02898130105862798),\n", 413 | " ('V3', 0.01590907677201072),\n", 414 | " ('V7', 0.014711642443088473),\n", 415 | " ('V13', 0.014587836279879807),\n", 416 | " ('V10', 0.013403577555324252),\n", 417 | " ('V18', 0.0003123828531369758),\n", 418 | " ('V1', 0.0001992197589061244),\n", 419 | " ('V17', 0.00017078683719280304),\n", 420 | " ('V12', 0.002066954758604087),\n", 421 | " ('V24', 0.0016792627630605912),\n", 422 | " ('V21', 0.0016285431507732772)],\n", 423 | " 154234: [('V18', -1),\n", 424 | " ('V15', 0.04582284749101369),\n", 425 | " ('V4', 0.01298634659155869),\n", 426 | " ('V24', 0.005400751656493175),\n", 427 | " ('V11', 0.004962711118773501),\n", 428 | " ('V1', 0.004888440731168236),\n", 429 | " ('V9', 0.004203121378734948),\n", 430 | " ('V17', -1),\n", 431 | " ('V7', 0.0028306002556112014),\n", 432 | " ('V3', 0.0026911820955583705),\n", 433 | " ('V10', 0.0019216298091299933),\n", 434 | " ('V12', 0.0011375011534669778),\n", 435 | " ('V13', 0.01092450612555844),\n", 436 | " ('V16', 9.700471898336474e-05)],\n", 437 | " 150644: [('V12', -1),\n", 438 | " ('V11', 0.010486840442011657),\n", 439 | " ('V17', 0.004155066659667708),\n", 440 | " ('V4', 0.003106695971850944),\n", 441 | " ('V13', 0.0028908881702827925),\n", 442 | " ('V24', 0.001753354192488222),\n", 443 | " ('V14', 0.001660075517201605),\n", 444 | " ('V15', 0.004020149749684945),\n", 445 | " ('V18', 0.0034792307493295323),\n", 446 | " ('V10', 0.0012823640295097126),\n", 447 | " ('V16', 0.007824932111410356)],\n", 448 | " 150647: [('V12', -1),\n", 449 | " ('V11', 0.009069677191330994),\n", 450 | " ('V4', 0.00438623604807333),\n", 451 | " ('V17', 0.0036944688557951216),\n", 452 | " ('V26', 0.0025862953143419055),\n", 453 | " ('V24', 0.002056737456413358),\n", 454 | " ('V14', 0.0014729120689511745),\n", 455 | " ('V18', 0.0035930031667444967),\n", 456 | " ('V15', 0.001178029124017615),\n", 457 | " ('V16', 0.001203094323917563),\n", 458 | " ('V19', 0.0007287963722595273)],\n", 459 | " 150665: [('V17', -1),\n", 460 | " ('V24', 0.006989545665981571),\n", 461 | " ('V11', 0.006511480206894255),\n", 462 | " ('V4', 0.003946292279777723),\n", 463 | " ('V15', 0.003525485111059065),\n", 464 | " ('V18', 0.003038005743877573),\n", 465 | " ('V12', 0.002472561714746555),\n", 466 | " ('V16', 0.008939958118581916),\n", 467 | " ('V19', 0.007667361038725944)],\n", 468 | " 150654: [('V12', -1),\n", 469 | " ('V11', 0.009297464939338189),\n", 470 | " ('V4', 0.003955479118907272),\n", 471 | " ('V17', 0.003865460078225134),\n", 472 | " ('V24', 0.0019398804405763668),\n", 473 | " ('V26', 0.0018423242061232359),\n", 474 | " ('V14', 0.0016718447008473962),\n", 475 | " ('V18', 0.0033557790402469314),\n", 476 | " ('V3', 0.0012585499497945432),\n", 477 | " ('V7', 0.0011889138875021267),\n", 478 | " ('V16', 0.0010422633904812783)],\n", 479 | " 42528: [('V17', -1),\n", 480 | " ('V24', 0.006736426024435591),\n", 481 | " ('V18', 0.0050778753111189595),\n", 482 | " ('V12', 0.002842628989519194),\n", 483 | " ('V4', 0.0024794057912630846),\n", 484 | " ('V14', 0.002269837239830145),\n", 485 | " ('V11', 0.002218586625682887),\n", 486 | " ('V19', 0.0016399651294882627),\n", 487 | " ('V16', 0.0013763903992034172),\n", 488 | " ('V9', 0.004386193418645018),\n", 489 | " ('V26', 0.001635154304709785),\n", 490 | " ('V10', 0.0010430554925531486),\n", 491 | " ('V13', 0.0009982975334452363)],\n", 492 | " 42635: [('V17', -1),\n", 493 | " ('V24', 0.00605546732778159),\n", 494 | " ('V18', 0.004766048030447719),\n", 495 | " ('V4', 0.002795297282351012),\n", 496 | " ('V11', 0.0021679928865168505),\n", 497 | " ('V12', 0.0020859762105622697),\n", 498 | " ('V15', 0.0019091023771055922),\n", 499 | " ('V14', 0.0014634281110916776),\n", 500 | " ('V10', 0.001426458113448244),\n", 501 | " ('V19', 0.012697520780446276),\n", 502 | " ('V16', 0.01025355503344674),\n", 503 | " ('V13', 0.0018196956159556073)]}" 504 | ] 505 | }, 506 | "execution_count": 6, 507 | "metadata": {}, 508 | "output_type": "execute_result" 509 | } 510 | ], 511 | "source": [ 512 | "all_sets_explaining_features" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.7.4" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 2 537 | } 538 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronniemi/explainAnomaliesUsingSHAP/3df55220d459d27e7029158563178a40751732c5/requirements.txt -------------------------------------------------------------------------------- /simulate perfect autoencoder data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 33 10 | }, 11 | "colab_type": "code", 12 | "executionInfo": { 13 | "elapsed": 1040, 14 | "status": "ok", 15 | "timestamp": 1566741408085, 16 | "user": { 17 | "displayName": "Ronnie Mindlin", 18 | "photoUrl": "", 19 | "userId": "17245682626254463986" 20 | }, 21 | "user_tz": -180 22 | }, 23 | "id": "3DXBYgkglrYm", 24 | "outputId": "c96b2248-17a5-4965-9433-20d0018ed7bb" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import numpy as np\n", 29 | "import pandas as pd" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": { 36 | "colab": {}, 37 | "colab_type": "code", 38 | "id": "60s3D0U7luAm" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "n_records = 1000000\n", 43 | "n_features = 4\n", 44 | "\n", 45 | "features_names = []\n", 46 | "for feature in range(1, n_features+1):\n", 47 | " features_names.append('feature_' + str(feature))\n", 48 | "\n", 49 | "data = np.random.rand(n_records, n_features)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": { 56 | "colab": {}, 57 | "colab_type": "code", 58 | "id": "39AvCMiImdCg" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "data_df = pd.DataFrame(data=data, columns=features_names)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/", 71 | "height": 190 72 | }, 73 | "colab_type": "code", 74 | "executionInfo": { 75 | "elapsed": 1854, 76 | "status": "ok", 77 | "timestamp": 1566741408926, 78 | "user": { 79 | "displayName": "Ronnie Mindlin", 80 | "photoUrl": "", 81 | "userId": "17245682626254463986" 82 | }, 83 | "user_tz": -180 84 | }, 85 | "id": "k4NisjLRmg-b", 86 | "outputId": "c826d658-7b88-4fe9-efd5-0da960077899" 87 | }, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/html": [ 92 | "
\n", 93 | "\n", 106 | "\n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | "
feature_1feature_2feature_3feature_4
00.6511370.9872130.8067420.732936
10.1259910.6015960.2477400.909559
20.5670750.9765090.0326460.860562
30.5006080.6927400.9844980.520351
40.6236540.0684210.5639770.954146
\n", 154 | "
" 155 | ], 156 | "text/plain": [ 157 | " feature_1 feature_2 feature_3 feature_4\n", 158 | "0 0.651137 0.987213 0.806742 0.732936\n", 159 | "1 0.125991 0.601596 0.247740 0.909559\n", 160 | "2 0.567075 0.976509 0.032646 0.860562\n", 161 | "3 0.500608 0.692740 0.984498 0.520351\n", 162 | "4 0.623654 0.068421 0.563977 0.954146" 163 | ] 164 | }, 165 | "execution_count": 4, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "data_df.head()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 5, 177 | "metadata": { 178 | "colab": {}, 179 | "colab_type": "code", 180 | "id": "I8StXZhwmro0" 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "data_df['feature_1_2'] = data_df['feature_1'] + data_df['feature_2']\n", 185 | "data_df['feature_3_4'] = data_df['feature_3'] + data_df['feature_4']\n", 186 | "\n", 187 | "data_df['class'] = np.zeros(n_records)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": { 194 | "colab": { 195 | "base_uri": "https://localhost:8080/", 196 | "height": 190 197 | }, 198 | "colab_type": "code", 199 | "executionInfo": { 200 | "elapsed": 1834, 201 | "status": "ok", 202 | "timestamp": 1566741408927, 203 | "user": { 204 | "displayName": "Ronnie Mindlin", 205 | "photoUrl": "", 206 | "userId": "17245682626254463986" 207 | }, 208 | "user_tz": -180 209 | }, 210 | "id": "w3aH7rcqnArk", 211 | "outputId": "0c800a44-e0ee-49a4-ca4b-03c6c070a974" 212 | }, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/html": [ 217 | "
\n", 218 | "\n", 231 | "\n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | "
feature_1feature_2feature_3feature_4feature_1_2feature_3_4class
00.6511370.9872130.8067420.7329361.6383511.5396780.0
10.1259910.6015960.2477400.9095590.7275871.1573000.0
20.5670750.9765090.0326460.8605621.5435840.8932080.0
30.5006080.6927400.9844980.5203511.1933471.5048490.0
40.6236540.0684210.5639770.9541460.6920751.5181230.0
\n", 297 | "
" 298 | ], 299 | "text/plain": [ 300 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n", 301 | "0 0.651137 0.987213 0.806742 0.732936 1.638351 1.539678 0.0\n", 302 | "1 0.125991 0.601596 0.247740 0.909559 0.727587 1.157300 0.0\n", 303 | "2 0.567075 0.976509 0.032646 0.860562 1.543584 0.893208 0.0\n", 304 | "3 0.500608 0.692740 0.984498 0.520351 1.193347 1.504849 0.0\n", 305 | "4 0.623654 0.068421 0.563977 0.954146 0.692075 1.518123 0.0" 306 | ] 307 | }, 308 | "execution_count": 6, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "data_df.head()" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 7, 320 | "metadata": { 321 | "colab": {}, 322 | "colab_type": "code", 323 | "id": "q0o6tjtqoh7t" 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "n_to_change = 15000\n", 328 | "\n", 329 | "idx = data_df.sample(n=n_to_change).index\n", 330 | "\n", 331 | "for i in idx:\n", 332 | "# rand_vals = np.random.rand(2)\n", 333 | "# data_df.loc[i, 'feature_1_2'] = rand_vals[0]\n", 334 | "# data_df.loc[i, 'feature_3_4'] = rand_vals[1]\n", 335 | "# data_df.loc[i, 'class'] = '1'\n", 336 | " rand_num = np.random.rand(1)[0]\n", 337 | " if rand_num < 0.5:\n", 338 | " data_df.loc[i, 'feature_1_2'] = np.random.rand(1)[0]\n", 339 | " data_df.loc[i, 'class'] = '1_2'\n", 340 | " else:\n", 341 | " data_df.loc[i, 'feature_3_4'] = np.random.rand(1)[0]\n", 342 | " data_df.loc[i, 'class'] = '3_4'\n", 343 | " " 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 12, 349 | "metadata": { 350 | "colab": { 351 | "base_uri": "https://localhost:8080/", 352 | "height": 334 353 | }, 354 | "colab_type": "code", 355 | "executionInfo": { 356 | "elapsed": 794547, 357 | "status": "ok", 358 | "timestamp": 1566742201654, 359 | "user": { 360 | "displayName": "Ronnie Mindlin", 361 | "photoUrl": "", 362 | "userId": "17245682626254463986" 363 | }, 364 | "user_tz": -180 365 | }, 366 | "id": "J_C1LZCY8tmY", 367 | "outputId": "05a0c369-b3f1-4e03-c34c-9775f5e2b2c3" 368 | }, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/html": [ 373 | "
\n", 374 | "\n", 387 | "\n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | "
feature_1feature_2feature_3feature_4feature_1_2feature_3_4class
00.6511370.9872130.8067420.7329361.6383511.5396780
10.1259910.6015960.2477400.9095590.7275871.1573000
20.5670750.9765090.0326460.8605621.5435840.8932080
30.5006080.6927400.9844980.5203511.1933471.5048490
40.6236540.0684210.5639770.9541460.6920751.5181230
50.1244220.0827830.9093180.9150400.2072051.8243580
60.6078730.7089210.2103360.4177351.3167940.6280710
70.0408840.4558090.6818800.7061620.4966931.3880420
80.3869960.3643880.2419290.0038340.7513850.2457630
90.2069100.9663900.9034460.9327761.1733001.8362220
\n", 503 | "
" 504 | ], 505 | "text/plain": [ 506 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n", 507 | "0 0.651137 0.987213 0.806742 0.732936 1.638351 1.539678 0\n", 508 | "1 0.125991 0.601596 0.247740 0.909559 0.727587 1.157300 0\n", 509 | "2 0.567075 0.976509 0.032646 0.860562 1.543584 0.893208 0\n", 510 | "3 0.500608 0.692740 0.984498 0.520351 1.193347 1.504849 0\n", 511 | "4 0.623654 0.068421 0.563977 0.954146 0.692075 1.518123 0\n", 512 | "5 0.124422 0.082783 0.909318 0.915040 0.207205 1.824358 0\n", 513 | "6 0.607873 0.708921 0.210336 0.417735 1.316794 0.628071 0\n", 514 | "7 0.040884 0.455809 0.681880 0.706162 0.496693 1.388042 0\n", 515 | "8 0.386996 0.364388 0.241929 0.003834 0.751385 0.245763 0\n", 516 | "9 0.206910 0.966390 0.903446 0.932776 1.173300 1.836222 0" 517 | ] 518 | }, 519 | "execution_count": 12, 520 | "metadata": {}, 521 | "output_type": "execute_result" 522 | } 523 | ], 524 | "source": [ 525 | "data_df.head(10)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 13, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "1.9996749006122247\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "x_max_val = data_df.iloc[:,:-1].max().max()\n", 543 | "print(x_max_val)\n", 544 | "data_df.iloc[:,:-1] = data_df.iloc[:,:-1] / x_max_val" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 14, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/html": [ 555 | "
\n", 556 | "\n", 569 | "\n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | "
feature_1feature_2feature_3feature_4feature_1_2feature_3_4class
00.3256220.4936870.4034370.3665280.8193090.7699640
10.0630060.3008470.1238900.4548530.3638530.5787440
20.2835840.4883340.0163260.4303510.7719170.4466770
30.2503450.3464260.4923290.2602180.5967710.7525470
40.3118780.0342160.2820340.4771500.3460940.7591850
50.0622210.0413980.4547330.4575940.1036200.9123270
60.3039860.3545180.1051850.2089020.6585040.3140870
70.0204450.2279420.3409960.3531390.2483870.6941340
80.1935300.1822240.1209840.0019180.3757530.1229020
90.1034720.4832740.4517960.4664640.5867460.9182600
\n", 685 | "
" 686 | ], 687 | "text/plain": [ 688 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n", 689 | "0 0.325622 0.493687 0.403437 0.366528 0.819309 0.769964 0\n", 690 | "1 0.063006 0.300847 0.123890 0.454853 0.363853 0.578744 0\n", 691 | "2 0.283584 0.488334 0.016326 0.430351 0.771917 0.446677 0\n", 692 | "3 0.250345 0.346426 0.492329 0.260218 0.596771 0.752547 0\n", 693 | "4 0.311878 0.034216 0.282034 0.477150 0.346094 0.759185 0\n", 694 | "5 0.062221 0.041398 0.454733 0.457594 0.103620 0.912327 0\n", 695 | "6 0.303986 0.354518 0.105185 0.208902 0.658504 0.314087 0\n", 696 | "7 0.020445 0.227942 0.340996 0.353139 0.248387 0.694134 0\n", 697 | "8 0.193530 0.182224 0.120984 0.001918 0.375753 0.122902 0\n", 698 | "9 0.103472 0.483274 0.451796 0.466464 0.586746 0.918260 0" 699 | ] 700 | }, 701 | "execution_count": 14, 702 | "metadata": {}, 703 | "output_type": "execute_result" 704 | } 705 | ], 706 | "source": [ 707 | "data_df.head(10)" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 16, 713 | "metadata": { 714 | "colab": {}, 715 | "colab_type": "code", 716 | "id": "4QW7Tk56nCYv" 717 | }, 718 | "outputs": [], 719 | "source": [ 720 | "data_df.to_csv('../data/perfect_autoencoder/random_data_' + str(n_features+2) + '_features.csv', index=False)" 721 | ] 722 | } 723 | ], 724 | "metadata": { 725 | "colab": { 726 | "collapsed_sections": [], 727 | "name": "simulate_data.ipynb", 728 | "provenance": [], 729 | "version": "0.3.2" 730 | }, 731 | "kernelspec": { 732 | "display_name": "Python 3", 733 | "language": "python", 734 | "name": "python3" 735 | }, 736 | "language_info": { 737 | "codemirror_mode": { 738 | "name": "ipython", 739 | "version": 3 740 | }, 741 | "file_extension": ".py", 742 | "mimetype": "text/x-python", 743 | "name": "python", 744 | "nbconvert_exporter": "python", 745 | "pygments_lexer": "ipython3", 746 | "version": "3.7.4" 747 | } 748 | }, 749 | "nbformat": 4, 750 | "nbformat_minor": 1 751 | } 752 | --------------------------------------------------------------------------------