├── .DS_Store
├── Input
├── .DS_Store
├── alarm.rds
└── ecoli70.rds
├── output
└── .DS_Store
├── lib
├── __pycache__
│ ├── hc.cpython-37.pyc
│ ├── mmhc.cpython-37.pyc
│ ├── mmpc.cpython-37.pyc
│ ├── accessory.cpython-37.pyc
│ └── evaluation.cpython-37.pyc
├── evaluation.py
├── mmhc.py
├── mmpc.py
├── hc.py
└── accessory.py
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── test.iml
└── workspace.xml
├── README.md
├── BN_structure_learning.py
└── .Rhistory
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/.DS_Store
--------------------------------------------------------------------------------
/Input/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/.DS_Store
--------------------------------------------------------------------------------
/Input/alarm.rds:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/alarm.rds
--------------------------------------------------------------------------------
/Input/ecoli70.rds:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/ecoli70.rds
--------------------------------------------------------------------------------
/output/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/output/.DS_Store
--------------------------------------------------------------------------------
/lib/__pycache__/hc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/hc.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/mmhc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/mmhc.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/mmpc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/mmpc.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/accessory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/accessory.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/evaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/evaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is an implementation of MMHC in python. All data sets and models are placed in the "Input" folder and the results are generated to
3 | the "Output" folder.
4 |
5 | If you want to test your own data set, just put it in the "Input" folder and change the corresponding variable in
6 | "BN_structure_learning" file which is also an example file for running the code. If you want to evaluate the result, please provide the ground truth model with bnlearn format.
7 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/lib/evaluation.py:
--------------------------------------------------------------------------------
1 | # evaluation methods
2 | from rpy2.robjects.packages import importr
3 | base, bnlearn = importr('base'), importr('bnlearn')
4 |
5 | # compute the F1 score of a learned graph given true graph
6 | def f1(dag_true, dag_learned):
7 | '''
8 | :param dag_true: true DAG
9 | :param dag_learned: learned DAG
10 | :return: the F1 score of learned DAG
11 | '''
12 | compare = bnlearn.compare(bnlearn.cpdag(dag_true), bnlearn.cpdag(dag_learned))
13 | return compare[0][0] * 2 / (compare[0][0] * 2 + compare[1][0] + compare[2][0])
--------------------------------------------------------------------------------
/.idea/test.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/BN_structure_learning.py:
--------------------------------------------------------------------------------
1 | import random
2 | import urllib.request
3 | from os import path
4 |
5 | import pandas as pd
6 | from graphviz import Digraph
7 | from rpy2.robjects import pandas2ri
8 | from rpy2.robjects.packages import importr
9 |
10 | from lib.evaluation import f1
11 | from lib.mmhc import mmhc
12 |
13 | pandas2ri.activate()
14 | base, bnlearn = importr('base'), importr('bnlearn')
15 |
16 | # load network
17 | network = 'alarm'
18 | network_path = 'Input/' + network + '.rds'
19 | if not path.isfile(network_path):
20 | url = 'https://www.bnlearn.com/bnrepository/' + network + '/' + network + '.rds'
21 | urllib.request.urlretrieve(url, network_path)
22 | dag_true = base.readRDS(network_path)
23 |
24 | # generate data
25 | datasize = 10000
26 | filename = 'Input/' + network + '_' + str(datasize) + '.csv'
27 | if path.isfile(filename):
28 | data = pd.read_csv(filename, dtype='category') # change dtype = 'float64'/'category' if data is continuous/categorical
29 | else:
30 | data = bnlearn.rbn(dag_true, datasize)
31 | data = data[random.sample(list(data.columns), data.shape[1])]
32 | data.to_csv(filename, index=False)
33 |
34 |
35 | # learn bayesian network from data
36 | dag_learned = mmhc(data)
37 |
38 | # plot the learned graph
39 | dot = Digraph()
40 | for node in bnlearn.nodes(dag_learned):
41 | dot.node(node)
42 | for parent in bnlearn.parents(dag_learned, node):
43 | dot.edge(node, parent)
44 | dot.render('output/' + network + '_' + str(datasize) + '.gv', view = False)
45 |
46 | # evaluate the learned graph
47 | print('f1 score is ' + str(f1(dag_true, dag_learned)))
48 | print('shd score is ' + str(bnlearn.shd(bnlearn.cpdag(dag_true), dag_learned)[0]))
--------------------------------------------------------------------------------
/lib/mmhc.py:
--------------------------------------------------------------------------------
1 | from lib.mmpc import mmpc_forward, mmpc_backward, symmetry
2 | from lib.hc import hc
3 | import time
4 | import numpy as np
5 |
6 | def mmhc(data, test = None, score = None, prune = True, threshold = 0.05):
7 | '''
8 | mmhc algorithm
9 | :param data: input data (pandas dataframe)
10 | :param test: type of independence test (currently support g-test (for discrete data), z-test (for continuous data))
11 | :param score: type of score function (currently support bic (for both discrete and continuous data))
12 | :param prune: whether use prune method
13 | :param threshold: threshold for CI test
14 | :return: the DAG learned from data (bnlearn format)
15 | '''
16 |
17 | # initialise pc set as empty for all variables
18 | pc = {}
19 |
20 | # initialise the candidate set for variables
21 | can = {}
22 | for tar in data:
23 | can[tar] = list(data.columns)
24 | can[tar].remove(tar)
25 |
26 | # preprocess the data
27 | varnames = list(data.columns)
28 | if all(data[var].dtype.name == 'category' for var in data):
29 | arities = np.array(data.nunique())
30 | data = data.apply(lambda x: x.cat.codes).to_numpy()
31 | if test is None:
32 | test = 'g-test'
33 | if score is None:
34 | score = 'bic'
35 | elif all(data[var].dtype.name != 'category' for var in data):
36 | arities = None
37 | if test is None:
38 | test = 'z-test'
39 | if score is None or score == 'bic':
40 | score = 'bic_g'
41 | else:
42 | raise Exception('Mixed data is not supported.')
43 |
44 | # run MMPC on each variable
45 | start = time.time()
46 | for tar in varnames:
47 | # forward phase
48 | pc[tar] = []
49 | pc[tar], can = mmpc_forward(tar, pc[tar], can, data, arities, varnames, prune, test, threshold)
50 | # backward phase
51 | if pc[tar]:
52 | pc[tar], can = mmpc_backward(tar, pc[tar], can, data, arities, varnames, prune, test, threshold)
53 | # check the symmetry of pc set
54 | # when the number of variables is large, this function may be computational costly
55 | # this function can be merged into the pruning process during forward and backward mmpc by transmitting the whole
56 | # pc set into mmpc_forward and mmpc_backward
57 | pc = symmetry(pc)
58 | print('MMPC phase costs %.2f seconds' % (time.time() - start))
59 | # run hill-climbing
60 | start = time.time()
61 | dag = hc(data, arities, varnames, pc, score)
62 | print('HC phase costs %.2f seconds' % (time.time() - start))
63 | return dag
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | 1555656258250
86 |
87 |
88 | 1555656258250
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/.Rhistory:
--------------------------------------------------------------------------------
1 | install.packages('devtools')
2 | library(devtools)
3 | install_github('andreacirilloac/updateR')
4 | library(updateR)
5 | library(readr)
6 | X1_MCAR <- read_csv("Documents/GitHub/SFMV/data/1/1_MCAR.csv")
7 | View(X1_MCAR)
8 | data = X1_MCAR
9 | data = as.data.frame(lapply(data, as.factor))
10 | library(bnlearn)
11 | structural.em(data[1:100, ])
12 | dag = structural.em(data[1:100, ])
13 | View(dag)
14 | dag
15 | arcs(dag)
16 | structural.em(data[1:100, ], score = 'bdeu')
17 | structural.em(data[1:100, ], maximize.args = list(score='bdeu'))
18 | structural.em(data[1:100, ], maximize.args = list(score='BDe'))
19 | structural.em(data[1:100, ], maximize.args = list(score='bde'))
20 | 10-0.9*1.3
21 | 5+0.9*0.4
22 | source("~/.active-rstudio-document")
23 | View(V)
24 | source("~/.active-rstudio-document")
25 | View(V)
26 | -1-0.9*1.3
27 | -1-1.9*0.9
28 | -4-0.9*(1+1.3+1.9+1.9)
29 | (-4-0.9*(1+1.3+1.9+1.9))/4
30 | (-2-0.9*(1+1.3))/2
31 | (-4-0.9*(1+1.3+2-3.3))/4
32 | (-4-0.9*(0.1+0.4+0.7+3))/4
33 | (-4+0.9*(0.1+0.4+0.7+3))/4
34 | (-1+0.9*(0.1+0.4+0.7+3))/4
35 | (-1+0.9*(-1-1.3-1.9-1.9))/4
36 | (-1+0.9*(0.4+0.4+0.7+2.3))/4
37 | (-4+0.9*(0.4+0.4+0.7+2.3))/4
38 | ((0.4+0.4+0.7+2.3))/4
39 | (-1 + (0.4+0.4+0.7+2.3))/4
40 | (-1 + (-1-1.3-1.9-1.9))/4
41 | -1 + ((-1-1.3-1.9-1.9))/4
42 | -1 + (0.9 * (-1-1.3-1.9-1.9))/4
43 | -1 + (0.9 * (1.5+8.8+0.7+2.3))/4
44 | -1 + ((1.5+8.8+0.7+2.3))/4
45 | (2-0.3-0.9-1.4)/4
46 | (-4 + 0.9*(3+0.8-0.4))/4
47 | (-4 + (3+0.8-0.4))/4
48 | -1+0.9*24.4
49 | -1+24.4
50 | -1+24.4*0.9
51 | -1+24.4*0.95
52 | -1+22*0.9
53 | 22*0.9
54 | 24.4*0.9
55 | 16*0.9
56 | 0.9*(-1-1.3-1-1)/4
57 | 3.4*0.9/4
58 | 3/4*0.9
59 | 24.4*0.9
60 | 22*0.9
61 | 22*0.9*0.9
62 | 3.4/4
63 | 3.4/4*0.9
64 | 0.781-0.025+0.184*2
65 | (0.781-0.025+0.184*2)/4
66 | source("~/.active-rstudio-document")
67 | View(V)
68 | source("~/.active-rstudio-document")
69 | View(V)
70 | source("~/.active-rstudio-document")
71 | View(V)
72 | source("~/.active-rstudio-document")
73 | View(V)
74 | View(V)
75 | View(V)
76 | source("~/.active-rstudio-document")
77 | View(V)
78 | source("~/.active-rstudio-document")
79 | View(V)
80 | source("~/.active-rstudio-document")
81 | source("~/.active-rstudio-document")
82 | View(V)
83 | source("~/.active-rstudio-document")
84 | source("~/.active-rstudio-document")
85 | View(V)
86 | debugSource("~/Documents/R/base/RL.R")
87 | View(V)
88 | View(V)
89 | source("~/Documents/R/base/RL.R")
90 | source("~/Documents/R/base/RL.R")
91 | source("~/Documents/R/base/RL.R")
92 | source("~/Documents/R/base/RL.R")
93 | source("~/Documents/R/base/RL.R")
94 | source("~/Documents/R/base/RL.R")
95 | source("~/Documents/R/base/RL.R")
96 | source("~/Documents/R/base/RL.R")
97 | source("~/Documents/R/base/RL.R")
98 | source("~/Documents/R/base/RL.R")
99 | source("~/Documents/R/base/RL.R")
100 | source("~/Documents/R/base/RL.R")
101 | source("~/Documents/R/base/RL.R")
102 | source("~/Documents/R/base/RL.R")
103 | View(V)
104 | source("~/Documents/R/base/RL.R")
105 | source("~/Documents/R/base/RL.R")
106 | source("~/Documents/R/base/RL.R")
107 | source("~/Documents/R/base/RL.R")
108 | source("~/Documents/R/base/RL.R")
109 | source("~/Documents/R/base/RL.R")
110 | source("~/Documents/R/base/RL.R")
111 | source("~/Documents/R/base/RL.R")
112 | source("~/Documents/R/base/RL.R")
113 | source("~/Documents/R/base/RL.R")
114 | source("~/Documents/R/base/RL.R")
115 | source("~/Documents/R/base/RL.R")
116 | source("~/Documents/R/base/RL.R")
117 | install.packages(c("broom", "haven", "jpeg", "matrixStats", "pillar", "R6", "readr", "survival", "vroom", "xfun", "XML"))
118 | debugSource("~/Documents/GitHub/SFMV/test.R")
119 | ci.test('gNBgE', 'imiDk', data = data)
120 | ci.test('gNBgE', 'imiDk', data = data, test = 'zf')
121 | source("~/Documents/GitHub/MMHC-Python/test.R")
122 | ci.test('eutG', 'lacZ', 'flgD', data = data, test = 'zf')
123 | ci.test('eutG', 'lacZ', data = data, test = 'zf')
124 | ci.test('eutG', 'flgD', data = data, test = 'zf')
125 | ci.test('eutG', 'nuoM', data = data, test = 'zf')
126 | dag = mmhc(data)
127 | source("~/Documents/GitHub/MMHC-Python/test.R")
128 | source("~/Documents/GitHub/MMHC-Python/test.R")
129 | compare(dag, dag.learned)
130 | compare(cpdag(dag), cpdag(dag.learned))
131 | 60/(60+64)
132 | source("~/Documents/GitHub/MMHC-Python/test.R")
133 | compare(cpdag(dag), cpdag(dag.learned))
134 | shd(cpdag(dag), cpdag(dag.learned))
135 | debug = capture.output(mmhc(data, debug = TRUE))
136 | save(debug, 'debug.txt')
137 | save('debug.txt', debug)
138 | write('debug.txt', debug)
139 | write(debug, 'debug.txt')
140 | source("~/Documents/GitHub/MMHC-Python/test.R")
141 | ci.test('eutG', 'lacZ', data = data)
142 | ci.test('eutG', 'lacZ', data = data, test = 'zf')
143 | source("~/Documents/GitHub/MMHC-Python/test.R")
144 | compare(dag, dag.learned)
145 | compare(cpdag(dag), cpdag(dag.learned))
146 | 60/(60+64)
147 | shd(cpdag(dag), dag.learned)
148 | debug = capture.output(mmhc(data, restrict.args = list(test = 'zf'), debug = TRUE))
149 | write(debug, 'debug.txt')
150 | ci.test('eutG', 'flgD', data = data, test = 'zf')
151 | ci.test('eutG', 'lacZ', data = data, test = 'zf')
152 | ci.test('eutG', 'aceB', data = data, test = 'zf')
153 | a = ci.test('eutG', 'aceB', data = data, test = 'zf')
154 | View(a)
155 | source("~/Documents/GitHub/MMHC-Python/test.R")
156 | network.path
157 | source("~/Documents/GitHub/MMHC-Python/test.R")
158 | compare(dag, dag.learned)
159 | compare(cpdag(dag), cpdag(dag.learned))
160 | 34/(34+12+29)
161 | shd(cpdag(dag), cpdag(dag.learned))
162 | source("~/Documents/GitHub/MMHC-Python/test.R")
163 |
--------------------------------------------------------------------------------
/lib/mmpc.py:
--------------------------------------------------------------------------------
1 | from lib.accessory import independence_test
2 | import operator
3 |
4 | # forward phase of MMPC
5 | def mmpc_forward(tar, pc, can, data, arities, varnames, prune, test, threshold):
6 | '''
7 | forward phase of mmpc
8 | :param tar: target variable
9 | :param pc: parents and children set of the target variable
10 | :param can: candidate variable for the pc set of the target variable
11 | :param data: input data (numpy array)
12 | :param arities: number of distinct value for each variable
13 | :param varnames: variable names
14 | :param prune: whether use prune method
15 | :param test: type of statistical test (currently support g-test)
16 | :param threshold: threshold for statistical test to determine independence
17 | :return: pc set and candidate pc set
18 | '''
19 |
20 | # run until no candidate variable for current variable
21 | p_value = {}
22 | for can_var in can[tar]:
23 | p_value[can_var] = 0
24 |
25 | while can[tar]:
26 | # run conditional independence test between each candidate varialbe and target variable
27 | p_value = independence_test(p_value, tar, pc, can[tar], data, arities, varnames, test, threshold)
28 | # print(p_value)
29 | # update pc set and candidate set
30 | pc, can = update_forward(p_value, tar, pc, can, prune, threshold)
31 | return pc, can
32 |
33 |
34 | # backward phase of MMPC
35 | def mmpc_backward(tar, pc, can, data, arities, varnames, prune, test, threshold):
36 | '''
37 | backward phase of mmpc
38 | :param tar: target variable
39 | :param pc: parents and children set of the target variable
40 | :param can: candidate variable for the pc set of the target variable
41 | :param data: input data (numpy array)
42 | :param arities: number of distinct value for each variable
43 | :param varnames: variable names
44 | :param prune: whether use prune method
45 | :param test: type of statistical test (currently support g-test)
46 | :param threshold: threshold for statistical test to determine independence
47 | :return: pc set and candidate pc set
48 | '''
49 |
50 | # transfer the variable in pc set to candidate set except the last one
51 | can[tar] = pc[0: -1]
52 | pc_output = []
53 | pc_output.append(pc[-1])
54 | can[tar].reverse()
55 |
56 | while can[tar]:
57 | # run conditional independence test between each candidate varialbe and target variable
58 | p_value = independence_test({}, tar, pc_output, can[tar], data, arities, varnames, test, threshold)
59 | # update pc set and candidate set
60 | pc_output, can = update_backward(p_value, tar, pc_output, can, prune, threshold)
61 | return pc_output, can
62 |
63 |
64 | def update_forward(p_value, tar, pc, can, prune, threshold):
65 | '''
66 | add the variable with lowest p-value to pc set and remove it from the candidate set
67 | :param p_value: a dictionary contains the maximum p-value of CI tests for each variable
68 | :param tar: target variable
69 | :param pc: parents and children set of the target variable
70 | :param can: candidate variable for the pc set of the target variable
71 | :param prune: whether use prune method
72 | :param threshold: threshold for statistical test to determine independence
73 | :return: updated pc set and candidate variables
74 | '''
75 | sorted_p_value = sorted(p_value.items(), key=operator.itemgetter(1))
76 |
77 | if sorted_p_value[0][1] < threshold:
78 | pc.append(sorted_p_value[0][0])
79 | can[tar].remove(sorted_p_value[0][0])
80 | p_value.pop(sorted_p_value[0][0], None)
81 |
82 | # remove independent variables from candidate set
83 | independent_can = [x for x in sorted_p_value if x[1] > threshold]
84 | for ind in independent_can:
85 | can[tar].remove(ind[0])
86 | p_value.pop(ind[0])
87 | # prune the target variable from the candidate set of the candidate variable if they are independent
88 | if prune:
89 | if tar in can[ind[0]]:
90 | can[ind[0]].remove(tar)
91 | return pc, can
92 |
93 |
94 | def update_backward(p_value, tar, pc, can, prune, threshold):
95 | '''
96 | pc and candidate set update function for backward phase
97 | :param p_value: a dictionary contains the maximum p-value of CI tests for each variable
98 | :param tar: target variable
99 | :param pc: parents and children set of the target variable
100 | :param can: candidate variable for the pc set of the target variable
101 | :param prune: whether use prune method
102 | :param threshold: threshold for statistical test to determine independence
103 | :return: updated pc set and candidate variables
104 | '''
105 | # initialise the output candidate set
106 | can_output = []
107 |
108 | # signal of import variable
109 | sig_import = 1
110 | for can_var in can[tar]:
111 | if p_value[can_var] <= threshold:
112 | if sig_import:
113 | pc.append(can_var)
114 | sig_import = 0
115 | else:
116 | can_output.append(can_var)
117 | else:
118 | if prune:
119 | if tar in can[can_var]:
120 | can[can_var].remove(tar)
121 |
122 | can[tar] = can_output
123 |
124 | return pc, can
125 |
126 |
127 | # symmetry check for pc set
128 | def symmetry(pc):
129 | for var in pc:
130 | pc_remove = []
131 | for par in pc[var]:
132 | if var not in pc[par]:
133 | pc_remove.append(par)
134 | if pc_remove:
135 | for par in pc_remove:
136 | pc[var].remove(par)
137 | return pc
138 |
--------------------------------------------------------------------------------
/lib/hc.py:
--------------------------------------------------------------------------------
1 | from rpy2.robjects.packages import importr
2 |
3 | from lib.accessory import local_score, to_bnlearn
4 |
5 | base, bnlearn = importr('base'), importr('bnlearn')
6 |
7 |
8 | def check_cycle(vi, vj, dag):
9 | # whether adding or orientating edge vi->vj would cause cycle. In other words, this function check whether there is a direct path from vj to vi except the possible edge vi<-vj
10 | underchecked = [x for x in dag[vi]['par'] if x != vj]
11 | checked = []
12 | cyc_flag = False
13 | while underchecked:
14 | if cyc_flag:
15 | break
16 | underchecked_copy = list(underchecked)
17 | for vk in underchecked_copy:
18 | if dag[vk]['par']:
19 | if vj in dag[vk]['par']:
20 | cyc_flag = True
21 | break
22 | else:
23 | for key in dag[vk]['par']:
24 | if key not in checked + underchecked:
25 | underchecked.append(key)
26 | underchecked.remove(vk)
27 | checked.append(vk)
28 | return cyc_flag
29 |
30 |
31 | def hc(data, arities, varnames, pc=None, score='default'):
32 | '''
33 | :param data: the training data used for learn BN (numpy array)
34 | :param arities: number of distinct value for each variable
35 | :param varnames: variable names
36 | :param pc: the candidate parents and children set for each variable
37 | :param score: score function, including:
38 | bic (Bayesian Information Criterion for discrete variable)
39 | bic_g (Bayesian Information Criterion for continuous variable)
40 |
41 | :return: the learned BN (bnlearn format)
42 | '''
43 | if score == 'default':
44 | score = 'bic_g' if arities is None else 'bic'
45 | # initialize the candidate parents-set for each variable
46 | candidate = {}
47 | dag = {}
48 | cache = {}
49 | for var in varnames:
50 | if pc is None:
51 | candidate[var] = list(varnames)
52 | candidate[var].remove(var)
53 | else:
54 | candidate[var] = list(pc[var])
55 | dag[var] = {}
56 | dag[var]['par'] = []
57 | dag[var]['nei'] = []
58 | cache[var] = {}
59 | cache[var][tuple([])] = local_score(data, arities, [varnames.index(var)], score)
60 | diff = 1
61 | while diff > 0:
62 | diff = 0
63 | edge_candidate = []
64 | for vi in varnames:
65 | # attempt to add edges vi->vj
66 | for vj in candidate[vi]:
67 | cyc_flag = check_cycle(vi, vj, dag)
68 | if not cyc_flag:
69 | par_sea = tuple(sorted(dag[vj]['par'] + [vi]))
70 | if par_sea not in cache[vj]:
71 | cols = [varnames.index(x) for x in (vj, ) + par_sea]
72 | cache[vj][par_sea] = local_score(data, arities, cols, score)
73 | diff_temp = cache[vj][par_sea] - cache[vj][tuple(dag[vj]['par'])]
74 | if diff_temp - diff > 1e-10:
75 | diff = diff_temp
76 | edge_candidate = [vi, vj, 'a']
77 | for par_vi in dag[vi]['par']:
78 | # attempt to reverse edges from vi<-par_vi to vi->par_vi
79 | cyc_flag = check_cycle(vi, par_vi, dag)
80 | if not cyc_flag:
81 | par_sea_par_vi = tuple(sorted(dag[par_vi]['par'] + [vi]))
82 | if par_sea_par_vi not in cache[par_vi]:
83 | cols = [varnames.index(x) for x in (par_vi, ) + par_sea_par_vi]
84 | cache[par_vi][par_sea_par_vi] = local_score(data, arities, cols, score)
85 | par_sea_vi = tuple([x for x in dag[vi]['par'] if x != par_vi])
86 | if par_sea_vi not in cache[vi]:
87 | cols = [varnames.index(x) for x in (vi, ) + par_sea_vi]
88 | cache[vi][par_sea_vi] = local_score(data, arities, cols, score)
89 | diff_temp = cache[par_vi][par_sea_par_vi] + cache[vi][par_sea_vi] - cache[par_vi][
90 | tuple(dag[par_vi]['par'])] - cache[vi][tuple(dag[vi]['par'])]
91 | if diff_temp - diff > 1e-10:
92 | diff = diff_temp
93 | edge_candidate = [vi, par_vi, 'r']
94 | # attempt to delete edges vi<-par_vi
95 | par_sea = tuple([x for x in dag[vi]['par'] if x != par_vi])
96 | if par_sea not in cache[vi]:
97 | cols = [varnames.index(x) for x in (vi, ) + par_sea]
98 | cache[vi][par_sea] = local_score(data, arities, cols, score)
99 | diff_temp = cache[vi][par_sea] - cache[vi][tuple(dag[vi]['par'])]
100 | if diff_temp - diff > 1e-10:
101 | diff = diff_temp
102 | edge_candidate = [par_vi, vi, 'd']
103 | if edge_candidate:
104 | if edge_candidate[-1] == 'a':
105 | dag[edge_candidate[1]]['par'] = sorted(dag[edge_candidate[1]]['par'] + [edge_candidate[0]])
106 | candidate[edge_candidate[0]].remove(edge_candidate[1])
107 | candidate[edge_candidate[1]].remove(edge_candidate[0])
108 | elif edge_candidate[-1] == 'r':
109 | dag[edge_candidate[1]]['par'] = sorted(dag[edge_candidate[1]]['par'] + [edge_candidate[0]])
110 | dag[edge_candidate[0]]['par'].remove(edge_candidate[1])
111 | elif edge_candidate[-1] == 'd':
112 | dag[edge_candidate[1]]['par'].remove(edge_candidate[0])
113 | candidate[edge_candidate[0]].append(edge_candidate[1])
114 | candidate[edge_candidate[1]].append(edge_candidate[0])
115 | dag = bnlearn.model2network(to_bnlearn(dag))
116 | return dag
117 |
--------------------------------------------------------------------------------
/lib/accessory.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations, chain
2 |
3 | import numpy as np
4 | import pingouin as pg
5 | from numba import njit
6 | from scipy.stats import chi2, norm
7 | from sklearn.linear_model import LinearRegression
8 |
9 |
10 | def powerset(iterable):
11 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
12 | s = list(iterable)
13 | return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
14 |
15 |
16 | @njit(fastmath=True)
17 | def bic(data, arities, cols):
18 | strides = np.empty(len(cols), dtype=np.uint32)
19 | idx = len(cols) - 1
20 | stride = 1
21 | while idx > -1:
22 | strides[idx] = stride
23 | stride *= arities[cols[idx]]
24 | idx -= 1
25 | N_ijk = np.zeros(stride)
26 | N_ij = np.zeros(stride)
27 | for rowidx in range(data.shape[0]):
28 | idx_ijk = 0
29 | idx_ij = 0
30 | for i in range(len(cols)):
31 | idx_ijk += data[rowidx, cols[i]] * strides[i]
32 | if i != 0:
33 | idx_ij += data[rowidx, cols[i]] * strides[i]
34 | N_ijk[idx_ijk] += 1
35 | for i in range(arities[cols[0]]):
36 | N_ij[idx_ij + i * strides[0]] += 1
37 | bic = 0
38 | for i in range(stride):
39 | if N_ijk[i] != 0:
40 | bic += N_ijk[i] * np.log(N_ijk[i] / N_ij[i])
41 | bic -= 0.5 * np.log(data.shape[0]) * (arities[cols[0]] - 1) * strides[0]
42 | return bic
43 |
44 |
45 | def bic_g(data, arities, cols):
46 | data = data.to_numpy()
47 | y = data[:, cols[0]]
48 | if len(cols) == 1:
49 | resids = np.mean(y) - y
50 | else:
51 | X = data[:, cols[1:]]
52 | reg = LinearRegression().fit(X, y)
53 | preds = reg.predict(X)
54 | resids = y - preds
55 | sd = np.std(resids)
56 | numparams = len(cols) + 1 # include intercept AND sd (even though latter is not a free param)
57 | bic = norm.logpdf(resids, scale=sd).sum() - np.log(data.shape[0]) / 2 * numparams
58 | return bic
59 |
60 |
61 | def local_score(data, arities, cols, score='default'):
62 | '''
63 | :param weight: weight for data
64 | :param data: numbered version of data set
65 | :param cols: the index of node and its parents, the first element represents the index of the node and the following elements represent the indices of its parents
66 | :param score_function: name of score function, currently support bic, nal, bic_g
67 | :return: local score of node (cols[0]) given its parents (cols[1:])
68 | '''
69 | if len(data) == 0:
70 | return np.nan
71 | else:
72 | if score == 'default':
73 | score = 'bic_g' if arities is None else 'bic'
74 | try:
75 | ls = globals()[score](data, arities, np.asarray(cols))
76 | except Exception as e:
77 | raise Exception('score function ' + str(
78 | e) + ' is undefined or does not fit to data type. Available score functions are: bic (BIC for discrete variables) and bic_g (BIC for continuous variables).')
79 | return ls
80 |
81 |
82 | # statistical test
83 | def independence_test(p_value, tar, pc, can, data, arities, varnames, test='g-test', threshold=0.05):
84 | '''
85 | statistical independence test
86 | :param p_value: a dictionary contains the maximum p-value of CI tests for each variable
87 | :param pc: parents and children set of the target variable
88 | :param can: candidate variable for the pc set of the target variable
89 | :param data: input data (numpy array)
90 | :param arities: number of distinct value for each variable
91 | :param varnames: variable names
92 | :param prune: whether use prune method
93 | :param test: type of statistical test (currently support g-test)
94 | :param threshold: threshold for statistical test to determine independence
95 | :return: a dictionary contains the maximum p-value of CI test for each variable
96 | '''
97 | for can_var in can:
98 | if can_var not in p_value.keys():
99 | p_value[can_var] = 0
100 | for con in powerset(pc[0:-1]):
101 | # avoid checking the separation set that has been checked in previous iterations
102 | con = list(con)
103 | if len(pc) != 0:
104 | con.append(pc[-1])
105 | cols = np.array([varnames.index(x) for x in [tar, can_var] + con])
106 | if test == 'g-test':
107 | G, dof = it_counter(data, arities, cols)
108 | p = chi2.sf(G, dof)
109 | p_value[can_var] = max(p, p_value[can_var])
110 | elif test == 'z-test':
111 | # under construction
112 | r = pg.partial_corr(data=data, x=tar, y=can_var, covar=con)['r'][0]
113 | z = np.sqrt(data.shape[0] - len(con) - 3) * np.arctanh(r)
114 | p = 2 * min(norm.cdf(z), norm.cdf(-z))
115 | p_value[can_var] = max(p, p_value[can_var])
116 | else:
117 | raise Exception('statistical test ' + test + ' is undefined, currently supported tests are: g-test')
118 | if p_value[can_var] > threshold:
119 | break
120 | return p_value
121 |
122 |
123 | @njit(fastmath=True)
124 | def it_counter(data, arities, cols):
125 | strides = np.empty(len(cols), dtype=np.uint32)
126 | idx = len(cols) - 1
127 | stride = 1
128 | while idx > -1:
129 | strides[idx] = stride
130 | stride *= arities[cols[idx]]
131 | idx -= 1
132 | N_ijk = np.zeros(stride)
133 | N_ik = np.zeros(stride)
134 | N_jk = np.zeros(stride)
135 | N_k = np.zeros(stride)
136 | for rowidx in range(data.shape[0]):
137 | idx_ijk = 0
138 | idx_ik = 0
139 | idx_jk = 0
140 | idx_k = 0
141 | for i in range(len(cols)):
142 | idx_ijk += data[rowidx, cols[i]] * strides[i]
143 | if i != 0:
144 | idx_jk += data[rowidx, cols[i]] * strides[i]
145 | if i != 1:
146 | idx_ik += data[rowidx, cols[i]] * strides[i]
147 | if (i != 0) & (i != 1):
148 | idx_k += data[rowidx, cols[i]] * strides[i]
149 | N_ijk[idx_ijk] += 1
150 | for j in range(arities[cols[1]]):
151 | N_ik[idx_ik + j * strides[1]] += 1
152 | for i in range(arities[cols[0]]):
153 | N_jk[idx_jk + i * strides[0]] += 1
154 | for i in range(arities[cols[0]]):
155 | for j in range(arities[cols[1]]):
156 | N_k[idx_k + i * strides[0] + j * strides[1]] += 1
157 | G = 0
158 | for i in range(stride):
159 | if N_ijk[i] != 0:
160 | G += 2 * N_ijk[i] * np.log(N_ijk[i] * N_k[i] / N_ik[i] / N_jk[i])
161 |
162 | dof = max((arities[cols[0]] - 1) * (arities[cols[1]] - 1) * strides[1], 1)
163 | return G, dof
164 |
165 |
166 | # convert the dag to bnlearn format
167 | def to_bnlearn(dag):
168 | output = ''
169 | for var in dag:
170 | output += '[' + var
171 | if dag[var]['par']:
172 | output += '|'
173 | for par in dag[var]['par']:
174 | output += par + ':'
175 | output = output[:-1]
176 | output += ']'
177 | return output
178 |
--------------------------------------------------------------------------------