├── LICENSE
├── README.md
└── Structure learning_for_multivariate_timeseries_dataset.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Abiodun Ayodeji
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Structure-Learning-from-Time-series-Data-with-CausalNex
2 |
3 | ## Time series data structure learning with NOTEARS and DYNOTEARS
4 |
5 |
6 | ### Introduction
7 |
8 | You must have heard the phrase "correlation is not causation". The phrase explains one of the issues with the conventional deep learning approach especially for long sequence time series dataset - the learned pattern (correlation) may be spurious.
9 |
10 | To address the spurious correlation issue, a number of approaches such as causal machine learning, and graph neural network are being proposed. However, learning the graph (structure) in time series datasets is a non-trivial task. This repository demonstrates how to use a public library - [CausalNex](https://causalnex.readthedocs.io/en/latest/) - with NOTEARS and [DYNOTEARS](https://arxiv.org/abs/2002.00498) to learn the dependencies (graph/structure) in sensed parameters that define aircraft degradation history.
11 |
12 | Learning the connections (structures) between time-varying parameters, and using the learned connections to build a deep learning model would improve the performance of the resulting model. The learned structure can be used to build a graph neural network, with better predictive performance than the conventional correlation-based deep learning models. The CausalNex API is also flexible, as it enables the integration of domain knowledge (i.e.nodes and edges can be added or removed from the extracted graph)
13 |
14 | ### Data
15 |
16 | The dataset is the degradation data that defines turbofan engine run-to-failure history, provided by NASA.
17 |
18 | ### Usage
19 |
20 | See the notebook
21 |
22 | ### Links
23 |
24 | [1] Pamfil et al. 2020: DYNOTEARS: Structure Learning from Time-Series Data. ArXiv, 2020
25 |
26 | [2] CausalNex [API](https://causalnex.readthedocs.io/en/latest/causalnex.html)
27 |
--------------------------------------------------------------------------------
/Structure learning_for_multivariate_timeseries_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "anaconda-cloud": {},
6 | "kernelspec": {
7 | "display_name": "Python 3",
8 | "language": "python",
9 | "name": "python3"
10 | },
11 | "language_info": {
12 | "codemirror_mode": {
13 | "name": "ipython",
14 | "version": 3
15 | },
16 | "file_extension": ".py",
17 | "mimetype": "text/x-python",
18 | "name": "python",
19 | "nbconvert_exporter": "python",
20 | "pygments_lexer": "ipython3",
21 | "version": "3.7.7-final"
22 | },
23 | "colab": {
24 | "name": "Graph_neural_network_for_multivariate dataset.ipynb",
25 | "provenance": [],
26 | "collapsed_sections": []
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "code",
32 | "metadata": {
33 | "id": "0zmot6HPl4At",
34 | "colab": {
35 | "base_uri": "https://localhost:8080/"
36 | },
37 | "outputId": "46935114-fae6-4855-a8de-7df68568750b"
38 | },
39 | "source": [
40 | "!pip install causalNex"
41 | ],
42 | "execution_count": 1,
43 | "outputs": [
44 | {
45 | "output_type": "stream",
46 | "name": "stdout",
47 | "text": [
48 | "Collecting causalNex\n",
49 | " Downloading causalnex-0.11.0-py3-none-any.whl (154 kB)\n",
50 | "\u001b[K |████████████████████████████████| 154 kB 5.2 MB/s \n",
51 | "\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.14.2 in /usr/local/lib/python3.7/dist-packages (from causalNex) (1.19.5)\n",
52 | "Collecting pathos<0.3.0,>=0.2.7\n",
53 | " Downloading pathos-0.2.8-py2.py3-none-any.whl (81 kB)\n",
54 | "\u001b[K |████████████████████████████████| 81 kB 9.6 MB/s \n",
55 | "\u001b[?25hCollecting pgmpy<0.2.0,>=0.1.12\n",
56 | " Downloading pgmpy-0.1.17-py3-none-any.whl (1.9 MB)\n",
57 | "\u001b[K |████████████████████████████████| 1.9 MB 57.4 MB/s \n",
58 | "\u001b[?25hCollecting wrapt<1.13,>=1.11.0\n",
59 | " Downloading wrapt-1.12.1.tar.gz (27 kB)\n",
60 | "Requirement already satisfied: scipy<1.7,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from causalNex) (1.4.1)\n",
61 | "Collecting scikit-learn!=0.22.2.post1,!=0.24.1,<0.25.0,>=0.22.0\n",
62 | " Downloading scikit_learn-0.24.2-cp37-cp37m-manylinux2010_x86_64.whl (22.3 MB)\n",
63 | "\u001b[K |████████████████████████████████| 22.3 MB 1.2 MB/s \n",
64 | "\u001b[?25hRequirement already satisfied: networkx~=2.5 in /usr/local/lib/python3.7/dist-packages (from causalNex) (2.6.3)\n",
65 | "Requirement already satisfied: torch<2.0,>=1.7 in /usr/local/lib/python3.7/dist-packages (from causalNex) (1.10.0+cu111)\n",
66 | "Requirement already satisfied: pandas<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from causalNex) (1.1.5)\n",
67 | "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas<2.0,>=1.0->causalNex) (2.8.2)\n",
68 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas<2.0,>=1.0->causalNex) (2018.9)\n",
69 | "Collecting pox>=0.3.0\n",
70 | " Downloading pox-0.3.0-py2.py3-none-any.whl (30 kB)\n",
71 | "Collecting ppft>=1.6.6.4\n",
72 | " Downloading ppft-1.6.6.4-py3-none-any.whl (65 kB)\n",
73 | "\u001b[K |████████████████████████████████| 65 kB 2.8 MB/s \n",
74 | "\u001b[?25hRequirement already satisfied: multiprocess>=0.70.12 in /usr/local/lib/python3.7/dist-packages (from pathos<0.3.0,>=0.2.7->causalNex) (0.70.12.2)\n",
75 | "Requirement already satisfied: dill>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from pathos<0.3.0,>=0.2.7->causalNex) (0.3.4)\n",
76 | "Requirement already satisfied: pyparsing in /usr/local/lib/python3.7/dist-packages (from pgmpy<0.2.0,>=0.1.12->causalNex) (3.0.6)\n",
77 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from pgmpy<0.2.0,>=0.1.12->causalNex) (1.1.0)\n",
78 | "Requirement already satisfied: statsmodels in /usr/local/lib/python3.7/dist-packages (from pgmpy<0.2.0,>=0.1.12->causalNex) (0.10.2)\n",
79 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from pgmpy<0.2.0,>=0.1.12->causalNex) (4.62.3)\n",
80 | "Requirement already satisfied: six>=1.7.3 in /usr/local/lib/python3.7/dist-packages (from ppft>=1.6.6.4->pathos<0.3.0,>=0.2.7->causalNex) (1.15.0)\n",
81 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn!=0.22.2.post1,!=0.24.1,<0.25.0,>=0.22.0->causalNex) (3.0.0)\n",
82 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch<2.0,>=1.7->causalNex) (3.10.0.2)\n",
83 | "Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from statsmodels->pgmpy<0.2.0,>=0.1.12->causalNex) (0.5.2)\n",
84 | "Building wheels for collected packages: wrapt\n",
85 | " Building wheel for wrapt (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
86 | " Created wheel for wrapt: filename=wrapt-1.12.1-cp37-cp37m-linux_x86_64.whl size=68720 sha256=5efb082645a6ec5abb8b5551cf6b7ae0e2bbef7cb7668d8f70389ba471c9f111\n",
87 | " Stored in directory: /root/.cache/pip/wheels/62/76/4c/aa25851149f3f6d9785f6c869387ad82b3fd37582fa8147ac6\n",
88 | "Successfully built wrapt\n",
89 | "Installing collected packages: scikit-learn, ppft, pox, wrapt, pgmpy, pathos, causalNex\n",
90 | " Attempting uninstall: scikit-learn\n",
91 | " Found existing installation: scikit-learn 1.0.1\n",
92 | " Uninstalling scikit-learn-1.0.1:\n",
93 | " Successfully uninstalled scikit-learn-1.0.1\n",
94 | " Attempting uninstall: wrapt\n",
95 | " Found existing installation: wrapt 1.13.3\n",
96 | " Uninstalling wrapt-1.13.3:\n",
97 | " Successfully uninstalled wrapt-1.13.3\n",
98 | "Successfully installed causalNex-0.11.0 pathos-0.2.8 pgmpy-0.1.17 pox-0.3.0 ppft-1.6.6.4 scikit-learn-0.24.2 wrapt-1.12.1\n"
99 | ]
100 | }
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "source": [
106 | "import pandas as pd\n",
107 | "import matplotlib.pyplot as plt\n"
108 | ],
109 | "metadata": {
110 | "id": "eqxrYcMEW-TN"
111 | },
112 | "execution_count": null,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "code",
117 | "metadata": {
118 | "id": "ty-_q3IJlnqV"
119 | },
120 | "source": [
121 | "#libraries for causalnex"
122 | ],
123 | "execution_count": 3,
124 | "outputs": []
125 | },
126 | {
127 | "cell_type": "code",
128 | "metadata": {
129 | "id": "HVY-PurNlxja"
130 | },
131 | "source": [
132 | "import sys\n",
133 | "import networkx as nx\n",
134 | "import warnings\n",
135 | "from causalnex.structure import StructureModel\n",
136 | "from causalnex.structure.notears import from_pandas\n",
137 | "from IPython.display import Image\n",
138 | "from causalnex.plots import plot_structure, NODE_STYLE\n",
139 | "import networkx as nx\n",
140 | "from causalnex.structure.dynotears import from_pandas_dynamic "
141 | ],
142 | "execution_count": 4,
143 | "outputs": []
144 | },
145 | {
146 | "cell_type": "code",
147 | "metadata": {
148 | "id": "mr4LOxdXmD4k"
149 | },
150 | "source": [
151 | "from __future__ import print_function\n",
152 | "import scipy.sparse as sp\n",
153 | "from scipy.sparse.linalg.eigen.arpack import eigsh, ArpackNoConvergence"
154 | ],
155 | "execution_count": 5,
156 | "outputs": []
157 | },
158 | {
159 | "cell_type": "code",
160 | "metadata": {
161 | "tags": [],
162 | "colab": {
163 | "base_uri": "https://localhost:8080/"
164 | },
165 | "id": "pdW-bP3EU0Q8",
166 | "outputId": "8df3783c-45d2-41a4-a9ea-52266e19b52d"
167 | },
168 | "source": [
169 | "# Data ingestion - reading the datasets from Azure blob \n",
170 | "!wget http://azuremlsamples.azureml.net/templatedata/PM_train.txt\n",
171 | "!wget http://azuremlsamples.azureml.net/templatedata/PM_test.txt\n",
172 | "!wget http://azuremlsamples.azureml.net/templatedata/PM_truth.txt "
173 | ],
174 | "execution_count": 6,
175 | "outputs": [
176 | {
177 | "output_type": "stream",
178 | "name": "stdout",
179 | "text": [
180 | "--2021-12-31 06:08:44-- http://azuremlsamples.azureml.net/templatedata/PM_train.txt\n",
181 | "Resolving azuremlsamples.azureml.net (azuremlsamples.azureml.net)... 13.107.246.40, 13.107.213.40, 2620:1ec:bdf::40, ...\n",
182 | "Connecting to azuremlsamples.azureml.net (azuremlsamples.azureml.net)|13.107.246.40|:80... connected.\n",
183 | "HTTP request sent, awaiting response... 200 OK\n",
184 | "Length: 3515356 (3.4M) [text/plain]\n",
185 | "Saving to: ‘PM_train.txt’\n",
186 | "\n",
187 | "PM_train.txt 100%[===================>] 3.35M 21.2MB/s in 0.2s \n",
188 | "\n",
189 | "2021-12-31 06:08:44 (21.2 MB/s) - ‘PM_train.txt’ saved [3515356/3515356]\n",
190 | "\n",
191 | "--2021-12-31 06:08:45-- http://azuremlsamples.azureml.net/templatedata/PM_test.txt\n",
192 | "Resolving azuremlsamples.azureml.net (azuremlsamples.azureml.net)... 13.107.246.40, 13.107.213.40, 2620:1ec:bdf::40, ...\n",
193 | "Connecting to azuremlsamples.azureml.net (azuremlsamples.azureml.net)|13.107.246.40|:80... connected.\n",
194 | "HTTP request sent, awaiting response... 200 OK\n",
195 | "Length: 2228855 (2.1M) [text/plain]\n",
196 | "Saving to: ‘PM_test.txt’\n",
197 | "\n",
198 | "PM_test.txt 100%[===================>] 2.12M --.-KB/s in 0.1s \n",
199 | "\n",
200 | "2021-12-31 06:08:45 (14.2 MB/s) - ‘PM_test.txt’ saved [2228855/2228855]\n",
201 | "\n",
202 | "--2021-12-31 06:08:45-- http://azuremlsamples.azureml.net/templatedata/PM_truth.txt\n",
203 | "Resolving azuremlsamples.azureml.net (azuremlsamples.azureml.net)... 13.107.246.40, 13.107.213.40, 2620:1ec:bdf::40, ...\n",
204 | "Connecting to azuremlsamples.azureml.net (azuremlsamples.azureml.net)|13.107.246.40|:80... connected.\n",
205 | "HTTP request sent, awaiting response... 200 OK\n",
206 | "Length: 429 [text/plain]\n",
207 | "Saving to: ‘PM_truth.txt’\n",
208 | "\n",
209 | "PM_truth.txt 100%[===================>] 429 --.-KB/s in 0s \n",
210 | "\n",
211 | "2021-12-31 06:08:45 (65.6 MB/s) - ‘PM_truth.txt’ saved [429/429]\n",
212 | "\n"
213 | ]
214 | }
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "metadata": {
220 | "collapsed": true,
221 | "id": "vl5YBlGdU0Q8"
222 | },
223 | "source": [
224 | "# read training data \n",
225 | "train_df = pd.read_csv('PM_train.txt', sep=\" \", header=None)\n",
226 | "train_df.drop(train_df.columns[[26, 27]], axis=1, inplace=True)\n",
227 | "train_df.columns = ['id', 'cycle', 'setting1', 'setting2', 'setting3', 's1', 's2', 's3',\n",
228 | " 's4', 's5', 's6', 's7', 's8', 's9', 's10', 's11', 's12', 's13', 's14',\n",
229 | " 's15', 's16', 's17', 's18', 's19', 's20', 's21']"
230 | ],
231 | "execution_count": 7,
232 | "outputs": []
233 | },
234 | {
235 | "cell_type": "code",
236 | "metadata": {
237 | "collapsed": true,
238 | "id": "AusQ_xtqU0Q8"
239 | },
240 | "source": [
241 | "# read test data\n",
242 | "test_df = pd.read_csv('PM_test.txt', sep=\" \", header=None)\n",
243 | "test_df.drop(test_df.columns[[26, 27]], axis=1, inplace=True)\n",
244 | "test_df.columns = ['id', 'cycle', 'setting1', 'setting2', 'setting3', 's1', 's2', 's3',\n",
245 | " 's4', 's5', 's6', 's7', 's8', 's9', 's10', 's11', 's12', 's13', 's14',\n",
246 | " 's15', 's16', 's17', 's18', 's19', 's20', 's21']"
247 | ],
248 | "execution_count": 8,
249 | "outputs": []
250 | },
251 | {
252 | "cell_type": "code",
253 | "metadata": {
254 | "collapsed": true,
255 | "id": "Jbm6B7TaU0Q9"
256 | },
257 | "source": [
258 | "# read ground truth data\n",
259 | "truth_df = pd.read_csv('PM_truth.txt', sep=\" \", header=None)\n",
260 | "truth_df.drop(truth_df.columns[[1]], axis=1, inplace=True)"
261 | ],
262 | "execution_count": 9,
263 | "outputs": []
264 | },
265 | {
266 | "cell_type": "code",
267 | "metadata": {
268 | "colab": {
269 | "base_uri": "https://localhost:8080/",
270 | "height": 270
271 | },
272 | "id": "dxN5Z1dzU0Q-",
273 | "outputId": "b84da78e-b13a-4d8c-8789-48e7f2da3cb0"
274 | },
275 | "source": [
276 | "train_df = train_df.sort_values(['id','cycle'])\n",
277 | "train_df.head()"
278 | ],
279 | "execution_count": 10,
280 | "outputs": [
281 | {
282 | "output_type": "execute_result",
283 | "data": {
284 | "text/html": [
285 | "\n",
286 | "
\n",
287 | "
\n",
288 | "
\n",
289 | "\n",
302 | "
\n",
303 | " \n",
304 | " \n",
305 | " | \n",
306 | " id | \n",
307 | " cycle | \n",
308 | " setting1 | \n",
309 | " setting2 | \n",
310 | " setting3 | \n",
311 | " s1 | \n",
312 | " s2 | \n",
313 | " s3 | \n",
314 | " s4 | \n",
315 | " s5 | \n",
316 | " s6 | \n",
317 | " s7 | \n",
318 | " s8 | \n",
319 | " s9 | \n",
320 | " s10 | \n",
321 | " s11 | \n",
322 | " s12 | \n",
323 | " s13 | \n",
324 | " s14 | \n",
325 | " s15 | \n",
326 | " s16 | \n",
327 | " s17 | \n",
328 | " s18 | \n",
329 | " s19 | \n",
330 | " s20 | \n",
331 | " s21 | \n",
332 | "
\n",
333 | " \n",
334 | " \n",
335 | " \n",
336 | " 0 | \n",
337 | " 1 | \n",
338 | " 1 | \n",
339 | " -0.0007 | \n",
340 | " -0.0004 | \n",
341 | " 100.0 | \n",
342 | " 518.67 | \n",
343 | " 641.82 | \n",
344 | " 1589.70 | \n",
345 | " 1400.60 | \n",
346 | " 14.62 | \n",
347 | " 21.61 | \n",
348 | " 554.36 | \n",
349 | " 2388.06 | \n",
350 | " 9046.19 | \n",
351 | " 1.3 | \n",
352 | " 47.47 | \n",
353 | " 521.66 | \n",
354 | " 2388.02 | \n",
355 | " 8138.62 | \n",
356 | " 8.4195 | \n",
357 | " 0.03 | \n",
358 | " 392 | \n",
359 | " 2388 | \n",
360 | " 100.0 | \n",
361 | " 39.06 | \n",
362 | " 23.4190 | \n",
363 | "
\n",
364 | " \n",
365 | " 1 | \n",
366 | " 1 | \n",
367 | " 2 | \n",
368 | " 0.0019 | \n",
369 | " -0.0003 | \n",
370 | " 100.0 | \n",
371 | " 518.67 | \n",
372 | " 642.15 | \n",
373 | " 1591.82 | \n",
374 | " 1403.14 | \n",
375 | " 14.62 | \n",
376 | " 21.61 | \n",
377 | " 553.75 | \n",
378 | " 2388.04 | \n",
379 | " 9044.07 | \n",
380 | " 1.3 | \n",
381 | " 47.49 | \n",
382 | " 522.28 | \n",
383 | " 2388.07 | \n",
384 | " 8131.49 | \n",
385 | " 8.4318 | \n",
386 | " 0.03 | \n",
387 | " 392 | \n",
388 | " 2388 | \n",
389 | " 100.0 | \n",
390 | " 39.00 | \n",
391 | " 23.4236 | \n",
392 | "
\n",
393 | " \n",
394 | " 2 | \n",
395 | " 1 | \n",
396 | " 3 | \n",
397 | " -0.0043 | \n",
398 | " 0.0003 | \n",
399 | " 100.0 | \n",
400 | " 518.67 | \n",
401 | " 642.35 | \n",
402 | " 1587.99 | \n",
403 | " 1404.20 | \n",
404 | " 14.62 | \n",
405 | " 21.61 | \n",
406 | " 554.26 | \n",
407 | " 2388.08 | \n",
408 | " 9052.94 | \n",
409 | " 1.3 | \n",
410 | " 47.27 | \n",
411 | " 522.42 | \n",
412 | " 2388.03 | \n",
413 | " 8133.23 | \n",
414 | " 8.4178 | \n",
415 | " 0.03 | \n",
416 | " 390 | \n",
417 | " 2388 | \n",
418 | " 100.0 | \n",
419 | " 38.95 | \n",
420 | " 23.3442 | \n",
421 | "
\n",
422 | " \n",
423 | " 3 | \n",
424 | " 1 | \n",
425 | " 4 | \n",
426 | " 0.0007 | \n",
427 | " 0.0000 | \n",
428 | " 100.0 | \n",
429 | " 518.67 | \n",
430 | " 642.35 | \n",
431 | " 1582.79 | \n",
432 | " 1401.87 | \n",
433 | " 14.62 | \n",
434 | " 21.61 | \n",
435 | " 554.45 | \n",
436 | " 2388.11 | \n",
437 | " 9049.48 | \n",
438 | " 1.3 | \n",
439 | " 47.13 | \n",
440 | " 522.86 | \n",
441 | " 2388.08 | \n",
442 | " 8133.83 | \n",
443 | " 8.3682 | \n",
444 | " 0.03 | \n",
445 | " 392 | \n",
446 | " 2388 | \n",
447 | " 100.0 | \n",
448 | " 38.88 | \n",
449 | " 23.3739 | \n",
450 | "
\n",
451 | " \n",
452 | " 4 | \n",
453 | " 1 | \n",
454 | " 5 | \n",
455 | " -0.0019 | \n",
456 | " -0.0002 | \n",
457 | " 100.0 | \n",
458 | " 518.67 | \n",
459 | " 642.37 | \n",
460 | " 1582.85 | \n",
461 | " 1406.22 | \n",
462 | " 14.62 | \n",
463 | " 21.61 | \n",
464 | " 554.00 | \n",
465 | " 2388.06 | \n",
466 | " 9055.15 | \n",
467 | " 1.3 | \n",
468 | " 47.28 | \n",
469 | " 522.19 | \n",
470 | " 2388.04 | \n",
471 | " 8133.80 | \n",
472 | " 8.4294 | \n",
473 | " 0.03 | \n",
474 | " 393 | \n",
475 | " 2388 | \n",
476 | " 100.0 | \n",
477 | " 38.90 | \n",
478 | " 23.4044 | \n",
479 | "
\n",
480 | " \n",
481 | "
\n",
482 | "
\n",
483 | "
\n",
493 | " \n",
494 | " \n",
531 | "\n",
532 | " \n",
556 | "
\n",
557 | "
\n",
558 | " "
559 | ],
560 | "text/plain": [
561 | " id cycle setting1 setting2 setting3 ... s17 s18 s19 s20 s21\n",
562 | "0 1 1 -0.0007 -0.0004 100.0 ... 392 2388 100.0 39.06 23.4190\n",
563 | "1 1 2 0.0019 -0.0003 100.0 ... 392 2388 100.0 39.00 23.4236\n",
564 | "2 1 3 -0.0043 0.0003 100.0 ... 390 2388 100.0 38.95 23.3442\n",
565 | "3 1 4 0.0007 0.0000 100.0 ... 392 2388 100.0 38.88 23.3739\n",
566 | "4 1 5 -0.0019 -0.0002 100.0 ... 393 2388 100.0 38.90 23.4044\n",
567 | "\n",
568 | "[5 rows x 26 columns]"
569 | ]
570 | },
571 | "metadata": {},
572 | "execution_count": 10
573 | }
574 | ]
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "source": [
579 | "#Extract the graph (structural causal model) and the adjacency matrix"
580 | ],
581 | "metadata": {
582 | "id": "ztX1cULXlQNW"
583 | }
584 | },
585 | {
586 | "cell_type": "markdown",
587 | "source": [
588 | "# Two methods are used here. Causalnex is used to learn the structure in the dataset. \n",
589 | "#NOTEARS is used to get the adjacency matrix\n",
590 | "#Dynotears (Dynamic NOTEARS for times series) is also used to extract the adjacency matrix "
591 | ],
592 | "metadata": {
593 | "id": "yoxlGtSjmu8f"
594 | }
595 | },
596 | {
597 | "cell_type": "code",
598 | "metadata": {
599 | "id": "PxXjOU-CnP1V"
600 | },
601 | "source": [
602 | "#the structure is learned using only the sensors\n",
603 | "df_g = train_df.drop(train_df.columns[[0, 1, 2,3,4]], axis=1)"
604 | ],
605 | "execution_count": 13,
606 | "outputs": []
607 | },
608 | {
609 | "cell_type": "code",
610 | "metadata": {
611 | "id": "vGETGxuDnUty",
612 | "colab": {
613 | "base_uri": "https://localhost:8080/",
614 | "height": 206
615 | },
616 | "outputId": "51d2c743-7028-4361-c882-0864308f6417"
617 | },
618 | "source": [
619 | "df_g.head()"
620 | ],
621 | "execution_count": 14,
622 | "outputs": [
623 | {
624 | "output_type": "execute_result",
625 | "data": {
626 | "text/html": [
627 | "\n",
628 | " \n",
629 | "
\n",
630 | "
\n",
631 | "\n",
644 | "
\n",
645 | " \n",
646 | " \n",
647 | " | \n",
648 | " s1 | \n",
649 | " s2 | \n",
650 | " s3 | \n",
651 | " s4 | \n",
652 | " s5 | \n",
653 | " s6 | \n",
654 | " s7 | \n",
655 | " s8 | \n",
656 | " s9 | \n",
657 | " s10 | \n",
658 | " s11 | \n",
659 | " s12 | \n",
660 | " s13 | \n",
661 | " s14 | \n",
662 | " s15 | \n",
663 | " s16 | \n",
664 | " s17 | \n",
665 | " s18 | \n",
666 | " s19 | \n",
667 | " s20 | \n",
668 | " s21 | \n",
669 | "
\n",
670 | " \n",
671 | " \n",
672 | " \n",
673 | " 0 | \n",
674 | " 518.67 | \n",
675 | " 641.82 | \n",
676 | " 1589.70 | \n",
677 | " 1400.60 | \n",
678 | " 14.62 | \n",
679 | " 21.61 | \n",
680 | " 554.36 | \n",
681 | " 2388.06 | \n",
682 | " 9046.19 | \n",
683 | " 1.3 | \n",
684 | " 47.47 | \n",
685 | " 521.66 | \n",
686 | " 2388.02 | \n",
687 | " 8138.62 | \n",
688 | " 8.4195 | \n",
689 | " 0.03 | \n",
690 | " 392 | \n",
691 | " 2388 | \n",
692 | " 100.0 | \n",
693 | " 39.06 | \n",
694 | " 23.4190 | \n",
695 | "
\n",
696 | " \n",
697 | " 1 | \n",
698 | " 518.67 | \n",
699 | " 642.15 | \n",
700 | " 1591.82 | \n",
701 | " 1403.14 | \n",
702 | " 14.62 | \n",
703 | " 21.61 | \n",
704 | " 553.75 | \n",
705 | " 2388.04 | \n",
706 | " 9044.07 | \n",
707 | " 1.3 | \n",
708 | " 47.49 | \n",
709 | " 522.28 | \n",
710 | " 2388.07 | \n",
711 | " 8131.49 | \n",
712 | " 8.4318 | \n",
713 | " 0.03 | \n",
714 | " 392 | \n",
715 | " 2388 | \n",
716 | " 100.0 | \n",
717 | " 39.00 | \n",
718 | " 23.4236 | \n",
719 | "
\n",
720 | " \n",
721 | " 2 | \n",
722 | " 518.67 | \n",
723 | " 642.35 | \n",
724 | " 1587.99 | \n",
725 | " 1404.20 | \n",
726 | " 14.62 | \n",
727 | " 21.61 | \n",
728 | " 554.26 | \n",
729 | " 2388.08 | \n",
730 | " 9052.94 | \n",
731 | " 1.3 | \n",
732 | " 47.27 | \n",
733 | " 522.42 | \n",
734 | " 2388.03 | \n",
735 | " 8133.23 | \n",
736 | " 8.4178 | \n",
737 | " 0.03 | \n",
738 | " 390 | \n",
739 | " 2388 | \n",
740 | " 100.0 | \n",
741 | " 38.95 | \n",
742 | " 23.3442 | \n",
743 | "
\n",
744 | " \n",
745 | " 3 | \n",
746 | " 518.67 | \n",
747 | " 642.35 | \n",
748 | " 1582.79 | \n",
749 | " 1401.87 | \n",
750 | " 14.62 | \n",
751 | " 21.61 | \n",
752 | " 554.45 | \n",
753 | " 2388.11 | \n",
754 | " 9049.48 | \n",
755 | " 1.3 | \n",
756 | " 47.13 | \n",
757 | " 522.86 | \n",
758 | " 2388.08 | \n",
759 | " 8133.83 | \n",
760 | " 8.3682 | \n",
761 | " 0.03 | \n",
762 | " 392 | \n",
763 | " 2388 | \n",
764 | " 100.0 | \n",
765 | " 38.88 | \n",
766 | " 23.3739 | \n",
767 | "
\n",
768 | " \n",
769 | " 4 | \n",
770 | " 518.67 | \n",
771 | " 642.37 | \n",
772 | " 1582.85 | \n",
773 | " 1406.22 | \n",
774 | " 14.62 | \n",
775 | " 21.61 | \n",
776 | " 554.00 | \n",
777 | " 2388.06 | \n",
778 | " 9055.15 | \n",
779 | " 1.3 | \n",
780 | " 47.28 | \n",
781 | " 522.19 | \n",
782 | " 2388.04 | \n",
783 | " 8133.80 | \n",
784 | " 8.4294 | \n",
785 | " 0.03 | \n",
786 | " 393 | \n",
787 | " 2388 | \n",
788 | " 100.0 | \n",
789 | " 38.90 | \n",
790 | " 23.4044 | \n",
791 | "
\n",
792 | " \n",
793 | "
\n",
794 | "
\n",
795 | "
\n",
805 | " \n",
806 | " \n",
843 | "\n",
844 | " \n",
868 | "
\n",
869 | "
\n",
870 | " "
871 | ],
872 | "text/plain": [
873 | " s1 s2 s3 s4 s5 ... s17 s18 s19 s20 s21\n",
874 | "0 518.67 641.82 1589.70 1400.60 14.62 ... 392 2388 100.0 39.06 23.4190\n",
875 | "1 518.67 642.15 1591.82 1403.14 14.62 ... 392 2388 100.0 39.00 23.4236\n",
876 | "2 518.67 642.35 1587.99 1404.20 14.62 ... 390 2388 100.0 38.95 23.3442\n",
877 | "3 518.67 642.35 1582.79 1401.87 14.62 ... 392 2388 100.0 38.88 23.3739\n",
878 | "4 518.67 642.37 1582.85 1406.22 14.62 ... 393 2388 100.0 38.90 23.4044\n",
879 | "\n",
880 | "[5 rows x 21 columns]"
881 | ]
882 | },
883 | "metadata": {},
884 | "execution_count": 14
885 | }
886 | ]
887 | },
888 | {
889 | "cell_type": "code",
890 | "metadata": {
891 | "id": "z3D3ur-GmTxH"
892 | },
893 | "source": [
894 | "warnings.filterwarnings(\"ignore\") # silence warnings\n",
895 | "\n",
896 | "sm = StructureModel()"
897 | ],
898 | "execution_count": 15,
899 | "outputs": []
900 | },
901 | {
902 | "cell_type": "code",
903 | "metadata": {
904 | "id": "mqiZfLBNmXce"
905 | },
906 | "source": [
907 | "# the sm contains nodes and edges learned from the dataset\n",
908 | "#The nodes and edges a\n",
909 | "sm = from_pandas(df_g)"
910 | ],
911 | "execution_count": 16,
912 | "outputs": []
913 | },
914 | {
915 | "cell_type": "code",
916 | "metadata": {
917 | "id": "4tutIiWmmaXy"
918 | },
919 | "source": [
920 | "# Plot the graph of the strucutural model (sm)\n",
921 | "fig, ax = plt.subplots()\n",
922 | "nx.draw_circular(sm, ax=ax)\n",
923 | "fig.show()"
924 | ],
925 | "execution_count": 17,
926 | "outputs": []
927 | },
928 | {
929 | "cell_type": "code",
930 | "metadata": {
931 | "id": "SrV28iVzmeXl"
932 | },
933 | "source": [
934 | "#With CausalNex, You can also add and remove edges in a few lines of code. \n",
935 | "#The code in this cell removes some edges that falls below 0.8 threshold\n",
936 | "\n",
937 | "#sm.remove_edges_below_threshold(0.8)\n",
938 | "#fig, ax = plt.subplots()\n",
939 | "#nx.draw_circular(sm, ax=ax)\n",
940 | "#fig.show()"
941 | ],
942 | "execution_count": 57,
943 | "outputs": []
944 | },
945 | {
946 | "cell_type": "code",
947 | "metadata": {
948 | "id": "v9XFOSJpmlSe"
949 | },
950 | "source": [
951 | "#With CausalNex, you can also get the largest subgraph that represents the relationships in the data \n",
952 | "#sm = sm.get_largest_subgraph()\n",
953 | "#fig, ax = plt.subplots()\n",
954 | "#nx.draw_circular(sm, ax=ax)\n",
955 | "#fig.show()"
956 | ],
957 | "execution_count": 19,
958 | "outputs": []
959 | },
960 | {
961 | "cell_type": "code",
962 | "metadata": {
963 | "id": "lCr5jCO_mtsK"
964 | },
965 | "source": [
966 | "#Extract adjacency matrix using networkx\n",
967 | "adj_mtrx= nx.adjacency_matrix(sm)"
968 | ],
969 | "execution_count": 20,
970 | "outputs": []
971 | },
972 | {
973 | "cell_type": "code",
974 | "metadata": {
975 | "colab": {
976 | "base_uri": "https://localhost:8080/"
977 | },
978 | "id": "PB6GFZ-v1Hps",
979 | "outputId": "24de70a7-bf8b-4bce-ebf7-f89bc3801fd1"
980 | },
981 | "source": [
982 | "adj_mtrx"
983 | ],
984 | "execution_count": 21,
985 | "outputs": [
986 | {
987 | "output_type": "execute_result",
988 | "data": {
989 | "text/plain": [
990 | "<21x21 sparse matrix of type ''\n",
991 | "\twith 420 stored elements in Compressed Sparse Row format>"
992 | ]
993 | },
994 | "metadata": {},
995 | "execution_count": 21
996 | }
997 | ]
998 | },
999 | {
1000 | "cell_type": "code",
1001 | "source": [
1002 | "# The dynamic notears (DYNOTEARS) function is used for dynamic datasets. \n",
1003 | "# Graph can also be extracted with DYNOTEARS "
1004 | ],
1005 | "metadata": {
1006 | "id": "tGuLgc-pHu_n"
1007 | },
1008 | "execution_count": null,
1009 | "outputs": []
1010 | },
1011 | {
1012 | "cell_type": "code",
1013 | "metadata": {
1014 | "id": "_Db-Qhz_neln",
1015 | "colab": {
1016 | "base_uri": "https://localhost:8080/"
1017 | },
1018 | "outputId": "d172cb22-01d3-4e8a-eb56-cb3901d5c2bf"
1019 | },
1020 | "source": [
1021 | "#Here is the DYNOTEARS implementation\n",
1022 | "g_learnt = from_pandas_dynamic(df_g,1,lambda_w=.1,lambda_a=.1,w_threshold=.1)\n",
1023 | "g_learnt"
1024 | ],
1025 | "execution_count": 22,
1026 | "outputs": [
1027 | {
1028 | "output_type": "execute_result",
1029 | "data": {
1030 | "text/plain": [
1031 | ""
1032 | ]
1033 | },
1034 | "metadata": {},
1035 | "execution_count": 22
1036 | }
1037 | ]
1038 | },
1039 | {
1040 | "cell_type": "code",
1041 | "metadata": {
1042 | "id": "UF2SyIM7ntIR"
1043 | },
1044 | "source": [
1045 | "#from copy import deepcopy\n",
1046 | "#g_learnt_2 = deepcopy(g_learnt)\n",
1047 | "#g_learnt_2.remove_edges_below_threshold(.8)\n",
1048 | "fig, ax = plt.subplots()\n",
1049 | "nx.draw_circular(g_learnt, ax=ax)\n",
1050 | "fig.show()"
1051 | ],
1052 | "execution_count": 23,
1053 | "outputs": []
1054 | },
1055 | {
1056 | "cell_type": "code",
1057 | "metadata": {
1058 | "id": "Qf_mzJH9nxCC",
1059 | "colab": {
1060 | "base_uri": "https://localhost:8080/"
1061 | },
1062 | "outputId": "8428e330-c076-4f79-bfb2-9c467aed525d"
1063 | },
1064 | "source": [
1065 | "adj_mtrx_dyno= nx.adjacency_matrix(g_learnt)\n",
1066 | "adj_mtrx_dyno"
1067 | ],
1068 | "execution_count": 24,
1069 | "outputs": [
1070 | {
1071 | "output_type": "execute_result",
1072 | "data": {
1073 | "text/plain": [
1074 | "<42x42 sparse matrix of type ''\n",
1075 | "\twith 20 stored elements in Compressed Sparse Row format>"
1076 | ]
1077 | },
1078 | "metadata": {},
1079 | "execution_count": 24
1080 | }
1081 | ]
1082 | },
1083 | {
1084 | "cell_type": "code",
1085 | "metadata": {
1086 | "id": "O29Lmdmkn7O8"
1087 | },
1088 | "source": [
1089 | "# Utility functions to preprocess the adjacency matrix\n",
1090 | "def normalize_features(mx):\n",
1091 | " \"\"\"Row-normalize sparse matrix\"\"\"\n",
1092 | " rowsum = np.array(mx.sum(1))\n",
1093 | " r_inv = np.power(rowsum, -1).flatten()\n",
1094 | " r_inv[np.isinf(r_inv)] = 0.\n",
1095 | " r_mat_inv = sp.diags(r_inv)\n",
1096 | " mx = r_mat_inv.dot(mx)\n",
1097 | " return mx\n",
1098 | "\n",
1099 | "\n",
1100 | "def normalize_adj(adj, symmetric=True):\n",
1101 | " if symmetric:\n",
1102 | " d = sp.diags(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0)\n",
1103 | " a_norm = adj.dot(d).transpose().dot(d).tocsr()\n",
1104 | " else:\n",
1105 | " d = sp.diags(np.power(np.array(adj.sum(1)), -1).flatten(), 0)\n",
1106 | " a_norm = d.dot(adj).tocsr()\n",
1107 | " return a_norm\n",
1108 | "\n",
1109 | "\n",
1110 | "def normalize_adj_numpy(adj, symmetric=True):\n",
1111 | " if symmetric:\n",
1112 | " d = np.diag(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0)\n",
1113 | " a_norm = adj.dot(d).transpose().dot(d)\n",
1114 | " else:\n",
1115 | " d = np.diag(np.power(np.array(adj.sum(1)), -1).flatten(), 0)\n",
1116 | " a_norm = d.dot(adj)\n",
1117 | " return a_norm\n",
1118 | "\n",
1119 | "\n",
1120 | "def preprocess_adj(adj, symmetric=True):\n",
1121 | " adj = adj + sp.eye(adj.shape[0])\n",
1122 | " adj = normalize_adj(adj, symmetric)\n",
1123 | " return adj\n",
1124 | "\n",
1125 | "\n",
1126 | "def preprocess_adj_numpy(adj, symmetric=True):\n",
1127 | " adj = adj + np.eye(adj.shape[0])\n",
1128 | " adj = normalize_adj_numpy(adj, symmetric)\n",
1129 | " return adj\n",
1130 | "\n",
1131 | "\n",
1132 | "def preprocess_adj_tensor(adj_tensor, symmetric=True):\n",
1133 | " adj_out_tensor = []\n",
1134 | " for i in range(adj_tensor.shape[0]):\n",
1135 | " adj = adj_tensor[i]\n",
1136 | " adj = adj + np.eye(adj.shape[0])\n",
1137 | " adj = normalize_adj_numpy(adj, symmetric)\n",
1138 | " adj_out_tensor.append(adj)\n",
1139 | " adj_out_tensor = np.array(adj_out_tensor)\n",
1140 | " return adj_out_tensor\n",
1141 | "\n",
1142 | "\n",
1143 | "def preprocess_adj_tensor_with_identity(adj_tensor, symmetric=True):\n",
1144 | " adj_out_tensor = []\n",
1145 | " for i in range(adj_tensor.shape[0]):\n",
1146 | " adj = adj_tensor[i]\n",
1147 | " adj = adj + np.eye(adj.shape[0])\n",
1148 | " adj = normalize_adj_numpy(adj, symmetric)\n",
1149 | " adj = np.concatenate([np.eye(adj.shape[0]), adj], axis=0)\n",
1150 | " adj_out_tensor.append(adj)\n",
1151 | " adj_out_tensor = np.array(adj_out_tensor)\n",
1152 | " return adj_out_tensor\n",
1153 | "\n",
1154 | "\n",
1155 | "def preprocess_adj_tensor_with_identity_concat(adj_tensor, symmetric=True):\n",
1156 | " adj_out_tensor = []\n",
1157 | " for i in range(adj_tensor.shape[0]):\n",
1158 | " adj = adj_tensor[i]\n",
1159 | " adj = adj + np.eye(adj.shape[0])\n",
1160 | " adj = normalize_adj_numpy(adj, symmetric)\n",
1161 | " adj = np.concatenate([np.eye(adj.shape[0]), adj], axis=0)\n",
1162 | " adj_out_tensor.append(adj)\n",
1163 | " adj_out_tensor = np.concatenate(adj_out_tensor, axis=0)\n",
1164 | " return adj_out_tensor\n",
1165 | "\n",
1166 | "def preprocess_adj_tensor_concat(adj_tensor, symmetric=True):\n",
1167 | " adj_out_tensor = []\n",
1168 | " for i in range(adj_tensor.shape[0]):\n",
1169 | " adj = adj_tensor[i]\n",
1170 | " adj = adj + np.eye(adj.shape[0])\n",
1171 | " adj = normalize_adj_numpy(adj, symmetric)\n",
1172 | " adj_out_tensor.append(adj)\n",
1173 | " adj_out_tensor = np.concatenate(adj_out_tensor, axis=0)\n",
1174 | " return adj_out_tensor\n",
1175 | "\n",
1176 | "def preprocess_edge_adj_tensor(edge_adj_tensor, symmetric=True):\n",
1177 | " edge_adj_out_tensor = []\n",
1178 | " num_edge_features = int(edge_adj_tensor.shape[1]/edge_adj_tensor.shape[2])\n",
1179 | "\n",
1180 | " for i in range(edge_adj_tensor.shape[0]):\n",
1181 | " edge_adj = edge_adj_tensor[i]\n",
1182 | " edge_adj = np.split(edge_adj, num_edge_features, axis=0)\n",
1183 | " edge_adj = np.array(edge_adj)\n",
1184 | " edge_adj = preprocess_adj_tensor_concat(edge_adj, symmetric)\n",
1185 | " edge_adj_out_tensor.append(edge_adj)\n",
1186 | "\n",
1187 | " edge_adj_out_tensor = np.array(edge_adj_out_tensor)\n",
1188 | " return edge_adj_out_tensor\n",
1189 | "\n"
1190 | ],
1191 | "execution_count": 25,
1192 | "outputs": []
1193 | },
1194 | {
1195 | "cell_type": "code",
1196 | "source": [
1197 | "#Next is to develop a graph neural network\n",
1198 | "#This can be done using a bootstrap approach or with stellargrah/Spektral"
1199 | ],
1200 | "metadata": {
1201 | "id": "dIQAcfu9Lntg"
1202 | },
1203 | "execution_count": null,
1204 | "outputs": []
1205 | }
1206 | ]
1207 | }
--------------------------------------------------------------------------------