├── .DS_Store ├── Input ├── .DS_Store ├── alarm.rds └── ecoli70.rds ├── output └── .DS_Store ├── lib ├── __pycache__ │ ├── hc.cpython-37.pyc │ ├── mmhc.cpython-37.pyc │ ├── mmpc.cpython-37.pyc │ ├── accessory.cpython-37.pyc │ └── evaluation.cpython-37.pyc ├── evaluation.py ├── mmhc.py ├── mmpc.py ├── hc.py └── accessory.py ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── inspectionProfiles │ └── Project_Default.xml ├── test.iml └── workspace.xml ├── README.md ├── BN_structure_learning.py └── .Rhistory /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/.DS_Store -------------------------------------------------------------------------------- /Input/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/.DS_Store -------------------------------------------------------------------------------- /Input/alarm.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/alarm.rds -------------------------------------------------------------------------------- /Input/ecoli70.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/Input/ecoli70.rds -------------------------------------------------------------------------------- /output/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/output/.DS_Store -------------------------------------------------------------------------------- /lib/__pycache__/hc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/hc.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/mmhc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/mmhc.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/mmpc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/mmpc.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/accessory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/accessory.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enderlogic/MMHC-Python/HEAD/lib/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is an implementation of MMHC in python. All data sets and models are placed in the "Input" folder and the results are generated to 3 | the "Output" folder. 4 | 5 | If you want to test your own data set, just put it in the "Input" folder and change the corresponding variable in 6 | "BN_structure_learning" file which is also an example file for running the code. If you want to evaluate the result, please provide the ground truth model with bnlearn format. 7 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /lib/evaluation.py: -------------------------------------------------------------------------------- 1 | # evaluation methods 2 | from rpy2.robjects.packages import importr 3 | base, bnlearn = importr('base'), importr('bnlearn') 4 | 5 | # compute the F1 score of a learned graph given true graph 6 | def f1(dag_true, dag_learned): 7 | ''' 8 | :param dag_true: true DAG 9 | :param dag_learned: learned DAG 10 | :return: the F1 score of learned DAG 11 | ''' 12 | compare = bnlearn.compare(bnlearn.cpdag(dag_true), bnlearn.cpdag(dag_learned)) 13 | return compare[0][0] * 2 / (compare[0][0] * 2 + compare[1][0] + compare[2][0]) -------------------------------------------------------------------------------- /.idea/test.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /BN_structure_learning.py: -------------------------------------------------------------------------------- 1 | import random 2 | import urllib.request 3 | from os import path 4 | 5 | import pandas as pd 6 | from graphviz import Digraph 7 | from rpy2.robjects import pandas2ri 8 | from rpy2.robjects.packages import importr 9 | 10 | from lib.evaluation import f1 11 | from lib.mmhc import mmhc 12 | 13 | pandas2ri.activate() 14 | base, bnlearn = importr('base'), importr('bnlearn') 15 | 16 | # load network 17 | network = 'alarm' 18 | network_path = 'Input/' + network + '.rds' 19 | if not path.isfile(network_path): 20 | url = 'https://www.bnlearn.com/bnrepository/' + network + '/' + network + '.rds' 21 | urllib.request.urlretrieve(url, network_path) 22 | dag_true = base.readRDS(network_path) 23 | 24 | # generate data 25 | datasize = 10000 26 | filename = 'Input/' + network + '_' + str(datasize) + '.csv' 27 | if path.isfile(filename): 28 | data = pd.read_csv(filename, dtype='category') # change dtype = 'float64'/'category' if data is continuous/categorical 29 | else: 30 | data = bnlearn.rbn(dag_true, datasize) 31 | data = data[random.sample(list(data.columns), data.shape[1])] 32 | data.to_csv(filename, index=False) 33 | 34 | 35 | # learn bayesian network from data 36 | dag_learned = mmhc(data) 37 | 38 | # plot the learned graph 39 | dot = Digraph() 40 | for node in bnlearn.nodes(dag_learned): 41 | dot.node(node) 42 | for parent in bnlearn.parents(dag_learned, node): 43 | dot.edge(node, parent) 44 | dot.render('output/' + network + '_' + str(datasize) + '.gv', view = False) 45 | 46 | # evaluate the learned graph 47 | print('f1 score is ' + str(f1(dag_true, dag_learned))) 48 | print('shd score is ' + str(bnlearn.shd(bnlearn.cpdag(dag_true), dag_learned)[0])) -------------------------------------------------------------------------------- /lib/mmhc.py: -------------------------------------------------------------------------------- 1 | from lib.mmpc import mmpc_forward, mmpc_backward, symmetry 2 | from lib.hc import hc 3 | import time 4 | import numpy as np 5 | 6 | def mmhc(data, test = None, score = None, prune = True, threshold = 0.05): 7 | ''' 8 | mmhc algorithm 9 | :param data: input data (pandas dataframe) 10 | :param test: type of independence test (currently support g-test (for discrete data), z-test (for continuous data)) 11 | :param score: type of score function (currently support bic (for both discrete and continuous data)) 12 | :param prune: whether use prune method 13 | :param threshold: threshold for CI test 14 | :return: the DAG learned from data (bnlearn format) 15 | ''' 16 | 17 | # initialise pc set as empty for all variables 18 | pc = {} 19 | 20 | # initialise the candidate set for variables 21 | can = {} 22 | for tar in data: 23 | can[tar] = list(data.columns) 24 | can[tar].remove(tar) 25 | 26 | # preprocess the data 27 | varnames = list(data.columns) 28 | if all(data[var].dtype.name == 'category' for var in data): 29 | arities = np.array(data.nunique()) 30 | data = data.apply(lambda x: x.cat.codes).to_numpy() 31 | if test is None: 32 | test = 'g-test' 33 | if score is None: 34 | score = 'bic' 35 | elif all(data[var].dtype.name != 'category' for var in data): 36 | arities = None 37 | if test is None: 38 | test = 'z-test' 39 | if score is None or score == 'bic': 40 | score = 'bic_g' 41 | else: 42 | raise Exception('Mixed data is not supported.') 43 | 44 | # run MMPC on each variable 45 | start = time.time() 46 | for tar in varnames: 47 | # forward phase 48 | pc[tar] = [] 49 | pc[tar], can = mmpc_forward(tar, pc[tar], can, data, arities, varnames, prune, test, threshold) 50 | # backward phase 51 | if pc[tar]: 52 | pc[tar], can = mmpc_backward(tar, pc[tar], can, data, arities, varnames, prune, test, threshold) 53 | # check the symmetry of pc set 54 | # when the number of variables is large, this function may be computational costly 55 | # this function can be merged into the pruning process during forward and backward mmpc by transmitting the whole 56 | # pc set into mmpc_forward and mmpc_backward 57 | pc = symmetry(pc) 58 | print('MMPC phase costs %.2f seconds' % (time.time() - start)) 59 | # run hill-climbing 60 | start = time.time() 61 | dag = hc(data, arities, varnames, pc, score) 62 | print('HC phase costs %.2f seconds' % (time.time() - start)) 63 | return dag -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 65 | 66 | 67 |