├── README.md ├── pictures ├── calc_example_bayesnet.png └── variable_elimination_demo.png └── src ├── calculate_example1.m ├── calculate_example2.m ├── calculate_example3.m ├── calculate_example4.m ├── calculate_example5.m ├── direct_sample.m ├── elim.m ├── find_valid.m ├── gen_key.m ├── get_conditional.m ├── make_factor.m ├── make_product.m ├── refine_input.m ├── refine_keys_obs.m ├── refine_obs_cell.m ├── var_elim.m ├── variables.m └── variables2.m /README.md: -------------------------------------------------------------------------------- 1 | # bayesnet-variable-elimination 2 | MATLAB implementation of variable elimination in bayesian networks. Since variable elimination is essentially based on factors, it is also possible to use the implementation on MRFs, CRFs, etc. 3 | # Run a demo 4 | Clone the github repository and try running calculate_example1.m and calculate_example2.m. These codes will produce conditional probabilities of P(B|J,M) and P(E|J,M) for the Bayesian network provided in 'calc_example_bayesnet'. 5 | # Implementation details 6 | Each factor was represented using MATLAB containers.MAP(a.k.a hash tables). In addition each variable configuration is represented using a string. For example, if a model has 3 binary parameters A, B, and C and A = 1, B = 0, C = 1, then a string representing this configuration would be 'TFT'. Also, an auxilliary character 'N' was used to depict situations in which certain variables are out of scope. 7 | # TODO 8 | For now, the implementation only works on bayesian networks with binary variables. 9 | -------------------------------------------------------------------------------- /pictures/calc_example_bayesnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/82magnolia/bayesnet-variable-elimination/c74dd83e57e8eb75e18b61f5a2a187f58be5401d/pictures/calc_example_bayesnet.png -------------------------------------------------------------------------------- /pictures/variable_elimination_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/82magnolia/bayesnet-variable-elimination/c74dd83e57e8eb75e18b61f5a2a187f58be5401d/pictures/variable_elimination_demo.png -------------------------------------------------------------------------------- /src/calculate_example1.m: -------------------------------------------------------------------------------- 1 | % calculate p(B|J,M) 2 | variables % import probability tables 3 | default_str = 'NNNNN'; 4 | p_list = {p_B,p_E,p_A_BE,p_J_A,p_M_A}; % Make a list of probabilities to be 5 | % calculated in the correct order 6 | p_partition = {[1],[2],[3,4,5]}; % The order 1~5 should not change: the only 7 | % thing that should be modified is the partiion 8 | query_list = [1,4,5]; 9 | elim_list = [2,3]; 10 | p_joint_BJM = var_elim(p_list,p_partition,query_list,elim_list,default_str); 11 | cond_B_JM = get_conditional(p_joint_BJM,[4,5],[1],default_str); 12 | %{ 13 | cond_B_JM.keys 14 | 15 | ans = 16 | 17 | 1×8 cell array 18 | 19 | Columns 1 through 7 20 | 21 | {'FNNFF'} {'FNNFT'} {'FNNTF'} {'FNNTT'} {'TNNFF'} {'TNNFT'} {'TNNTF'} 22 | 23 | Column 8 24 | 25 | {'TNNTT'} 26 | 27 | cond_B_JM.values 28 | 29 | ans = 30 | 31 | 1×8 cell array 32 | 33 | Columns 1 through 6 34 | 35 | {[0.9987]} {[0.9119]} {[0.9944]} {[0.6258]} {[0.0013]} {[0.0881]} 36 | 37 | Columns 7 through 8 38 | 39 | 40 | {[0.0056]} {[0.3742]} 41 | 42 | 'TNNFF' indicates p(B=T|J=F,M=F) and so on 43 | %} 44 | % To run an actual demo, clone the following gitbut repositoty: 45 | % https://github.com/82magnolia/bayesnet-variable-elimination.git -------------------------------------------------------------------------------- /src/calculate_example2.m: -------------------------------------------------------------------------------- 1 | % calculate p(E|J,M) 2 | variables % import probability tables 3 | default_str = 'NNNNN'; 4 | p_list = {p_E,p_B,p_A_BE,p_J_A,p_M_A}; 5 | p_partition = {[1],[2],[3,4,5]}; % The order 1~5 should not change: the only 6 | % thing that should be modified is the partition 7 | query_list = [2,4,5]; 8 | elim_list = [1,3]; 9 | p_joint_EJM = var_elim(p_list,p_partition,query_list,elim_list,default_str); 10 | cond_E_JM = get_conditional(p_joint_EJM,[4,5],[2],default_str); 11 | %{ 12 | cond_E_JM.keys 13 | 14 | ans = 15 | 16 | 1×8 cell array 17 | 18 | Columns 1 through 7 19 | 20 | {'NFNFF'} {'NFNFT'} {'NFNTF'} {'NFNTT'} {'NTNFF'} {'NTNFT'} {'NTNTF'} 21 | 22 | Column 8 23 | 24 | {'NTNTT'} 25 | 26 | cond_E_JM.values 27 | 28 | ans = 29 | 30 | 1×8 cell array 31 | 32 | Columns 1 through 6 33 | 34 | {[0.9922]} {[0.9806]} {[0.9916]} {[0.9427]} {[0.0078]} {[0.0194]} 35 | 36 | Columns 7 through 8 37 | 38 | {[0.0084]} {[0.0573]} 39 | 40 | 'TNNFF' indicates p(E=T|J=F,M=F) and so on 41 | 42 | 43 | %} 44 | % To run an actual demo, clone the following gitbut repositoty: 45 | % https://github.com/82magnolia/bayesnet-variable-elimination.git -------------------------------------------------------------------------------- /src/calculate_example3.m: -------------------------------------------------------------------------------- 1 | % calculate p(S,R|W) 2 | variables2 % import probability tables 3 | default_str = 'NNNN'; 4 | p_list = {p_W_SR,p_C,p_S_C,p_R_C}; 5 | p_partition = {[1],[2,3,4]}; % The order 1~5 should not change: the only 6 | % thing that should be modified is the partition 7 | query_list = [2,3,4]; 8 | elim_list = [1]; 9 | p_joint_SRW = var_elim(p_list,p_partition,query_list,elim_list,default_str); 10 | cond_SR_W = get_conditional(p_joint_SRW,[4],[2,3],default_str); -------------------------------------------------------------------------------- /src/calculate_example4.m: -------------------------------------------------------------------------------- 1 | % calculate p(S,R|W) 2 | variables2 % import probability tables 3 | default_str = 'NNNN'; 4 | sampling_no = 15000; 5 | [samples,emp_table] = direct_sample(p_list,sampling_no,default_str); 6 | mod_emp_table = elim(emp_table,1,default_str); 7 | cond_SR_W = get_conditional(mod_emp_table,[4],[2,3],default_str); -------------------------------------------------------------------------------- /src/calculate_example5.m: -------------------------------------------------------------------------------- 1 | % Script for calculating differences between exact and empirical methods 2 | calculate_example3 3 | exact_table = cond_SR_W; 4 | exact_joint = p_joint_SRW; 5 | exact_values = cell2mat(exact_table.values); 6 | calculate_example4 7 | emp_table = cond_SR_W; 8 | emp_joint = mod_emp_table; 9 | emp_values = cell2mat(emp_table.values); 10 | norm(emp_values-exact_values) 11 | % Using 15000 samples gave an error rate of 0.0052 12 | % To run an actual demo, clone the following gitbut repositoty: 13 | % https://github.com/82magnolia/bayesnet-variable-elimination.git -------------------------------------------------------------------------------- /src/direct_sample.m: -------------------------------------------------------------------------------- 1 | function [samples,emp_table] = direct_sample(p_list,sampling_no,default_str) 2 | % sampling_no is the total number of times to sample 3 | samples = repmat({default_str},1,sampling_no); 4 | emp_table_keys = gen_key(1:size(default_str,2),default_str); 5 | emp_table_values = zeros(1,size(emp_table_keys,2)); 6 | emp_table = containers.Map(emp_table_keys,emp_table_values); 7 | for i = 1:sampling_no 8 | for j = 1:size(p_list,2) 9 | samples{i}(j) = 'T'; 10 | test = rand; 11 | prob_keys = p_list{j}.keys; 12 | target_key = prob_keys{1}; 13 | new_input = refine_input(target_key,samples{i}); 14 | p_table = p_list{j}; 15 | if test > p_table(new_input) 16 | samples{i}(j) = 'F'; 17 | end 18 | end 19 | % samples{i} = modify_sample(samples{i}); 20 | emp_table(samples{i}) = emp_table(samples{i}) + 1; 21 | end 22 | temp = cell2mat(emp_table.values); 23 | temp = temp/sampling_no; 24 | emp_table = containers.Map(emp_table.keys,temp); 25 | end 26 | 27 | function[new_sample] = modify_sample(sample) 28 | new_sample = sample; 29 | for i = 1:size(sample,2) 30 | if sample(i) == 'T' 31 | new_sample(i) = 'F'; 32 | else 33 | new_sample(i) = 'T'; 34 | end 35 | end 36 | end -------------------------------------------------------------------------------- /src/elim.m: -------------------------------------------------------------------------------- 1 | function [simp_factor] = elim(factor,target_no,default_str) 2 | % factor is a hash table 3 | % target_no indicates the variable to be eliminated if -1, no variable is 4 | % eliminated 5 | if target_no == -1 6 | simp_factor = factor; 7 | else 8 | factor_var_list = find_valid(factor.keys); 9 | new_var_list = setdiff(factor_var_list,target_no); 10 | new_keys = gen_key(new_var_list,default_str); 11 | new_value = []; 12 | for i = 1:size(new_keys,2) 13 | T_key = new_keys{i}; 14 | F_key = new_keys{i}; 15 | T_key(target_no) = 'T'; 16 | F_key(target_no) = 'F'; 17 | new_value(i) = factor(T_key)+factor(F_key); 18 | end 19 | simp_factor = containers.Map(new_keys,new_value); 20 | simp_factor; 21 | end 22 | end -------------------------------------------------------------------------------- /src/find_valid.m: -------------------------------------------------------------------------------- 1 | function [var_list] = find_valid(keys) 2 | var_list = []; 3 | sample = keys{1}; 4 | for i = 1:size(sample,2) 5 | if sample(i) ~= 'N' 6 | var_list = union(var_list,i); 7 | end 8 | end 9 | end -------------------------------------------------------------------------------- /src/gen_key.m: -------------------------------------------------------------------------------- 1 | function [new_keys] = gen_key(var_list,default_str) 2 | % var_list is a vector containing variables that are present in the scope 3 | % ex) [2,3,4] 4 | new_keys = {default_str}; % if number of variables increase, this term should be modified 5 | for i = 1:size(var_list,2) 6 | new_keys = modify_key(new_keys,var_list(i)); 7 | end 8 | end 9 | 10 | function [mod_key] = modify_key(input_keys,var_pos) 11 | mod_key = cell(1,2*size(input_keys,2)); 12 | for i = 1:size(input_keys,2) 13 | mod_T = input_keys{i}; 14 | mod_F = input_keys{i}; 15 | mod_T(var_pos) = 'T'; 16 | mod_F(var_pos) = 'F'; 17 | mod_key{2*i-1} = mod_T; 18 | mod_key{2*i} = mod_F; 19 | end 20 | end -------------------------------------------------------------------------------- /src/get_conditional.m: -------------------------------------------------------------------------------- 1 | function [conditional_table] = get_conditional(joint_distrib,cond_vars,query_vars,default_str) 2 | % joint_distrib is a probability table for the joint distribution 3 | % cond_vars are variables that are conditioned ex)[1,3] 4 | % query_vars are variables that are queried ex)[2,4] 5 | cond_keys = gen_key(cond_vars,default_str); 6 | query_keys = gen_key(query_vars,default_str); 7 | new_keys = cell(1,size(cond_keys,2)*size(query_keys,2)); 8 | new_values = []; 9 | key_count = 1; 10 | for i = 1:size(cond_keys,2) 11 | temp_values = zeros(1,size(query_keys,2)); 12 | for j = 1:size(query_keys,2) 13 | new_keys{key_count} = key_join(cond_keys{i},query_keys{j}); 14 | temp_values(j) = joint_distrib(new_keys{key_count}); 15 | key_count = key_count + 1; 16 | end 17 | temp_values = temp_values/sum(temp_values); 18 | new_values = [new_values,temp_values]; 19 | end 20 | conditional_table = containers.Map(new_keys,new_values); 21 | end 22 | 23 | function [new_key] = key_join(key1,key2) 24 | new_key = key1; 25 | for i = 1:size(key2,2) 26 | if key2(i) ~= 'N' 27 | new_key(i) = key2(i); 28 | end 29 | end 30 | end -------------------------------------------------------------------------------- /src/make_factor.m: -------------------------------------------------------------------------------- 1 | function [factor] = make_factor(obs_cell,p_table) 2 | % obs_cell: Cell specifying observed variables ex){{3,'T'},{4,'F'}} 3 | % p_table: Input probability table 4 | orig_keys = p_table.keys; 5 | obs_cell = refine_obs_cell(orig_keys{1},obs_cell); 6 | keys = refine_keys_obs(p_table.keys,obs_cell); 7 | value = zeros(1,size(keys,2)); 8 | for i = 1:size(keys,2) 9 | value(i) = p_table(keys{i}); 10 | for j = 1:size(obs_cell,2) 11 | keys{i}(obs_cell{j}{1}) = 'N'; 12 | end 13 | end 14 | factor = containers.Map(keys,value); 15 | end -------------------------------------------------------------------------------- /src/make_product.m: -------------------------------------------------------------------------------- 1 | function [factor] = make_product(factor_list,default_str) 2 | % factor_list is a cell containing factors 3 | % Simplifies factors: ex) f1(A)*f2(A,B,E)*f3(A,B) = f4(A,B,E) 4 | var_list = []; 5 | for i = 1:size(factor_list,2) 6 | var_list = union(var_list,find_valid(factor_list{i}.keys)); 7 | var_list = var_list'; 8 | end 9 | new_keys = gen_key(var_list,default_str); 10 | new_value = zeros(1,size(new_keys,2)); 11 | val_count = 1; 12 | for i = 1:size(new_keys,2) 13 | val = 1; 14 | for j = 1:size(factor_list,2) 15 | factor_keys = factor_list{j}.keys; 16 | ref_input = refine_input(factor_keys{1},new_keys{i}); 17 | val = val * factor_list{j}(ref_input); 18 | end 19 | new_value(val_count) = val; 20 | val_count = val_count + 1; 21 | end 22 | factor = containers.Map(new_keys,new_value); 23 | end -------------------------------------------------------------------------------- /src/refine_input.m: -------------------------------------------------------------------------------- 1 | function [new_input] = refine_input(target_key,query_string) 2 | % Refines input to take care of out-of-scope variables 3 | new_input = query_string; 4 | for i = 1:size(query_string,2) 5 | if target_key(i) == 'N' 6 | new_input(i) = 'N'; 7 | end 8 | end -------------------------------------------------------------------------------- /src/refine_keys_obs.m: -------------------------------------------------------------------------------- 1 | function [new_keys] = refine_keys_obs(target_keys,obs_cell) 2 | new_keys = {}; 3 | key_count = 1; 4 | for i = 1:size(target_keys,2) 5 | validity = true; 6 | for j = 1:size(obs_cell,2) 7 | if target_keys{i}(obs_cell{j}{1}) ~= obs_cell{j}{2} 8 | validity = false; 9 | break; 10 | end 11 | end 12 | if validity == true 13 | new_keys{key_count} = target_keys{i}; 14 | key_count = key_count + 1; 15 | end 16 | end -------------------------------------------------------------------------------- /src/refine_obs_cell.m: -------------------------------------------------------------------------------- 1 | function [new_obs_cell] = refine_obs_cell(target_key,obs_cell) 2 | new_obs_cell = {}; 3 | cell_count = 1; 4 | for i = 1:size(obs_cell,2) 5 | if target_key(obs_cell{i}{1}) ~= 'N' 6 | new_obs_cell{cell_count} = obs_cell{i}; 7 | cell_count = cell_count + 1; 8 | end 9 | end -------------------------------------------------------------------------------- /src/var_elim.m: -------------------------------------------------------------------------------- 1 | function [result_table] = var_elim(p_list,p_partition,query_list,elim_list,default_str) 2 | % One should specify variable orderings in 'variables.m' 3 | % p_list is a cell containing probability tables: p_list should be sorted 4 | % in the same order variables are eliminated 5 | % p_partition is a cell containing groups of factors to be calculated at 6 | % the same elimination step: ex){[1],[2],[3,4,5]} 7 | % query_list is a list of query variables: ex)[1,4,5] 8 | % ellim_list is a list of variables to eliminate in their respective 9 | % orders: ex)[2,3] if variable 2 is eliminated after 3 is eliminated(it is 10 | % analogous to the 'sigma' operation in variable elimination 11 | % -1 is a placeholder indicating no more variables are left to be 12 | % eliminated 13 | % default_str designates input when no variable is in scope 14 | if size(elim_list,2) ~= size(p_partition,2) 15 | elim_list = [-1,elim_list]; 16 | end 17 | result_keys = gen_key(query_list,default_str); 18 | result_values = zeros(1,size(result_keys,2)); 19 | for i = 1:size(result_keys,2) 20 | factors = cell(1,size(p_list,2)); 21 | obs_cell = convert_from_key(result_keys{i}); 22 | for j = 1:size(p_list,2) 23 | factors{j} = make_factor(obs_cell,p_list{j}); 24 | end 25 | prod_factor = containers.Map({default_str},1.0); 26 | for j = size(p_partition,2):-1:1 27 | target_factors = factors(p_partition{j}); 28 | target_factors{size(target_factors,2)+1} = prod_factor; 29 | prod_factor = make_product(target_factors,default_str); 30 | prod_factor = elim(prod_factor,elim_list(j),default_str); 31 | end 32 | result_values(i) = prod_factor(default_str); 33 | end 34 | result_table = containers.Map(result_keys,result_values); 35 | end 36 | 37 | function [obs_cell] = convert_from_key(key) 38 | obs_cell = {}; 39 | cell_count = 1; 40 | for i = 1:size(key,2) 41 | if key(i) ~= 'N' 42 | obs_cell{cell_count} = {i,key(i)}; 43 | cell_count = cell_count + 1; 44 | end 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /src/variables.m: -------------------------------------------------------------------------------- 1 | % Variables are ordered in the following order:B,E,A,J,M 2 | % probability notations: p(A|B,C) = p_A_BC 3 | % p_B 4 | keys = {'TNNNN','FNNNN'}; 5 | values = {0.01,0.99}; 6 | p_B = containers.Map(keys,values); 7 | % p_E 8 | keys = {'NTNNN','NFNNN'}; 9 | values = {0.009,0.991}; 10 | p_E = containers.Map(keys,values); 11 | % p_A_BE 12 | keys = {'TTTNN','FTTNN','TFTNN','FFTNN','TTFNN','FTFNN','TFFNN','FFFNN'}; 13 | values = {0.98,0.14,0.89,0.01,0.02,0.86,0.11,0.99}; 14 | p_A_BE = containers.Map(keys,values); 15 | % p_J_A 16 | keys = {'NNTTN','NNFTN','NNTFN','NNFFN'}; 17 | values = {0.65,0.08,0.35,0.92}; 18 | p_J_A = containers.Map(keys,values); 19 | % p_M_A 20 | keys = {'NNTNT','NNFNT','NNTNF','NNFNF'}; 21 | values = {0.94,0.03,0.06,0.97}; 22 | p_M_A = containers.Map(keys,values); 23 | 24 | p_list = {p_B,p_E,p_A_BE,p_J_A,p_M_A}; 25 | -------------------------------------------------------------------------------- /src/variables2.m: -------------------------------------------------------------------------------- 1 | % Variables are ordered in the following order:C,S,R,W 2 | % probability notations: p(A|B,C) = p_A_BC 3 | % p_C 4 | keys = {'TNNN','FNNN'}; 5 | values = {0.7,0.3}; 6 | p_C = containers.Map(keys,values); 7 | % p_S_C 8 | keys = {'TTNN','TFNN','FTNN','FFNN'}; 9 | values = {0.05,0.95,0.65,0.35}; 10 | p_S_C = containers.Map(keys,values); 11 | % p_R_C 12 | keys = {'TNTN','TNFN','FNTN','FNFN'}; 13 | values = {0.9,0.1,0.3,0.7}; 14 | p_R_C = containers.Map(keys,values); 15 | % p_W_SR 16 | keys = {'NTTT','NTTF','NTFT','NTFF','NFTT','NFTF','NFFT','NFFF'}; 17 | values = {0.99,0.01,0.85,0.15,0.8,0.2,0.1,0.9}; 18 | p_W_SR = containers.Map(keys,values); 19 | 20 | p_list = {p_C,p_S_C,p_R_C,p_W_SR}; 21 | --------------------------------------------------------------------------------