├── ExplainAnomaliesUsingSHAP.py
├── README.txt
├── flow_example.ipynb
├── prefect_autoencoder_data_top_10k.csv
├── requirements.txt
└── simulate perfect autoencoder data.ipynb
/ExplainAnomaliesUsingSHAP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from tensorflow.keras.models import Model
5 | from tensorflow.keras.layers import Input, Dense
6 | from tensorflow.keras import regularizers
7 | from tensorflow.keras.callbacks import EarlyStopping
8 |
9 | import shap
10 | import warnings
11 | import logging
12 |
13 | warnings.filterwarnings("ignore")
14 | logger = logging.getLogger('shap')
15 | logger.disabled = True
16 |
17 |
18 | class ExplainAnomaliesUsingSHAP:
19 | '''
20 | This class implements method described in 'Explaining Anomalies Detected by Autoencoders Using SHAP' to explain
21 | anomalies revealed by an unsupervised Autoencoder model using SHAP.
22 | '''
23 |
24 | autoencoder = None
25 | num_anomalies_to_explain = None
26 | reconstruction_error_percent = None
27 | shap_values_selection = None
28 | counter = None
29 |
30 | def __init__(self, num_anomalies_to_explain=100, reconstruction_error_percent=0.5, shap_values_selection='mean'):
31 | """
32 | Args:
33 | num_anomalies_to_explain (int): number of top ranked anomalies (ranked by anomaly score that is the mse) to
34 | explain.
35 | reconstruction_error_percent (float): Number between 0 to 1- see explanation to this parameter in
36 | 'Explaining Anomalies Detected by Autoencoders Using SHAP' under
37 | ReconstructionErrorPercent.
38 | shap_values_selection (str): One of the possible methods to choose explaining features by their SHAP values.
39 | Can be: 'mean', 'median', 'constant'. See explanation to this parameter in
40 | 'Explaining Anomalies Detected by Autoencoders Using SHAP' under
41 | SHAPvaluesSelection.
42 | """
43 |
44 | self.num_anomalies_to_explain = num_anomalies_to_explain
45 | self.reconstruction_error_percent = reconstruction_error_percent
46 | self.shap_values_selection = shap_values_selection
47 |
48 | def train_model(self, x_train, nb_epoch=1000, batch_size=64):
49 | """
50 | Train 6-layer Autoencoder model on the given x_train data.
51 |
52 | Args:
53 | x_train (data frame): The data to train the Autoencoder model on
54 | nb_epoch (int): Number of epoch the model will perform
55 | batch_size (int): Size of each batch of data enter to the model
56 |
57 | Returns:
58 | model: Trained autoencoder
59 | """
60 |
61 | input_dim = x_train.shape[1]
62 |
63 | input_layer = Input(shape=(input_dim,))
64 |
65 | encoder = Dense(int(input_dim / 2), activation="relu", activity_regularizer=regularizers.l1(10e-7))(
66 | input_layer)
67 |
68 | encoder = Dense(int(input_dim / 4), activation="relu", kernel_regularizer=regularizers.l2(10e-7))(encoder)
69 |
70 | decoder = Dense(int(input_dim / 2), activation='relu', kernel_regularizer=regularizers.l2(10e-7))(encoder)
71 |
72 | decoder = Dense(input_dim, activation='sigmoid', kernel_regularizer=regularizers.l2(10e-7))(decoder)
73 |
74 | self.autoencoder = Model(inputs=input_layer, outputs=decoder)
75 |
76 | self.autoencoder.summary()
77 |
78 | self.autoencoder.compile(optimizer='adam', loss='mean_squared_error', metrics=['mse'])
79 |
80 | earlystopper = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
81 | self.autoencoder.fit(x_train, x_train, epochs=nb_epoch, batch_size=batch_size, shuffle=True,
82 | validation_split=0.1, verbose=2, callbacks=[earlystopper])
83 |
84 | return self.autoencoder
85 |
86 | def get_top_anomaly_to_explain(self, x_explain):
87 | """
88 | Sort all records in x_explain by their MSE calculated according to their prediction by the trained Autoencoder
89 | and return the top num_anomalies_to_explain (its value given by the user at class initialization) records.
90 |
91 | Args:
92 | x_explain (data frame): Set of records we want to explain the most anomalous ones from it.
93 |
94 | Returns:
95 | list: List of index of the top num_anomalies_to_explain records with highest MSE that will be explained.
96 | """
97 |
98 | predictions = self.autoencoder.predict(x_explain)
99 | square_errors = np.power(x_explain - predictions, 2)
100 | mse_series = pd.Series(np.mean(square_errors, axis=1))
101 |
102 | most_anomal_trx = mse_series.sort_values(ascending=False)
103 | columns = ["id", "mse_all_columns"]
104 | columns.extend(["squared_error_" + x for x in list(x_explain.columns)])
105 | items = []
106 | for x in most_anomal_trx.iteritems():
107 | item = [x[0], x[1]]
108 | item.extend(square_errors.loc[x[0]])
109 | items.append(item)
110 |
111 | df_anomalies = pd.DataFrame(items, columns=columns)
112 | df_anomalies.set_index('id', inplace=True)
113 |
114 | top_anomalies_to_explain = df_anomalies.head(self.num_anomalies_to_explain).index
115 | return top_anomalies_to_explain
116 |
117 | def get_num_features_with_highest_reconstruction_error(self, total_squared_error, errors_df):
118 | """
119 | Calculate the number of features whose reconstruction errors sum to reconstruction_error_percent of the
120 | total_squared_error of the records that selected to be explained at the moment. This is the number of the
121 | top reconstructed errors features that going to be explained and eventually this features together with their
122 | explanation will build up the features explanation set of this record.
123 |
124 | Args:
125 | total_squared_error (int): MSE of the records selected to be explained
126 | errors_df (data frame): The reconstruction error of each feature- this is the first output output of
127 | get_errors_df_per_record function
128 |
129 | Returns:
130 | int: Number of features whose reconstruction errors sum to reconstruction_error_percent of the
131 | total_squared_error of the records that selected to be explained at the moment
132 | """
133 |
134 | error = 0
135 | for num_of_features, index in enumerate(errors_df.index):
136 | error += errors_df.loc[index, 'err']
137 | if error >= self.reconstruction_error_percent * total_squared_error:
138 | break
139 | return num_of_features + 1
140 |
141 | def get_background_set(self, x_train, background_size=200):
142 | """
143 | Get the first background_size records from x_train data and return it. Used for SHAP explanation process.
144 |
145 | Args:
146 | x_train (data frame): the data we will get the background set from
147 | background_size (int): The number of records to select from x_train. Default value is 200.
148 |
149 | Returns:
150 | data frame: Records from x_train that will be the background set of the explanation of the record that we
151 | explain at that moment using SHAP.
152 | """
153 |
154 | background_set = x_train.head(background_size)
155 | return background_set
156 |
157 | def get_errors_df_per_record(self, record):
158 | """
159 | Create data frame of the reconstruction errors of each features of the given record. Eventually we get data
160 | frame so each row contain the index of feature, its name, and its reconstruction error based on the record
161 | prediction provided by the trained autoencoder. This data frame is sorted by the reconstruction error of the
162 | features
163 |
164 | Args:
165 | record (pandas series): The record we explain at the moment; values of all its features.
166 |
167 | Returns:
168 | data frame: Data frame of all features reconstruction error sorted by the reconstruction error.
169 | """
170 |
171 | prediction = self.autoencoder.predict(np.array([[record]])[0])[0]
172 | square_errors = np.power(record - prediction, 2)
173 | errors_df = pd.DataFrame({'col_name': square_errors.index, 'err': square_errors}).reset_index(drop=True)
174 | total_mse = np.mean(square_errors)
175 | errors_df.sort_values(by='err', ascending=False, inplace=True)
176 | return errors_df, total_mse
177 |
178 | def get_highest_shap_values(self, shap_values_df):
179 | """
180 | Choosing explaining features based on their SHAP values by shap_values_selection method (mean, median, constant)
181 | i.e. remove all features with SHAP values that do not meet the method requirements as described in 'Explaining
182 | Anomalies Detected by Autoencoders Using SHAP' under SHAPvaluesSelection.
183 |
184 | Args:
185 | shap_values_df (data frame): Data frame with all existing features and their SHAP values.
186 |
187 | Returns:
188 | data frame: Data frame that contain for each feature we explain (features with high reconstruction error)
189 | its explaining features that selected by the shap_values_selection method and their SHAP values.
190 | """
191 |
192 | all_explaining_features_df = pd.DataFrame()
193 |
194 | for i in range(shap_values_df.shape[0]):
195 | shap_values = shap_values_df.iloc[i]
196 |
197 | if self.shap_values_selection == 'mean':
198 | treshold_val = np.mean(shap_values)
199 |
200 | elif self.shap_values_selection == 'median':
201 | treshold_val = np.median(shap_values)
202 |
203 | elif self.shap_values_selection == 'constant':
204 | num_explaining_features = 5
205 | explaining_features = shap_values_df[i:i + 1].stack().nlargest(num_explaining_features)
206 | all_explaining_features_df = pd.concat([all_explaining_features_df, explaining_features], axis=0)
207 | continue
208 |
209 | else:
210 | raise ValueError('unknown SHAP value selection method')
211 |
212 | num_explaining_features = 0
213 | for j in range(len(shap_values)):
214 | if shap_values[j] > treshold_val:
215 | num_explaining_features += 1
216 | explaining_features = shap_values_df[i:i + 1].stack().nlargest(num_explaining_features)
217 | all_explaining_features_df = pd.concat([all_explaining_features_df, explaining_features], axis=0)
218 | return all_explaining_features_df
219 |
220 | def func_predict_feature(self, record):
221 | """
222 | Predict the value of specific feature (with 'counter' index) using the trained autoencoder
223 |
224 | Args:
225 | record (pandas series): The record we explain at the moment; values of all its features.
226 |
227 | Returns:
228 | list: List the size of the number of features, contain the value of the predicted features with 'counter'
229 | index (the feature we explain at the moment)
230 | """
231 |
232 | record_prediction = self.autoencoder.predict(record)[:, self.counter]
233 | return record_prediction
234 |
235 | def explain_unsupervised_data(self, x_train, x_explain, autoencoder=None, return_shap_values=False):
236 | """
237 | First, if Autoencoder model not provided ('autoencoder' is None) train Autoencoder model on given x_train data.
238 | Then, for each record in 'top_records_to_explain' selected from given 'x_explain' as described in
239 | 'get_top_anomaly_to_explain' function, we use SHAP to explain the features with the highest reconstruction
240 | error based on the output of 'get_num_features_with_highest_reconstruction_error' function described above.
241 | Then, after we got the SHAP value of each feature in the explanation of the high reconstructed error feature,
242 | we select the explaining features using 'highest_contributing_features' function described above. Eventually,
243 | when we got the explaining features for each one of the features with highest reconstruction error, we build the
244 | explaining features set so the feature with the highest reconstruction error and its explaining features enter
245 | first to the explaining features set, then the next feature with highest reconstruction error and its explaining
246 | features enter to the explaining features set only if they don't already exist in the explaining features set
247 | and so on (full explanation + example exist in 'Explaining Anomalies Detected by Autoencoders Using SHAP')
248 |
249 | Args:
250 | x_train (data frame): The data to train the autoencoder model on and to select the background set from (for
251 | SHAP explanation process)
252 | x_explain (data frame): The data from which the top 'num_anomalies_to_explain' records are selected by their
253 | MSE to be explained.
254 | autoencoder (model): Trained Autoencoder model that will be used to explain x_explain data. If None (model
255 | not provided) then we will build and train from scratch a Autoencoder model as described
256 | in train_model function.
257 | return_shap_values (bool): If False, the resulting explnation featues set for each record will include only
258 | the names of the explaining features. If True, in addition to explaining feature name,
259 | the explnation featues set will include the SHAP value of each feature in the explnation
260 | featues set so the explnation featues set will be composed of tupels of (str, float)
261 | when str will be the name of the explaining feature and float will be its SHAP value.
262 | Note that for the explained features (features with high reconstraction error), if they
263 | did not appear in previuse feature explanation (explnation of feature with higher
264 | recustraction error), they will not have any SHAP values. Therefore they get unique
265 | value of -1.
266 |
267 | Returns:
268 | dict: Return all_sets_explaining_features dictionary that contain the explanation for
269 | 'top_records_to_explain' records so that the keys are int; the records indexes and the values are
270 | lists; the explanation features sets.
271 | """
272 |
273 | self.autoencoder = autoencoder
274 | if self.autoencoder is None:
275 | self.train_model(x_train)
276 |
277 | top_records_to_explain = self.get_top_anomaly_to_explain(x_explain)
278 | all_sets_explaining_features = {}
279 |
280 | for record_idx in top_records_to_explain:
281 | print(record_idx)
282 |
283 | record_to_explain = x_explain.loc[record_idx]
284 |
285 | df_err, total_mse = self.get_errors_df_per_record(record_to_explain)
286 | num_of_features = self.get_num_features_with_highest_reconstruction_error(total_mse * df_err.shape[0],
287 | df_err)
288 |
289 | df_top_err = df_err.head(num_of_features)
290 | all_sets_explaining_features[record_idx] = []
291 | shap_values_all_features = [[] for num in range(num_of_features)]
292 |
293 | backgroungd_set = self.get_background_set(x_train, 200).values
294 | for i in range(num_of_features):
295 | self.counter = df_top_err.index[i]
296 | explainer = shap.KernelExplainer(self.func_predict_feature, backgroungd_set)
297 | shap_values = explainer.shap_values(record_to_explain, nsamples='auto')
298 | shap_values_all_features[i] = shap_values
299 |
300 | shap_values_all_features = np.fabs(shap_values_all_features)
301 |
302 | shap_values_all_features = pd.DataFrame(data=shap_values_all_features, columns=x_train.columns)
303 | highest_contributing_features = self.get_highest_shap_values(shap_values_all_features)
304 |
305 | for idx_explained_feature in range(num_of_features):
306 | set_explaining_features =[]
307 | for idx, row in highest_contributing_features.iterrows():
308 | if idx[0] == idx_explained_feature:
309 | set_explaining_features.append((idx[1], row[0]))
310 | explained_feature_index = df_top_err.index[idx_explained_feature]
311 | set_explaining_features.insert(0, (x_train.columns[explained_feature_index], -1))
312 |
313 | all_sets_explaining_features[record_idx].append(set_explaining_features)
314 |
315 | final_set_features = []
316 | final_set_items = []
317 | for item in sum(all_sets_explaining_features[record_idx], []):
318 | if item[0] not in final_set_features:
319 | final_set_features.append(item[0])
320 | final_set_items.append(item)
321 |
322 | if return_shap_values:
323 | all_sets_explaining_features[record_idx] = final_set_items
324 | else:
325 | all_sets_explaining_features[record_idx] = final_set_features
326 |
327 | return all_sets_explaining_features
328 |
--------------------------------------------------------------------------------
/README.txt:
--------------------------------------------------------------------------------
1 | install given requirements.txt file to run the code
--------------------------------------------------------------------------------
/flow_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Example of using our ExplainAnomaliesUsingSHAP code on data taken from kaggle: https://www.kaggle.com/mlg-ulb/creditcardfraud"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from ExplainAnomaliesUsingSHAP import ExplainAnomaliesUsingSHAP\n",
17 | "import pandas as pd"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "# Normaliztion of data"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "data": {
34 | "text/html": [
35 | "
\n",
36 | "\n",
49 | "
\n",
50 | " \n",
51 | " \n",
52 | " | \n",
53 | " V1 | \n",
54 | " V2 | \n",
55 | " V3 | \n",
56 | " V4 | \n",
57 | " V5 | \n",
58 | " V6 | \n",
59 | " V7 | \n",
60 | " V8 | \n",
61 | " V9 | \n",
62 | " V10 | \n",
63 | " ... | \n",
64 | " V21 | \n",
65 | " V22 | \n",
66 | " V23 | \n",
67 | " V24 | \n",
68 | " V25 | \n",
69 | " V26 | \n",
70 | " V27 | \n",
71 | " V28 | \n",
72 | " Amount | \n",
73 | " Class | \n",
74 | "
\n",
75 | " \n",
76 | " \n",
77 | " \n",
78 | " 0 | \n",
79 | " 0.935192 | \n",
80 | " 0.766490 | \n",
81 | " 0.881365 | \n",
82 | " 0.313023 | \n",
83 | " 0.763439 | \n",
84 | " 0.267669 | \n",
85 | " 0.266815 | \n",
86 | " 0.786444 | \n",
87 | " 0.475312 | \n",
88 | " 0.510600 | \n",
89 | " ... | \n",
90 | " 0.561184 | \n",
91 | " 0.522992 | \n",
92 | " 0.663793 | \n",
93 | " 0.391253 | \n",
94 | " 0.585122 | \n",
95 | " 0.394557 | \n",
96 | " 0.418976 | \n",
97 | " 0.312697 | \n",
98 | " 0.005824 | \n",
99 | " 0 | \n",
100 | "
\n",
101 | " \n",
102 | " 1 | \n",
103 | " 0.978542 | \n",
104 | " 0.770067 | \n",
105 | " 0.840298 | \n",
106 | " 0.271796 | \n",
107 | " 0.766120 | \n",
108 | " 0.262192 | \n",
109 | " 0.264875 | \n",
110 | " 0.786298 | \n",
111 | " 0.453981 | \n",
112 | " 0.505267 | \n",
113 | " ... | \n",
114 | " 0.557840 | \n",
115 | " 0.480237 | \n",
116 | " 0.666938 | \n",
117 | " 0.336440 | \n",
118 | " 0.587290 | \n",
119 | " 0.446013 | \n",
120 | " 0.416345 | \n",
121 | " 0.313423 | \n",
122 | " 0.000105 | \n",
123 | " 0 | \n",
124 | "
\n",
125 | " \n",
126 | " 2 | \n",
127 | " 0.935217 | \n",
128 | " 0.753118 | \n",
129 | " 0.868141 | \n",
130 | " 0.268766 | \n",
131 | " 0.762329 | \n",
132 | " 0.281122 | \n",
133 | " 0.270177 | \n",
134 | " 0.788042 | \n",
135 | " 0.410603 | \n",
136 | " 0.513018 | \n",
137 | " ... | \n",
138 | " 0.565477 | \n",
139 | " 0.546030 | \n",
140 | " 0.678939 | \n",
141 | " 0.289354 | \n",
142 | " 0.559515 | \n",
143 | " 0.402727 | \n",
144 | " 0.415489 | \n",
145 | " 0.311911 | \n",
146 | " 0.014739 | \n",
147 | " 0 | \n",
148 | "
\n",
149 | " \n",
150 | " 3 | \n",
151 | " 0.941878 | \n",
152 | " 0.765304 | \n",
153 | " 0.868484 | \n",
154 | " 0.213661 | \n",
155 | " 0.765647 | \n",
156 | " 0.275559 | \n",
157 | " 0.266803 | \n",
158 | " 0.789434 | \n",
159 | " 0.414999 | \n",
160 | " 0.507585 | \n",
161 | " ... | \n",
162 | " 0.559734 | \n",
163 | " 0.510277 | \n",
164 | " 0.662607 | \n",
165 | " 0.223826 | \n",
166 | " 0.614245 | \n",
167 | " 0.389197 | \n",
168 | " 0.417669 | \n",
169 | " 0.314371 | \n",
170 | " 0.004807 | \n",
171 | " 0 | \n",
172 | "
\n",
173 | " \n",
174 | " 4 | \n",
175 | " 0.938617 | \n",
176 | " 0.776520 | \n",
177 | " 0.864251 | \n",
178 | " 0.269796 | \n",
179 | " 0.762975 | \n",
180 | " 0.263984 | \n",
181 | " 0.268968 | \n",
182 | " 0.782484 | \n",
183 | " 0.490950 | \n",
184 | " 0.524303 | \n",
185 | " ... | \n",
186 | " 0.561327 | \n",
187 | " 0.547271 | \n",
188 | " 0.663392 | \n",
189 | " 0.401270 | \n",
190 | " 0.566343 | \n",
191 | " 0.507497 | \n",
192 | " 0.420561 | \n",
193 | " 0.317490 | \n",
194 | " 0.002724 | \n",
195 | " 0 | \n",
196 | "
\n",
197 | " \n",
198 | "
\n",
199 | "
5 rows × 30 columns
\n",
200 | "
"
201 | ],
202 | "text/plain": [
203 | " V1 V2 V3 V4 V5 V6 V7 \\\n",
204 | "0 0.935192 0.766490 0.881365 0.313023 0.763439 0.267669 0.266815 \n",
205 | "1 0.978542 0.770067 0.840298 0.271796 0.766120 0.262192 0.264875 \n",
206 | "2 0.935217 0.753118 0.868141 0.268766 0.762329 0.281122 0.270177 \n",
207 | "3 0.941878 0.765304 0.868484 0.213661 0.765647 0.275559 0.266803 \n",
208 | "4 0.938617 0.776520 0.864251 0.269796 0.762975 0.263984 0.268968 \n",
209 | "\n",
210 | " V8 V9 V10 ... V21 V22 V23 V24 \\\n",
211 | "0 0.786444 0.475312 0.510600 ... 0.561184 0.522992 0.663793 0.391253 \n",
212 | "1 0.786298 0.453981 0.505267 ... 0.557840 0.480237 0.666938 0.336440 \n",
213 | "2 0.788042 0.410603 0.513018 ... 0.565477 0.546030 0.678939 0.289354 \n",
214 | "3 0.789434 0.414999 0.507585 ... 0.559734 0.510277 0.662607 0.223826 \n",
215 | "4 0.782484 0.490950 0.524303 ... 0.561327 0.547271 0.663392 0.401270 \n",
216 | "\n",
217 | " V25 V26 V27 V28 Amount Class \n",
218 | "0 0.585122 0.394557 0.418976 0.312697 0.005824 0 \n",
219 | "1 0.587290 0.446013 0.416345 0.313423 0.000105 0 \n",
220 | "2 0.559515 0.402727 0.415489 0.311911 0.014739 0 \n",
221 | "3 0.614245 0.389197 0.417669 0.314371 0.004807 0 \n",
222 | "4 0.566343 0.507497 0.420561 0.317490 0.002724 0 \n",
223 | "\n",
224 | "[5 rows x 30 columns]"
225 | ]
226 | },
227 | "execution_count": 2,
228 | "metadata": {},
229 | "output_type": "execute_result"
230 | }
231 | ],
232 | "source": [
233 | "df = pd.read_csv('../data/creditcard.csv', delimiter=',')\n",
234 | "df = df.drop(['Time'], axis=1)\n",
235 | "for col in df.columns[:-1]:\n",
236 | " min_val = df[col].min()\n",
237 | " max_val = df[col].max()\n",
238 | " if min_val != max_val:\n",
239 | " df[col] = (df[col] - min_val) / (max_val - min_val)\n",
240 | " \n",
241 | "df.head()"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "# Split to Train and Test"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 3,
254 | "metadata": {},
255 | "outputs": [
256 | {
257 | "name": "stdout",
258 | "output_type": "stream",
259 | "text": [
260 | "x shape: (284807, 29)\n",
261 | "y shape: (284807,)\n",
262 | "0 284315\n",
263 | "1 492\n",
264 | "Name: Class, dtype: int64\n"
265 | ]
266 | }
267 | ],
268 | "source": [
269 | "X = df.iloc[:,:-1]\n",
270 | "y = df.iloc[:, -1]\n",
271 | "\n",
272 | "print('x shape:', X.shape)\n",
273 | "print('y shape:', y.shape)\n",
274 | "print(y.value_counts())\n",
275 | "\n",
276 | "train_idx = y[y==0].index.values\n",
277 | "test_idx = y[y==1].index.values\n",
278 | "\n",
279 | "X_train = X.iloc[train_idx]\n",
280 | "y_train = y[train_idx]\n",
281 | "\n",
282 | "X_test = X.iloc[test_idx]\n",
283 | "y_test = y[test_idx]"
284 | ]
285 | },
286 | {
287 | "cell_type": "markdown",
288 | "metadata": {},
289 | "source": [
290 | "# Get explnation"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 4,
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "exp_model = ExplainAnomaliesUsingSHAP(num_anomalies_to_explain=10)"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 5,
305 | "metadata": {},
306 | "outputs": [
307 | {
308 | "name": "stdout",
309 | "output_type": "stream",
310 | "text": [
311 | "Model: \"model\"\n",
312 | "_________________________________________________________________\n",
313 | "Layer (type) Output Shape Param # \n",
314 | "=================================================================\n",
315 | "input_1 (InputLayer) [(None, 29)] 0 \n",
316 | "_________________________________________________________________\n",
317 | "dense (Dense) (None, 14) 420 \n",
318 | "_________________________________________________________________\n",
319 | "dense_1 (Dense) (None, 7) 105 \n",
320 | "_________________________________________________________________\n",
321 | "dense_2 (Dense) (None, 14) 112 \n",
322 | "_________________________________________________________________\n",
323 | "dense_3 (Dense) (None, 29) 435 \n",
324 | "=================================================================\n",
325 | "Total params: 1,072\n",
326 | "Trainable params: 1,072\n",
327 | "Non-trainable params: 0\n",
328 | "_________________________________________________________________\n",
329 | "WARNING:tensorflow:Falling back from v2 loop because of error: Failed to find data adapter that can handle input: , \n",
330 | "Train on 255883 samples, validate on 28432 samples\n",
331 | "Epoch 1/10\n",
332 | "255883/255883 - 8s - loss: 0.0024 - mse: 0.0023 - val_loss: 0.0012 - val_mse: 0.0012\n",
333 | "Epoch 2/10\n",
334 | "255883/255883 - 8s - loss: 0.0012 - mse: 0.0012 - val_loss: 9.7971e-04 - val_mse: 9.4585e-04\n",
335 | "Epoch 3/10\n",
336 | "255883/255883 - 8s - loss: 9.4128e-04 - mse: 9.0603e-04 - val_loss: 7.7618e-04 - val_mse: 7.3871e-04\n",
337 | "Epoch 4/10\n",
338 | "255883/255883 - 8s - loss: 7.9653e-04 - mse: 7.6116e-04 - val_loss: 7.5711e-04 - val_mse: 7.2357e-04\n",
339 | "Epoch 5/10\n",
340 | "255883/255883 - 8s - loss: 7.7851e-04 - mse: 7.4647e-04 - val_loss: 7.5049e-04 - val_mse: 7.1972e-04\n",
341 | "Epoch 6/10\n",
342 | "255883/255883 - 8s - loss: 7.7088e-04 - mse: 7.4122e-04 - val_loss: 7.4166e-04 - val_mse: 7.1278e-04\n",
343 | "Epoch 7/10\n",
344 | "255883/255883 - 9s - loss: 7.6278e-04 - mse: 7.3418e-04 - val_loss: 7.3703e-04 - val_mse: 7.0830e-04\n",
345 | "Epoch 8/10\n",
346 | "255883/255883 - 9s - loss: 7.5110e-04 - mse: 7.2236e-04 - val_loss: 7.3026e-04 - val_mse: 7.0129e-04\n",
347 | "Epoch 9/10\n",
348 | "255883/255883 - 9s - loss: 7.4219e-04 - mse: 7.1342e-04 - val_loss: 7.2397e-04 - val_mse: 6.9522e-04\n",
349 | "Epoch 10/10\n",
350 | "255883/255883 - 8s - loss: 7.3582e-04 - mse: 7.0741e-04 - val_loss: 7.2062e-04 - val_mse: 6.9238e-04\n",
351 | "WARNING:tensorflow:Falling back from v2 loop because of error: Failed to find data adapter that can handle input: , \n",
352 | "154587\n",
353 | "154684\n",
354 | "154371\n",
355 | "154234\n",
356 | "150644\n",
357 | "150647\n",
358 | "150665\n",
359 | "150654\n",
360 | "42528\n",
361 | "42635\n"
362 | ]
363 | }
364 | ],
365 | "source": [
366 | "all_sets_explaining_features = exp_model.explain_unsupervised_data(x_train=X_train, \n",
367 | " x_explain=X_test,\n",
368 | " return_shap_values=True)"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 6,
374 | "metadata": {},
375 | "outputs": [
376 | {
377 | "data": {
378 | "text/plain": [
379 | "{154587: [('V4', -1),\n",
380 | " ('V11', 0.0731265756822315),\n",
381 | " ('V15', 0.02369090910031154),\n",
382 | " ('V7', 0.016886018593518293),\n",
383 | " ('V3', 0.016570416382713796),\n",
384 | " ('V13', 0.015665877182100774),\n",
385 | " ('V10', 0.014052846919758964),\n",
386 | " ('V18', 0.00019320900654557518),\n",
387 | " ('V1', 0.00017568476566009978),\n",
388 | " ('V17', 0.00017336466464923974),\n",
389 | " ('V9', 0.0006888261077363056),\n",
390 | " ('V21', 0.0020651601972084717),\n",
391 | " ('V12', 0.0018995651805052225),\n",
392 | " ('V6', 0.001422672056074865),\n",
393 | " ('V8', -1)],\n",
394 | " 154684: [('V4', -1),\n",
395 | " ('V11', 0.06857125766422263),\n",
396 | " ('V15', 0.024017587211378415),\n",
397 | " ('V13', 0.01748191127094885),\n",
398 | " ('V7', 0.01707902271076717),\n",
399 | " ('V3', 0.01379392178383141),\n",
400 | " ('V10', 0.01240802991961949),\n",
401 | " ('V1', -1),\n",
402 | " ('V18', 0.0025376704074176696),\n",
403 | " ('V9', 0.000837353113461579),\n",
404 | " ('V17', 0.00358466848771998),\n",
405 | " ('V12', 0.002072845369657088),\n",
406 | " ('V21', 0.0020574753036938336),\n",
407 | " ('V22', 0.00046080208520149593),\n",
408 | " ('V8', -1),\n",
409 | " ('V24', 0.0003597983812910178)],\n",
410 | " 154371: [('V4', -1),\n",
411 | " ('V11', 0.08294142890489672),\n",
412 | " ('V15', 0.02898130105862798),\n",
413 | " ('V3', 0.01590907677201072),\n",
414 | " ('V7', 0.014711642443088473),\n",
415 | " ('V13', 0.014587836279879807),\n",
416 | " ('V10', 0.013403577555324252),\n",
417 | " ('V18', 0.0003123828531369758),\n",
418 | " ('V1', 0.0001992197589061244),\n",
419 | " ('V17', 0.00017078683719280304),\n",
420 | " ('V12', 0.002066954758604087),\n",
421 | " ('V24', 0.0016792627630605912),\n",
422 | " ('V21', 0.0016285431507732772)],\n",
423 | " 154234: [('V18', -1),\n",
424 | " ('V15', 0.04582284749101369),\n",
425 | " ('V4', 0.01298634659155869),\n",
426 | " ('V24', 0.005400751656493175),\n",
427 | " ('V11', 0.004962711118773501),\n",
428 | " ('V1', 0.004888440731168236),\n",
429 | " ('V9', 0.004203121378734948),\n",
430 | " ('V17', -1),\n",
431 | " ('V7', 0.0028306002556112014),\n",
432 | " ('V3', 0.0026911820955583705),\n",
433 | " ('V10', 0.0019216298091299933),\n",
434 | " ('V12', 0.0011375011534669778),\n",
435 | " ('V13', 0.01092450612555844),\n",
436 | " ('V16', 9.700471898336474e-05)],\n",
437 | " 150644: [('V12', -1),\n",
438 | " ('V11', 0.010486840442011657),\n",
439 | " ('V17', 0.004155066659667708),\n",
440 | " ('V4', 0.003106695971850944),\n",
441 | " ('V13', 0.0028908881702827925),\n",
442 | " ('V24', 0.001753354192488222),\n",
443 | " ('V14', 0.001660075517201605),\n",
444 | " ('V15', 0.004020149749684945),\n",
445 | " ('V18', 0.0034792307493295323),\n",
446 | " ('V10', 0.0012823640295097126),\n",
447 | " ('V16', 0.007824932111410356)],\n",
448 | " 150647: [('V12', -1),\n",
449 | " ('V11', 0.009069677191330994),\n",
450 | " ('V4', 0.00438623604807333),\n",
451 | " ('V17', 0.0036944688557951216),\n",
452 | " ('V26', 0.0025862953143419055),\n",
453 | " ('V24', 0.002056737456413358),\n",
454 | " ('V14', 0.0014729120689511745),\n",
455 | " ('V18', 0.0035930031667444967),\n",
456 | " ('V15', 0.001178029124017615),\n",
457 | " ('V16', 0.001203094323917563),\n",
458 | " ('V19', 0.0007287963722595273)],\n",
459 | " 150665: [('V17', -1),\n",
460 | " ('V24', 0.006989545665981571),\n",
461 | " ('V11', 0.006511480206894255),\n",
462 | " ('V4', 0.003946292279777723),\n",
463 | " ('V15', 0.003525485111059065),\n",
464 | " ('V18', 0.003038005743877573),\n",
465 | " ('V12', 0.002472561714746555),\n",
466 | " ('V16', 0.008939958118581916),\n",
467 | " ('V19', 0.007667361038725944)],\n",
468 | " 150654: [('V12', -1),\n",
469 | " ('V11', 0.009297464939338189),\n",
470 | " ('V4', 0.003955479118907272),\n",
471 | " ('V17', 0.003865460078225134),\n",
472 | " ('V24', 0.0019398804405763668),\n",
473 | " ('V26', 0.0018423242061232359),\n",
474 | " ('V14', 0.0016718447008473962),\n",
475 | " ('V18', 0.0033557790402469314),\n",
476 | " ('V3', 0.0012585499497945432),\n",
477 | " ('V7', 0.0011889138875021267),\n",
478 | " ('V16', 0.0010422633904812783)],\n",
479 | " 42528: [('V17', -1),\n",
480 | " ('V24', 0.006736426024435591),\n",
481 | " ('V18', 0.0050778753111189595),\n",
482 | " ('V12', 0.002842628989519194),\n",
483 | " ('V4', 0.0024794057912630846),\n",
484 | " ('V14', 0.002269837239830145),\n",
485 | " ('V11', 0.002218586625682887),\n",
486 | " ('V19', 0.0016399651294882627),\n",
487 | " ('V16', 0.0013763903992034172),\n",
488 | " ('V9', 0.004386193418645018),\n",
489 | " ('V26', 0.001635154304709785),\n",
490 | " ('V10', 0.0010430554925531486),\n",
491 | " ('V13', 0.0009982975334452363)],\n",
492 | " 42635: [('V17', -1),\n",
493 | " ('V24', 0.00605546732778159),\n",
494 | " ('V18', 0.004766048030447719),\n",
495 | " ('V4', 0.002795297282351012),\n",
496 | " ('V11', 0.0021679928865168505),\n",
497 | " ('V12', 0.0020859762105622697),\n",
498 | " ('V15', 0.0019091023771055922),\n",
499 | " ('V14', 0.0014634281110916776),\n",
500 | " ('V10', 0.001426458113448244),\n",
501 | " ('V19', 0.012697520780446276),\n",
502 | " ('V16', 0.01025355503344674),\n",
503 | " ('V13', 0.0018196956159556073)]}"
504 | ]
505 | },
506 | "execution_count": 6,
507 | "metadata": {},
508 | "output_type": "execute_result"
509 | }
510 | ],
511 | "source": [
512 | "all_sets_explaining_features"
513 | ]
514 | }
515 | ],
516 | "metadata": {
517 | "kernelspec": {
518 | "display_name": "Python 3",
519 | "language": "python",
520 | "name": "python3"
521 | },
522 | "language_info": {
523 | "codemirror_mode": {
524 | "name": "ipython",
525 | "version": 3
526 | },
527 | "file_extension": ".py",
528 | "mimetype": "text/x-python",
529 | "name": "python",
530 | "nbconvert_exporter": "python",
531 | "pygments_lexer": "ipython3",
532 | "version": "3.7.4"
533 | }
534 | },
535 | "nbformat": 4,
536 | "nbformat_minor": 2
537 | }
538 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronniemi/explainAnomaliesUsingSHAP/3df55220d459d27e7029158563178a40751732c5/requirements.txt
--------------------------------------------------------------------------------
/simulate perfect autoencoder data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/",
9 | "height": 33
10 | },
11 | "colab_type": "code",
12 | "executionInfo": {
13 | "elapsed": 1040,
14 | "status": "ok",
15 | "timestamp": 1566741408085,
16 | "user": {
17 | "displayName": "Ronnie Mindlin",
18 | "photoUrl": "",
19 | "userId": "17245682626254463986"
20 | },
21 | "user_tz": -180
22 | },
23 | "id": "3DXBYgkglrYm",
24 | "outputId": "c96b2248-17a5-4965-9433-20d0018ed7bb"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import numpy as np\n",
29 | "import pandas as pd"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {
36 | "colab": {},
37 | "colab_type": "code",
38 | "id": "60s3D0U7luAm"
39 | },
40 | "outputs": [],
41 | "source": [
42 | "n_records = 1000000\n",
43 | "n_features = 4\n",
44 | "\n",
45 | "features_names = []\n",
46 | "for feature in range(1, n_features+1):\n",
47 | " features_names.append('feature_' + str(feature))\n",
48 | "\n",
49 | "data = np.random.rand(n_records, n_features)"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 3,
55 | "metadata": {
56 | "colab": {},
57 | "colab_type": "code",
58 | "id": "39AvCMiImdCg"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "data_df = pd.DataFrame(data=data, columns=features_names)"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 4,
68 | "metadata": {
69 | "colab": {
70 | "base_uri": "https://localhost:8080/",
71 | "height": 190
72 | },
73 | "colab_type": "code",
74 | "executionInfo": {
75 | "elapsed": 1854,
76 | "status": "ok",
77 | "timestamp": 1566741408926,
78 | "user": {
79 | "displayName": "Ronnie Mindlin",
80 | "photoUrl": "",
81 | "userId": "17245682626254463986"
82 | },
83 | "user_tz": -180
84 | },
85 | "id": "k4NisjLRmg-b",
86 | "outputId": "c826d658-7b88-4fe9-efd5-0da960077899"
87 | },
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/html": [
92 | "\n",
93 | "\n",
106 | "
\n",
107 | " \n",
108 | " \n",
109 | " | \n",
110 | " feature_1 | \n",
111 | " feature_2 | \n",
112 | " feature_3 | \n",
113 | " feature_4 | \n",
114 | "
\n",
115 | " \n",
116 | " \n",
117 | " \n",
118 | " 0 | \n",
119 | " 0.651137 | \n",
120 | " 0.987213 | \n",
121 | " 0.806742 | \n",
122 | " 0.732936 | \n",
123 | "
\n",
124 | " \n",
125 | " 1 | \n",
126 | " 0.125991 | \n",
127 | " 0.601596 | \n",
128 | " 0.247740 | \n",
129 | " 0.909559 | \n",
130 | "
\n",
131 | " \n",
132 | " 2 | \n",
133 | " 0.567075 | \n",
134 | " 0.976509 | \n",
135 | " 0.032646 | \n",
136 | " 0.860562 | \n",
137 | "
\n",
138 | " \n",
139 | " 3 | \n",
140 | " 0.500608 | \n",
141 | " 0.692740 | \n",
142 | " 0.984498 | \n",
143 | " 0.520351 | \n",
144 | "
\n",
145 | " \n",
146 | " 4 | \n",
147 | " 0.623654 | \n",
148 | " 0.068421 | \n",
149 | " 0.563977 | \n",
150 | " 0.954146 | \n",
151 | "
\n",
152 | " \n",
153 | "
\n",
154 | "
"
155 | ],
156 | "text/plain": [
157 | " feature_1 feature_2 feature_3 feature_4\n",
158 | "0 0.651137 0.987213 0.806742 0.732936\n",
159 | "1 0.125991 0.601596 0.247740 0.909559\n",
160 | "2 0.567075 0.976509 0.032646 0.860562\n",
161 | "3 0.500608 0.692740 0.984498 0.520351\n",
162 | "4 0.623654 0.068421 0.563977 0.954146"
163 | ]
164 | },
165 | "execution_count": 4,
166 | "metadata": {},
167 | "output_type": "execute_result"
168 | }
169 | ],
170 | "source": [
171 | "data_df.head()"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 5,
177 | "metadata": {
178 | "colab": {},
179 | "colab_type": "code",
180 | "id": "I8StXZhwmro0"
181 | },
182 | "outputs": [],
183 | "source": [
184 | "data_df['feature_1_2'] = data_df['feature_1'] + data_df['feature_2']\n",
185 | "data_df['feature_3_4'] = data_df['feature_3'] + data_df['feature_4']\n",
186 | "\n",
187 | "data_df['class'] = np.zeros(n_records)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 6,
193 | "metadata": {
194 | "colab": {
195 | "base_uri": "https://localhost:8080/",
196 | "height": 190
197 | },
198 | "colab_type": "code",
199 | "executionInfo": {
200 | "elapsed": 1834,
201 | "status": "ok",
202 | "timestamp": 1566741408927,
203 | "user": {
204 | "displayName": "Ronnie Mindlin",
205 | "photoUrl": "",
206 | "userId": "17245682626254463986"
207 | },
208 | "user_tz": -180
209 | },
210 | "id": "w3aH7rcqnArk",
211 | "outputId": "0c800a44-e0ee-49a4-ca4b-03c6c070a974"
212 | },
213 | "outputs": [
214 | {
215 | "data": {
216 | "text/html": [
217 | "\n",
218 | "\n",
231 | "
\n",
232 | " \n",
233 | " \n",
234 | " | \n",
235 | " feature_1 | \n",
236 | " feature_2 | \n",
237 | " feature_3 | \n",
238 | " feature_4 | \n",
239 | " feature_1_2 | \n",
240 | " feature_3_4 | \n",
241 | " class | \n",
242 | "
\n",
243 | " \n",
244 | " \n",
245 | " \n",
246 | " 0 | \n",
247 | " 0.651137 | \n",
248 | " 0.987213 | \n",
249 | " 0.806742 | \n",
250 | " 0.732936 | \n",
251 | " 1.638351 | \n",
252 | " 1.539678 | \n",
253 | " 0.0 | \n",
254 | "
\n",
255 | " \n",
256 | " 1 | \n",
257 | " 0.125991 | \n",
258 | " 0.601596 | \n",
259 | " 0.247740 | \n",
260 | " 0.909559 | \n",
261 | " 0.727587 | \n",
262 | " 1.157300 | \n",
263 | " 0.0 | \n",
264 | "
\n",
265 | " \n",
266 | " 2 | \n",
267 | " 0.567075 | \n",
268 | " 0.976509 | \n",
269 | " 0.032646 | \n",
270 | " 0.860562 | \n",
271 | " 1.543584 | \n",
272 | " 0.893208 | \n",
273 | " 0.0 | \n",
274 | "
\n",
275 | " \n",
276 | " 3 | \n",
277 | " 0.500608 | \n",
278 | " 0.692740 | \n",
279 | " 0.984498 | \n",
280 | " 0.520351 | \n",
281 | " 1.193347 | \n",
282 | " 1.504849 | \n",
283 | " 0.0 | \n",
284 | "
\n",
285 | " \n",
286 | " 4 | \n",
287 | " 0.623654 | \n",
288 | " 0.068421 | \n",
289 | " 0.563977 | \n",
290 | " 0.954146 | \n",
291 | " 0.692075 | \n",
292 | " 1.518123 | \n",
293 | " 0.0 | \n",
294 | "
\n",
295 | " \n",
296 | "
\n",
297 | "
"
298 | ],
299 | "text/plain": [
300 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n",
301 | "0 0.651137 0.987213 0.806742 0.732936 1.638351 1.539678 0.0\n",
302 | "1 0.125991 0.601596 0.247740 0.909559 0.727587 1.157300 0.0\n",
303 | "2 0.567075 0.976509 0.032646 0.860562 1.543584 0.893208 0.0\n",
304 | "3 0.500608 0.692740 0.984498 0.520351 1.193347 1.504849 0.0\n",
305 | "4 0.623654 0.068421 0.563977 0.954146 0.692075 1.518123 0.0"
306 | ]
307 | },
308 | "execution_count": 6,
309 | "metadata": {},
310 | "output_type": "execute_result"
311 | }
312 | ],
313 | "source": [
314 | "data_df.head()"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 7,
320 | "metadata": {
321 | "colab": {},
322 | "colab_type": "code",
323 | "id": "q0o6tjtqoh7t"
324 | },
325 | "outputs": [],
326 | "source": [
327 | "n_to_change = 15000\n",
328 | "\n",
329 | "idx = data_df.sample(n=n_to_change).index\n",
330 | "\n",
331 | "for i in idx:\n",
332 | "# rand_vals = np.random.rand(2)\n",
333 | "# data_df.loc[i, 'feature_1_2'] = rand_vals[0]\n",
334 | "# data_df.loc[i, 'feature_3_4'] = rand_vals[1]\n",
335 | "# data_df.loc[i, 'class'] = '1'\n",
336 | " rand_num = np.random.rand(1)[0]\n",
337 | " if rand_num < 0.5:\n",
338 | " data_df.loc[i, 'feature_1_2'] = np.random.rand(1)[0]\n",
339 | " data_df.loc[i, 'class'] = '1_2'\n",
340 | " else:\n",
341 | " data_df.loc[i, 'feature_3_4'] = np.random.rand(1)[0]\n",
342 | " data_df.loc[i, 'class'] = '3_4'\n",
343 | " "
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 12,
349 | "metadata": {
350 | "colab": {
351 | "base_uri": "https://localhost:8080/",
352 | "height": 334
353 | },
354 | "colab_type": "code",
355 | "executionInfo": {
356 | "elapsed": 794547,
357 | "status": "ok",
358 | "timestamp": 1566742201654,
359 | "user": {
360 | "displayName": "Ronnie Mindlin",
361 | "photoUrl": "",
362 | "userId": "17245682626254463986"
363 | },
364 | "user_tz": -180
365 | },
366 | "id": "J_C1LZCY8tmY",
367 | "outputId": "05a0c369-b3f1-4e03-c34c-9775f5e2b2c3"
368 | },
369 | "outputs": [
370 | {
371 | "data": {
372 | "text/html": [
373 | "\n",
374 | "\n",
387 | "
\n",
388 | " \n",
389 | " \n",
390 | " | \n",
391 | " feature_1 | \n",
392 | " feature_2 | \n",
393 | " feature_3 | \n",
394 | " feature_4 | \n",
395 | " feature_1_2 | \n",
396 | " feature_3_4 | \n",
397 | " class | \n",
398 | "
\n",
399 | " \n",
400 | " \n",
401 | " \n",
402 | " 0 | \n",
403 | " 0.651137 | \n",
404 | " 0.987213 | \n",
405 | " 0.806742 | \n",
406 | " 0.732936 | \n",
407 | " 1.638351 | \n",
408 | " 1.539678 | \n",
409 | " 0 | \n",
410 | "
\n",
411 | " \n",
412 | " 1 | \n",
413 | " 0.125991 | \n",
414 | " 0.601596 | \n",
415 | " 0.247740 | \n",
416 | " 0.909559 | \n",
417 | " 0.727587 | \n",
418 | " 1.157300 | \n",
419 | " 0 | \n",
420 | "
\n",
421 | " \n",
422 | " 2 | \n",
423 | " 0.567075 | \n",
424 | " 0.976509 | \n",
425 | " 0.032646 | \n",
426 | " 0.860562 | \n",
427 | " 1.543584 | \n",
428 | " 0.893208 | \n",
429 | " 0 | \n",
430 | "
\n",
431 | " \n",
432 | " 3 | \n",
433 | " 0.500608 | \n",
434 | " 0.692740 | \n",
435 | " 0.984498 | \n",
436 | " 0.520351 | \n",
437 | " 1.193347 | \n",
438 | " 1.504849 | \n",
439 | " 0 | \n",
440 | "
\n",
441 | " \n",
442 | " 4 | \n",
443 | " 0.623654 | \n",
444 | " 0.068421 | \n",
445 | " 0.563977 | \n",
446 | " 0.954146 | \n",
447 | " 0.692075 | \n",
448 | " 1.518123 | \n",
449 | " 0 | \n",
450 | "
\n",
451 | " \n",
452 | " 5 | \n",
453 | " 0.124422 | \n",
454 | " 0.082783 | \n",
455 | " 0.909318 | \n",
456 | " 0.915040 | \n",
457 | " 0.207205 | \n",
458 | " 1.824358 | \n",
459 | " 0 | \n",
460 | "
\n",
461 | " \n",
462 | " 6 | \n",
463 | " 0.607873 | \n",
464 | " 0.708921 | \n",
465 | " 0.210336 | \n",
466 | " 0.417735 | \n",
467 | " 1.316794 | \n",
468 | " 0.628071 | \n",
469 | " 0 | \n",
470 | "
\n",
471 | " \n",
472 | " 7 | \n",
473 | " 0.040884 | \n",
474 | " 0.455809 | \n",
475 | " 0.681880 | \n",
476 | " 0.706162 | \n",
477 | " 0.496693 | \n",
478 | " 1.388042 | \n",
479 | " 0 | \n",
480 | "
\n",
481 | " \n",
482 | " 8 | \n",
483 | " 0.386996 | \n",
484 | " 0.364388 | \n",
485 | " 0.241929 | \n",
486 | " 0.003834 | \n",
487 | " 0.751385 | \n",
488 | " 0.245763 | \n",
489 | " 0 | \n",
490 | "
\n",
491 | " \n",
492 | " 9 | \n",
493 | " 0.206910 | \n",
494 | " 0.966390 | \n",
495 | " 0.903446 | \n",
496 | " 0.932776 | \n",
497 | " 1.173300 | \n",
498 | " 1.836222 | \n",
499 | " 0 | \n",
500 | "
\n",
501 | " \n",
502 | "
\n",
503 | "
"
504 | ],
505 | "text/plain": [
506 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n",
507 | "0 0.651137 0.987213 0.806742 0.732936 1.638351 1.539678 0\n",
508 | "1 0.125991 0.601596 0.247740 0.909559 0.727587 1.157300 0\n",
509 | "2 0.567075 0.976509 0.032646 0.860562 1.543584 0.893208 0\n",
510 | "3 0.500608 0.692740 0.984498 0.520351 1.193347 1.504849 0\n",
511 | "4 0.623654 0.068421 0.563977 0.954146 0.692075 1.518123 0\n",
512 | "5 0.124422 0.082783 0.909318 0.915040 0.207205 1.824358 0\n",
513 | "6 0.607873 0.708921 0.210336 0.417735 1.316794 0.628071 0\n",
514 | "7 0.040884 0.455809 0.681880 0.706162 0.496693 1.388042 0\n",
515 | "8 0.386996 0.364388 0.241929 0.003834 0.751385 0.245763 0\n",
516 | "9 0.206910 0.966390 0.903446 0.932776 1.173300 1.836222 0"
517 | ]
518 | },
519 | "execution_count": 12,
520 | "metadata": {},
521 | "output_type": "execute_result"
522 | }
523 | ],
524 | "source": [
525 | "data_df.head(10)"
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "execution_count": 13,
531 | "metadata": {},
532 | "outputs": [
533 | {
534 | "name": "stdout",
535 | "output_type": "stream",
536 | "text": [
537 | "1.9996749006122247\n"
538 | ]
539 | }
540 | ],
541 | "source": [
542 | "x_max_val = data_df.iloc[:,:-1].max().max()\n",
543 | "print(x_max_val)\n",
544 | "data_df.iloc[:,:-1] = data_df.iloc[:,:-1] / x_max_val"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": 14,
550 | "metadata": {},
551 | "outputs": [
552 | {
553 | "data": {
554 | "text/html": [
555 | "\n",
556 | "\n",
569 | "
\n",
570 | " \n",
571 | " \n",
572 | " | \n",
573 | " feature_1 | \n",
574 | " feature_2 | \n",
575 | " feature_3 | \n",
576 | " feature_4 | \n",
577 | " feature_1_2 | \n",
578 | " feature_3_4 | \n",
579 | " class | \n",
580 | "
\n",
581 | " \n",
582 | " \n",
583 | " \n",
584 | " 0 | \n",
585 | " 0.325622 | \n",
586 | " 0.493687 | \n",
587 | " 0.403437 | \n",
588 | " 0.366528 | \n",
589 | " 0.819309 | \n",
590 | " 0.769964 | \n",
591 | " 0 | \n",
592 | "
\n",
593 | " \n",
594 | " 1 | \n",
595 | " 0.063006 | \n",
596 | " 0.300847 | \n",
597 | " 0.123890 | \n",
598 | " 0.454853 | \n",
599 | " 0.363853 | \n",
600 | " 0.578744 | \n",
601 | " 0 | \n",
602 | "
\n",
603 | " \n",
604 | " 2 | \n",
605 | " 0.283584 | \n",
606 | " 0.488334 | \n",
607 | " 0.016326 | \n",
608 | " 0.430351 | \n",
609 | " 0.771917 | \n",
610 | " 0.446677 | \n",
611 | " 0 | \n",
612 | "
\n",
613 | " \n",
614 | " 3 | \n",
615 | " 0.250345 | \n",
616 | " 0.346426 | \n",
617 | " 0.492329 | \n",
618 | " 0.260218 | \n",
619 | " 0.596771 | \n",
620 | " 0.752547 | \n",
621 | " 0 | \n",
622 | "
\n",
623 | " \n",
624 | " 4 | \n",
625 | " 0.311878 | \n",
626 | " 0.034216 | \n",
627 | " 0.282034 | \n",
628 | " 0.477150 | \n",
629 | " 0.346094 | \n",
630 | " 0.759185 | \n",
631 | " 0 | \n",
632 | "
\n",
633 | " \n",
634 | " 5 | \n",
635 | " 0.062221 | \n",
636 | " 0.041398 | \n",
637 | " 0.454733 | \n",
638 | " 0.457594 | \n",
639 | " 0.103620 | \n",
640 | " 0.912327 | \n",
641 | " 0 | \n",
642 | "
\n",
643 | " \n",
644 | " 6 | \n",
645 | " 0.303986 | \n",
646 | " 0.354518 | \n",
647 | " 0.105185 | \n",
648 | " 0.208902 | \n",
649 | " 0.658504 | \n",
650 | " 0.314087 | \n",
651 | " 0 | \n",
652 | "
\n",
653 | " \n",
654 | " 7 | \n",
655 | " 0.020445 | \n",
656 | " 0.227942 | \n",
657 | " 0.340996 | \n",
658 | " 0.353139 | \n",
659 | " 0.248387 | \n",
660 | " 0.694134 | \n",
661 | " 0 | \n",
662 | "
\n",
663 | " \n",
664 | " 8 | \n",
665 | " 0.193530 | \n",
666 | " 0.182224 | \n",
667 | " 0.120984 | \n",
668 | " 0.001918 | \n",
669 | " 0.375753 | \n",
670 | " 0.122902 | \n",
671 | " 0 | \n",
672 | "
\n",
673 | " \n",
674 | " 9 | \n",
675 | " 0.103472 | \n",
676 | " 0.483274 | \n",
677 | " 0.451796 | \n",
678 | " 0.466464 | \n",
679 | " 0.586746 | \n",
680 | " 0.918260 | \n",
681 | " 0 | \n",
682 | "
\n",
683 | " \n",
684 | "
\n",
685 | "
"
686 | ],
687 | "text/plain": [
688 | " feature_1 feature_2 feature_3 feature_4 feature_1_2 feature_3_4 class\n",
689 | "0 0.325622 0.493687 0.403437 0.366528 0.819309 0.769964 0\n",
690 | "1 0.063006 0.300847 0.123890 0.454853 0.363853 0.578744 0\n",
691 | "2 0.283584 0.488334 0.016326 0.430351 0.771917 0.446677 0\n",
692 | "3 0.250345 0.346426 0.492329 0.260218 0.596771 0.752547 0\n",
693 | "4 0.311878 0.034216 0.282034 0.477150 0.346094 0.759185 0\n",
694 | "5 0.062221 0.041398 0.454733 0.457594 0.103620 0.912327 0\n",
695 | "6 0.303986 0.354518 0.105185 0.208902 0.658504 0.314087 0\n",
696 | "7 0.020445 0.227942 0.340996 0.353139 0.248387 0.694134 0\n",
697 | "8 0.193530 0.182224 0.120984 0.001918 0.375753 0.122902 0\n",
698 | "9 0.103472 0.483274 0.451796 0.466464 0.586746 0.918260 0"
699 | ]
700 | },
701 | "execution_count": 14,
702 | "metadata": {},
703 | "output_type": "execute_result"
704 | }
705 | ],
706 | "source": [
707 | "data_df.head(10)"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": 16,
713 | "metadata": {
714 | "colab": {},
715 | "colab_type": "code",
716 | "id": "4QW7Tk56nCYv"
717 | },
718 | "outputs": [],
719 | "source": [
720 | "data_df.to_csv('../data/perfect_autoencoder/random_data_' + str(n_features+2) + '_features.csv', index=False)"
721 | ]
722 | }
723 | ],
724 | "metadata": {
725 | "colab": {
726 | "collapsed_sections": [],
727 | "name": "simulate_data.ipynb",
728 | "provenance": [],
729 | "version": "0.3.2"
730 | },
731 | "kernelspec": {
732 | "display_name": "Python 3",
733 | "language": "python",
734 | "name": "python3"
735 | },
736 | "language_info": {
737 | "codemirror_mode": {
738 | "name": "ipython",
739 | "version": 3
740 | },
741 | "file_extension": ".py",
742 | "mimetype": "text/x-python",
743 | "name": "python",
744 | "nbconvert_exporter": "python",
745 | "pygments_lexer": "ipython3",
746 | "version": "3.7.4"
747 | }
748 | },
749 | "nbformat": 4,
750 | "nbformat_minor": 1
751 | }
752 |
--------------------------------------------------------------------------------