├── compile.sh ├── run.sh ├── BayesNet.png ├── data ├── .DS_Store ├── gold_alarm.bif └── solved_alarm.bif ├── main.py ├── README.md ├── bayesnet.py ├── format_checker.cpp └── utils.py /compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py $1 $2 -------------------------------------------------------------------------------- /BayesNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navreeetkaur/bayesian-network-learning/HEAD/BayesNet.png -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navreeetkaur/bayesian-network-learning/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import utils 4 | 5 | 6 | # Setup network 7 | step0 = time.time() 8 | print "Initialising . . . . " 9 | bn, df, mis_index = utils.setup_network(sys.argv[1], sys.argv[2]) 10 | step1 = time.time() 11 | print "Initialisation time: (%ss)" % (round((step1 - step0), 5)) 12 | print 13 | # Learn parameters 14 | print "Expectation-Maximisation . . . . " 15 | Alarm = utils.Expectation_Maximisation(df, bn, mis_index) 16 | step2 = time.time() 17 | print 18 | print "EM time: (%ss)" % (round((step2 - step1), 5)) 19 | print 20 | print "Parsing output file . . . . " 21 | utils.parse_output(Alarm, sys.argv[1]) 22 | step3 = time.time() 23 | print "Output file parsing: (%ss)" % (round((step3 - step2), 5)) 24 | print 25 | print "TOTAL Time taken: (%ss)" % (round((step3 - step1), 5)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian Network Parameter Learning 2 | ### Course Project - COL884(Spring'18):Uncertainity in AI 3 | #### Creator: Navreet Kaur[2015TT10917] 4 | 5 | #### Objective: 6 | Bayesian Parameter Learning of Alarm Bayesian Net given data with at most one missing value in each row. 7 | #### Algorithm Used: 8 | Expectation-Maximisation 9 | #### Goal: 10 | The goal of this assignment is to get experience with learning of Bayesian Networks and understanding their value in the real world. 11 | #### Scenario: 12 | Medical diagnosis. Some medical researchers have created a Bayesian network that models the inter-relationship between (some) diseases and observed symptoms. Our job as computer scientists is to learn parameters for the network based on health records. Unfortunately, as it happens in the real world, certain records have missing values. We need to do our best to compute the parameters for the network, so that it can be used for diagnosis later on. 13 | #### Problem Statement: 14 | We are given the Bayesian Network created by the researchers(as shown in BayesNet.png).Notice that eight diagnoses are modeled here: hypovolemia, left ventricular failure, Anaphylaxis, insufficient analgesia, pulmonary embolus, intubation, kinked tube, and disconnection. The observable nodes are CVP, PCWP, History, TPR, Blood Pressure, CO, HR BP, HR EKG, HR SAT, SaO2, PAP, MV, Min Vol, Exp CO2, FiO2 and Pres. Such networks can be represented in many formats. We will use the .bif format. BIF stands for Bayesian Interchange Format. The details about the format are [here](http://sites.poli.usp.br/p/fabio.cozman/). We are also providing a .bif parser so that you can start directly from a parsed Bayesian network represented as a graph. 15 | 16 | The goal of the assignment is to learn the Bayes net from a healthcare dataset. 17 | #### Input format: 18 | We will work with alarm.bif network. Please have a look at this file to get a basic understanding of how this information relates to the Bayes net image above. A sample Bayes net is as follows 19 | variable “X” { 20 | 21 | type discrete[2] { “True” “False” }; 22 | 23 | } 24 | 25 | variable “Y” { 26 | 27 | type discrete[2] { “True” “False” }; 28 | 29 | } 30 | 31 | variable “Z” { 32 | 33 | type discrete[2] { “True” “False” }; 34 | 35 | } 36 | probability(“X”) { table 0.2 0.8 ; } 37 | 38 | probability(“Y”) { table 0.4 0.6 ; } 39 | 40 | probability(“Z” “X” “Y”) { table 0.2 0.4 0.3 0.5 0.8 0.6 0.7 0.5; } 41 | 42 | This says that X, Y, and Z all have two values each. X and Y has no parents and prior P(X=True)=0.2, P(X=False)=0.8, and so on. Z has both X and Y as parents. Its probability table says P(Z=True|X=True, Y=True) = 0.2, P(Z=True|X=True, Y=False) = 0.4 and so on. 43 | 44 | Our input network will have the Bayes net structure including variables and parents, but will not have probability values. We will use -1 to represent that the probability value is unknown. 45 | probability(“X”) { table -1 -1 ; } will represent that prior probability of X is unknown and needs to be computed via learning. 46 | 47 | To learn these values we will provide a data file. Each line will be a patient record. All features will be listed in exactly the same order as in the .bif network and will be comma-separated. If a feature value is unknown we will use the special symbol “?” for it. There will be no more than 1 unknown value per row. Example: 48 | 49 | “True”, “False”, “True” “?”, “False”, “False” 50 | 51 | Here the first row says that X=True, Y=False and Z=True. The second row says that X is not known, Y and Z are both False. 52 | Overall your input will be alarm.bif with most probability values -1 and this datafile. The datafile will have about 10,000 patient records. 53 | #### Output format: 54 | Output will be the result of learning each probability value in the conditional probability tables. In other words, all -1s are replaced with a probability value upto four decimal places. Thus, the output is a complete alarm.bif network. 55 | #### Files: 56 | 1) records.dat: 57 | A Dataset file where a single line is a single patient record and each variable in the record is separated by spaces. The unknown record is marked by “?”. Each line contains at max 1 missing record. The file contains more than 11000 records. 58 | 2) format_check.cpp: 59 | A format checker to check your output file adheres to alarm.bif format. The format checker assumes that alarm.bif, solved_alarm.bif and gold_alarm.bif are present in current directory and outputs its results. (A next version will also compute the total learning error). 60 | 3) Alarm.bif: 61 | BIF format file, whose parameters need to be learned 62 | 4) Gold_Alarm.bif: 63 | BIF file having the true parameters 64 | 5) bayesnet.py: 65 | classes: 66 | Graph_Node 67 | Network 68 | methods: 69 | read_network: Parsing the .bif format file and build a bayesian net 70 | markov_blanket: Get variables in the markov blanket of variable 'val_name' 71 | get_data: Read data from records.dat and store as a pandas dataframe 72 | normalise_counts: normalise a list of counts from a given CPT 73 | 6) utils.py: 74 | methods: 75 | setup_network 76 | get_missing_index: List of the indices of nodes which have missing values in each data point; equal to -1 if no value is missing 77 | init_params: Initialise parameters 78 | normalise_array: Normalise a numpy array 79 | get_assignment_for: return the rows of the factor table with assignments as specified in evidence E 80 | markov_blanket_sampling: Inference by Markov Blanket Sampling 81 | Expectation 82 | Maximisation 83 | Expectation_Maximisation 84 | parse_output 85 | 7) main.py: main file that calls methods from bayesnet and utils to build a bayes net, read data and learn its parameters 86 | #### Compilation: 87 | Run the file run.sh - it takes 2 input files, alarm.bif and records.dat and output a file named 88 | solved_alarm.bif file: 89 | `./run.sh alarm.bif .dat` 90 | 91 | #### Assumptions: 92 | • All variables are missing completely(or unconditionally) at random(MCAR) and none of them are either missing at random(MAR) or missing systematically or hidden i.e. initially, probability of each missing value is the same and the sample mean of variable v is unbiased estimator of true value of v 93 | #### Parameter Initialisation: 94 | • Initialisation of parameters by available case analysis(ignoring rows with missing values if the missing value is that of the parent). Since data is MCAR, estimators based on the subsample of the data are unbiased estimators for the ones with complete data 95 | #### Design Choices: 96 | 1. Data Records: 97 | (a) String values for each class of random variables were mapped to integers 98 | (b) Data File was stored as a Pandas DataFrame so as to perform grouping and aggregation of certain data occurrences to get theirs counts 99 | 2. Network 100 | (a) Ordered dictionary to represent nodes in the graph (keys = name of random variable, value = node object) 101 | (b) Ordered dictionary to store Markov Blanket(MB) of all nodes (keys = name of random variable(X), value = list of Strings of names of nodes in MB of X) - This is stored so as to avoid recomputation of Markov Blanket at each step while doing Markov Blanket Sampling Inference 102 | 3. Graph_Node 103 | (a) List of Strings to store names of Parents 104 | (b) List of integers to store indices of Children in ordered dictionary of nodes in Bayes Net 105 | (c) Pandas DataFrames to store CPT 106 | 4. CPTs 107 | (a) All CPTs are represented by Pandas DataFrames(columns are names of variables and column ‘p’ for probability value) so as to easily access the entries by specifying a dictionary of ‘Evidence’ with keys as variable names and values as the integers 108 | #### Optimisation/Techniques: 109 | 1. Storage of only counts and not probabilities in all the CPTs and normalising them before performing Expectation step 110 | 2. **Smoothing**: Since all possible instances might not be observed due to small size of dataset as compared to number of network nodes, counts of all possible instances in the CPTs were set to one to initialise with. Similarly, in the Maximisation step, with any observed count was equal to zero, it was set to 0.00005 (since required precision of probabilities is upto 4 decimal places and counts in maximisation might be less than one due to weights of data points being considered, which itself lie between 0 and 1) 111 | 3. **Inference**: Since the probability of variable X is independent of all other variables given its markov blanket and only one data point is missing per row(i.e. all points are given hence MB is given), therefore, P(X | data) = P(X | mb(X)), where mb(X) is the markov blanket of x. Therefore, markov blanket sampling was used to calculate P(X | MB(X)) 112 | 4. Using **log probabilities** as addition operation is faster than multiplication and also, it helps to avoid numerical underflow. 113 | 5. **Convergence Criteria**: Maximum change in the CPTs in previous and current iteration is less than equal to 0.00005 114 | -------------------------------------------------------------------------------- /bayesnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from collections import OrderedDict 3 | import numpy as np 4 | import pandas as pd 5 | import time 6 | 7 | 8 | __author__ = "Navreet Kaur" 9 | __entrynumber__ = "2015TT10917" 10 | 11 | 12 | class Graph_Node(): 13 | """Our graph consists of a list of nodes where each node is represented as follows""" 14 | 15 | def __init__(self, name, n, vals): 16 | self.Node_Name = name # Variable name 17 | self.nvalues = n # Number of categories a variable represented by this node can take 18 | self.values = vals # Categories of possible values 19 | self.Children = [] # Children of a particular node - these are index of nodes in graph. 20 | self.Parents = [] # Parents of a particular node- note these are names of parents 21 | self.CPT = [] 22 | self.cpt_data = pd.DataFrame() # conditional probability table as a DataFrame (counts) 23 | self.markov_blanket = [] # List of nodes in the Markov Blanket - note that these are the names of the nodes 24 | 25 | def get_name(self): 26 | return self.Node_Name 27 | 28 | def get_children(self): 29 | return self.Children 30 | 31 | def get_Parents(self): 32 | return self.Parents 33 | 34 | def get_n_parents(self): 35 | return len(self.Parents) 36 | 37 | def get_CPT(self): 38 | return self.CPT 39 | 40 | def get_nvalues(self): 41 | return self.nvalues 42 | 43 | def get_values(self): 44 | return self.values 45 | 46 | def set_CPT(self, new_CPT): 47 | del(self.CPT[:]) 48 | self.CPT = new_CPT 49 | 50 | def set_counts(self, new_counts): 51 | del(self.counts[:]) 52 | self.counts = new_counts 53 | 54 | def set_MB(self, new_mb): 55 | self.markov_blanket = new_mb 56 | 57 | def set_cpt_data(self, new_cpt_data): 58 | self.cpt_data.drop(columns = list(self.cpt_data.columns)) 59 | self.cpt_data = new_cpt_data 60 | 61 | def set_Parents(self, Parent_Nodes): 62 | self.Parents = Parent_Nodes 63 | 64 | def add_child(self, new_child_index): 65 | if new_child_index in self.Children: 66 | return 0 67 | else: 68 | self.Children.append(new_child_index) 69 | return 1 70 | 71 | def print_node(self): 72 | print(self.Node_Name) 73 | print(self.values) 74 | print(self.Parents) 75 | print(self.CPT) 76 | print 77 | 78 | 79 | class network(): 80 | """ 81 | The whole network represted as a dictionary of nodes 82 | Pres_Graph: 83 | Ordered Dictionary - Keys: variable names, Values: Node Objects 84 | MB: 85 | Ordered Dictionary - Keys: variable names, Values: List of names of the nodes in the markob blanket of the key 86 | """ 87 | 88 | def __init__(self, Pres_Graph = OrderedDict(), MB = OrderedDict()): 89 | self.Pres_Graph = Pres_Graph 90 | self.MB = MB 91 | 92 | def addNode(self, node): 93 | self.Pres_Graph[node.Node_Name] = node 94 | 95 | def netSize(self): 96 | return len(self.Pres_Graph) 97 | 98 | def get_index(self, val_name): 99 | try: 100 | return self.Pres_Graph.keys().index(val_name) 101 | except: 102 | print "No node of the name: " + str(val_name) 103 | return None 104 | 105 | def get_nth_node(self, n): 106 | return self.Pres_Graph.values()[n] 107 | 108 | def search_node(self, val_name): 109 | try: 110 | return self.Pres_Graph[val_name] 111 | except: 112 | print "Node NOT found" 113 | return None 114 | 115 | def get_parent_nodes(self, node): 116 | parent_nodes = [] 117 | parents = node.get_Parents() 118 | for p in parents: 119 | parent_nodes.append(self.search_node(p)) 120 | return parent_nodes 121 | 122 | def get_children(self, val_name): 123 | Children = self.Pres_Graph[val_name].Children 124 | c = [] 125 | for n in Children: 126 | c.append(self.Pres_Graph.keys()[n]) 127 | return c 128 | 129 | def set_mb(self): 130 | for vals in self.Pres_Graph.keys(): 131 | self.MB[vals] = markov_blanket(self, vals) 132 | 133 | 134 | def normalise_cpt(self, X): 135 | l = [X] + self.Pres_Graph[X].Parents + ['counts', 'p'] 136 | cpt = self.Pres_Graph[X].cpt_data 137 | nvals = self.Pres_Graph[X].nvalues 138 | cardinality = cpt.shape[0] 139 | no_grps = int(cardinality/nvals) 140 | list_dfs = [] 141 | df = pd.DataFrame() 142 | i=0 143 | for n in range(no_grps): 144 | curr_df = pd.DataFrame(cpt.iloc[i:i+nvals, :]) 145 | curr_df['p'] = normalise_counts(curr_df['counts']) 146 | df = df.append(curr_df) 147 | i = i + nvals 148 | self.Pres_Graph[X].cpt_data = df[l] 149 | 150 | 151 | """ Reading network from .bif format """ 152 | def read_network(bif_filepath): 153 | Alarm = network() 154 | find = 0 155 | 156 | with open(bif_filepath, 'r') as myfile: 157 | while True: 158 | line = myfile.readline() 159 | line = line.strip() 160 | 161 | if line == '': 162 | break 163 | 164 | tokens = line.split() 165 | first_word = tokens[0] 166 | 167 | 168 | if first_word == "variable": 169 | values = [] 170 | name = tokens[1] # random varible name 171 | line_ = myfile.readline() # read next line 172 | line_ = line_.strip() 173 | tokens_ = line_.split() 174 | for i in range(3,len(tokens_)-1): 175 | values.append(tokens_[i]) 176 | new_node = Graph_Node(name = name, n = len(values), vals = values) 177 | Alarm.addNode(new_node) 178 | 179 | 180 | if first_word == "probability": 181 | vals = [] 182 | temp = tokens[2] 183 | node = Alarm.search_node(temp) 184 | index = Alarm.get_index(temp) 185 | i = 3 186 | # setting parents 187 | while True: 188 | if tokens[i]==")": 189 | break 190 | node_ = Alarm.search_node(tokens[i]) 191 | node_.add_child(index) 192 | vals.append(tokens[i]) 193 | i = i + 1 194 | 195 | node.set_Parents(vals) 196 | 197 | line_ = myfile.readline() 198 | tokens_ = line_.split() 199 | curr_CPT = [] 200 | for i in range(1,len(tokens_)-1): 201 | curr_CPT.append(int(tokens_[i])) 202 | 203 | node.set_CPT(curr_CPT) 204 | 205 | myfile.close() 206 | 207 | return Alarm 208 | 209 | 210 | # Get variables in the markov blanket of variable 'val_name' 211 | def markov_blanket(net, val_name): 212 | node = net.search_node(val_name) 213 | mb = [] 214 | # Parents 215 | parents = node.Parents 216 | mb = mb + parents 217 | # Children 218 | children_names = node.Children 219 | for c in children_names: 220 | child_node = net.Pres_Graph[net.Pres_Graph.keys()[c]] 221 | mb.append(child_node.Node_Name) 222 | # Spouses 223 | spouses = child_node.Parents 224 | for var in spouses: 225 | if var not in mb and var!=val_name: 226 | mb.append(var) 227 | 228 | return mb 229 | 230 | 231 | # Get the datafile as a pandas dataframe 232 | def get_data(filepath): 233 | with open(filepath,'r') as f: 234 | df = pd.DataFrame(l.rstrip().split() for l in f) 235 | 236 | df.columns = ['"Hypovolemia"','"StrokeVolume"','"LVFailure"','"LVEDVolume"','"PCWP"','"CVP"','"History"', 237 | '"MinVolSet"','"VentMach"','"Disconnect"','"VentTube"','"KinkedTube"','"Press"','"ErrLowOutput"', 238 | '"HRBP"','"ErrCauter"','"HREKG"','"HRSat"','"BP"','"CO"','"HR"','"TPR"','"Anaphylaxis"','"InsuffAnesth"','"PAP"','"PulmEmbolus"', 239 | '"FiO2"','"Catechol"','"SaO2"','"Shunt"','"PVSat"','"MinVol"','"ExpCO2"','"ArtCO2"','"VentAlv"','"VentLung"','"Intubation"'] 240 | 241 | features = list(df.columns) 242 | 243 | mapping_1 = {'"True"': 0, '"False"': 1, '"?"': float('nan')} 244 | mapping_2 = {'"Zero"': 0, '"Low"': 1, '"Normal"': 2, '"High"': 3, '"?"': float('nan')} 245 | mapping_3 = { '"Normal"': 0, '"Esophageal"': 1 , '"OneSided"': 2, '"?"': float('nan') } 246 | mapping_4 = {'"Low"':0, '"Normal"':1, '"High"':2, '"?"': float('nan')} 247 | mapping_5 = {'"Low"':0, '"Normal"':1, '"?"': float('nan')} 248 | mapping_6 = {'"Normal"':0, '"High"':1, '"?"': float('nan')} 249 | overall_mapping = { '"Hypovolemia"':mapping_1 , u'"StrokeVolume"':mapping_4, u'"LVFailure"':mapping_1, 250 | u'"LVEDVolume"':mapping_4, u'"PCWP"':mapping_4, u'"CVP"':mapping_4, 251 | u'"History"':mapping_1, u'"MinVolSet"':mapping_4, u'"VentMach"':mapping_2, u'"Disconnect"':mapping_1, 252 | u'"VentTube"':mapping_2, u'"KinkedTube"':mapping_1, u'"Press"':mapping_2, 253 | u'"ErrLowOutput"':mapping_1, u'"HRBP"':mapping_4, 254 | u'"ErrCauter"':mapping_1, u'"HREKG"':mapping_4, u'"HRSat"':mapping_4, 255 | u'"BP"':mapping_4, u'"CO"':mapping_4, u'"HR"':mapping_4, u'"TPR"':mapping_4, 256 | u'"Anaphylaxis"':mapping_1, u'"InsuffAnesth"':mapping_1, u'"PAP"':mapping_4, 257 | u'"PulmEmbolus"':mapping_1, u'"FiO2"':mapping_5, 258 | u'"Catechol"':mapping_6, u'"SaO2"':mapping_4, u'"Shunt"':mapping_6, 259 | u'"PVSat"':mapping_4, u'"MinVol"':mapping_2, u'"ExpCO2"':mapping_2, 260 | u'"ArtCO2"':mapping_4, u'"VentAlv"':mapping_2, u'"VentLung"':mapping_2, u'"Intubation"':mapping_3} 261 | df = df.replace(overall_mapping) 262 | # to get csv file of data 263 | # df.to_csv('records.csv') 264 | return df 265 | 266 | 267 | # normalise a list of counts 268 | def normalise_counts(vals): 269 | vals[vals==0] = 0.000005 270 | denom = np.sum(vals) 271 | normalised_vals = [] 272 | for val in vals: 273 | normalised_vals.append(val/float(denom)) 274 | return normalised_vals 275 | 276 | 277 | if __name__ == '__main__': 278 | print "This file contains Bayes Net classes: Run main.py" 279 | -------------------------------------------------------------------------------- /format_checker.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | using namespace std; 12 | 13 | class Graph_Node{ 14 | 15 | private: 16 | string Node_Name; 17 | vector Children; 18 | vector Parents; 19 | int nvalues; 20 | vector values; 21 | vector CPT; 22 | 23 | public: 24 | //Graph_Node(string name, vector Child_Nodes,vector Parent_Nodes,int n, vector vals,vector curr_CPT) 25 | Graph_Node(string name,int n,vector vals) 26 | { 27 | Node_Name=name; 28 | //Children=Child_Nodes; 29 | //Parents=Parent_Nodes; 30 | nvalues=n; 31 | values=vals; 32 | //CPT=curr_CPT; 33 | 34 | } 35 | string get_name() 36 | { 37 | return Node_Name; 38 | } 39 | vector get_children() 40 | { 41 | return Children; 42 | } 43 | vector get_Parents() 44 | { 45 | return Parents; 46 | } 47 | vector get_CPT() 48 | { 49 | return CPT; 50 | } 51 | int get_nvalues() 52 | { 53 | return nvalues; 54 | } 55 | vector get_values() 56 | { 57 | return values; 58 | } 59 | void set_CPT(vector new_CPT) 60 | { 61 | CPT.clear(); 62 | CPT=new_CPT; 63 | } 64 | void set_Parents(vector Parent_Nodes) 65 | { 66 | Parents.clear(); 67 | Parents=Parent_Nodes; 68 | } 69 | int add_child(int new_child_index ) 70 | { 71 | for(int i=0;i Pres_Graph; 89 | 90 | public: 91 | int addNode(Graph_Node node) 92 | { 93 | Pres_Graph.push_back(node); 94 | return 0; 95 | } 96 | list::iterator getNode(int i) 97 | { 98 | int count=0; 99 | list::iterator listIt; 100 | for(listIt=Pres_Graph.begin();listIt!=Pres_Graph.end();listIt++) 101 | { 102 | if(count++==i) 103 | break; 104 | 105 | } 106 | return listIt; 107 | } 108 | int netSize() 109 | { 110 | return Pres_Graph.size(); 111 | } 112 | int get_index(string val_name) 113 | { 114 | list::iterator listIt; 115 | int count=0; 116 | for(listIt=Pres_Graph.begin();listIt!=Pres_Graph.end();listIt++) 117 | { 118 | if(listIt->get_name().compare(val_name)==0) 119 | return count; 120 | count++; 121 | } 122 | return -1; 123 | } 124 | 125 | list::iterator get_nth_node(int n) 126 | { 127 | list::iterator listIt; 128 | int count=0; 129 | for(listIt=Pres_Graph.begin();listIt!=Pres_Graph.end();listIt++) 130 | { 131 | if(count==n) 132 | return listIt; 133 | count++; 134 | } 135 | return listIt; 136 | } 137 | 138 | list::iterator search_node(string val_name) 139 | { 140 | list::iterator listIt; 141 | for(listIt=Pres_Graph.begin();listIt!=Pres_Graph.end();listIt++) 142 | { 143 | if(listIt->get_name().compare(val_name)==0) 144 | return listIt; 145 | } 146 | 147 | cout<<"node not found\n"; 148 | return listIt; 149 | } 150 | 151 | 152 | }; 153 | 154 | void check_format() 155 | { 156 | network Alarm; 157 | string line,testline; 158 | int find=0; 159 | ifstream myfile("alarm.bif"); 160 | ifstream testfile("solved_alarm.bif"); 161 | string temp; 162 | string name; 163 | vector values; 164 | int line_count=1; 165 | if (myfile.is_open()) 166 | { 167 | 168 | while (! myfile.eof() ) 169 | { 170 | 171 | getline (myfile,line); 172 | 173 | 174 | 175 | 176 | 177 | getline (testfile,testline); 178 | if(testline.compare(line)!=0) 179 | { 180 | cout<<"Error Here in line number"<>temp; 187 | 188 | 189 | 190 | if(temp.compare("probability")==0) 191 | { 192 | string test_temp; 193 | 194 | getline (myfile,line); 195 | getline (testfile,testline); 196 | 197 | stringstream ss2; 198 | stringstream testss2; 199 | ss2.str(line); 200 | ss2>> temp; 201 | testss2.str(testline); 202 | testss2>>test_temp; 203 | if(test_temp.compare(temp)!=0) 204 | { 205 | cout<<"Error Here in line number"<> temp; 209 | testss2>>test_temp; 210 | vector curr_CPT; 211 | string::size_type sz; 212 | while(temp.compare(";")!=0) 213 | { 214 | 215 | if(!atof(test_temp.c_str())) 216 | { 217 | cout<<" Probem in Probab values in line "<>temp; 222 | testss2>>test_temp; 223 | 224 | 225 | 226 | } 227 | if(test_temp.compare(";")!=0) 228 | { 229 | cout<<" Probem in Semi-colon in line "< values; 265 | 266 | if (myfile.is_open()) 267 | { 268 | while (! myfile.eof() ) 269 | { 270 | stringstream ss; 271 | getline (myfile,line); 272 | 273 | 274 | ss.str(line); 275 | ss>>temp; 276 | 277 | 278 | if(temp.compare("variable")==0) 279 | { 280 | 281 | ss>>name; 282 | getline (myfile,line); 283 | 284 | stringstream ss2; 285 | ss2.str(line); 286 | for(int i=0;i<4;i++) 287 | { 288 | 289 | ss2>>temp; 290 | 291 | 292 | } 293 | values.clear(); 294 | while(temp.compare("};")!=0) 295 | { 296 | values.push_back(temp); 297 | 298 | ss2>>temp; 299 | } 300 | Graph_Node new_node(name,values.size(),values); 301 | int pos=Alarm.addNode(new_node); 302 | 303 | 304 | } 305 | else if(temp.compare("probability")==0) 306 | { 307 | 308 | ss>>temp; 309 | ss>>temp; 310 | 311 | list::iterator listIt; 312 | list::iterator listIt1; 313 | listIt=Alarm.search_node(temp); 314 | int index=Alarm.get_index(temp); 315 | ss>>temp; 316 | values.clear(); 317 | while(temp.compare(")")!=0) 318 | { 319 | listIt1=Alarm.search_node(temp); 320 | listIt1->add_child(index); 321 | values.push_back(temp); 322 | 323 | ss>>temp; 324 | 325 | } 326 | listIt->set_Parents(values); 327 | getline (myfile,line); 328 | stringstream ss2; 329 | 330 | ss2.str(line); 331 | ss2>> temp; 332 | 333 | ss2>> temp; 334 | 335 | vector curr_CPT; 336 | string::size_type sz; 337 | while(temp.compare(";")!=0) 338 | { 339 | 340 | curr_CPT.push_back(atof(temp.c_str())); 341 | 342 | ss2>>temp; 343 | 344 | 345 | 346 | } 347 | 348 | listIt->set_CPT(curr_CPT); 349 | 350 | 351 | } 352 | else 353 | { 354 | 355 | } 356 | 357 | 358 | 359 | 360 | 361 | } 362 | 363 | if(find==1) 364 | myfile.close(); 365 | } 366 | 367 | return Alarm; 368 | } 369 | 370 | int main() 371 | { 372 | network Alarm1,Alarm2; 373 | check_format(); 374 | Alarm1=read_network((char*)"solved_alarm.bif"); 375 | Alarm2=read_network((char*)"gold_alarm.bif"); 376 | float score=0; 377 | for(int i=0;i::iterator listIt1=Alarm1.get_nth_node(i); 380 | list::iterator listIt2=Alarm2.get_nth_node(i); 381 | vector cpt1=listIt1->get_CPT(); 382 | vector cpt2=listIt2->get_CPT(); 383 | for(int j=0;j660): 358 | # print "OVER TIME. . . . " 359 | break 360 | if delta <= 0.00005: 361 | break 362 | curr_iter +=1 363 | print "Converged in (" + str(curr_iter) + ") iterations" 364 | 365 | return bn 366 | 367 | # Parse learned parameters to 'solved_alarm.bif' 368 | def parse_output(Alarm, bif_alarm): 369 | i = 0 370 | with open('solved_alarm.bif', 'w') as output, open(bif_alarm, 'r') as input: 371 | while True: 372 | line0 = input.readline() 373 | line = line0.strip() 374 | if line == '': 375 | break 376 | tokens = line.split() 377 | first_word = tokens[0] 378 | if first_word == 'table': 379 | X = Alarm.Pres_Graph.keys()[i] 380 | l = [X] + Alarm.Pres_Graph[X].Parents 381 | to_write = np.asarray(Alarm.Pres_Graph[X].cpt_data.sort_values(l, ascending = True)['p']) 382 | to_write = ["{:10.4f}".format(item) for item in to_write] 383 | to_write = str(to_write)[1:len(str(to_write))-1].replace("'", "") 384 | to_write = to_write.replace(",", "") 385 | to_write = to_write.replace(" ", " ") 386 | to_write = to_write.replace(" ", "") 387 | output.write('\ttable '+ to_write + " ;\n") 388 | i+=1 389 | else: 390 | output.write(line0) 391 | 392 | 393 | if __name__ == '__main__': 394 | print "This file contains utility functions: Run main.py" 395 | 396 | 397 | 398 | 399 | --------------------------------------------------------------------------------