├── .gitignore ├── LICENSE ├── README.markdown ├── graph.py ├── graphTests.py └── node.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 Ryan Lester 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.markdown: -------------------------------------------------------------------------------- 1 | Python implementation of Sum-product (aka Belief-Propagation) for discrete Factor Graphs. 2 | 3 | See [this paper](http://www.comm.utoronto.ca/frank/papers/KFL01.pdf) for more details on the Factor Graph framework and the sum-product algorithm. This code was originally written as part of a grad student seminar taught by Erik Sudderth at Brown University; the [seminar web page](http://cs.brown.edu/courses/csci2420/) is an excellent resource for learning more about graphical models. 4 | 5 | Requires NumPy and future. 6 | 7 | To use: 8 | 9 | from graph import Graph 10 | import numpy as np 11 | 12 | G = Graph() 13 | 14 | # add variable nodes 15 | a = G.addVarNode('a',3) 16 | b = G.addVarNode('b',2) 17 | 18 | # add factors 19 | # unary factor 20 | Pb = np.array([[0.3],[0.7]]) 21 | G.addFacNode(Pb, b) 22 | 23 | # connecting factor 24 | Pab = np.array([[0.2, 0.8], [0.4, 0.6], [0.1, 0.9]]) 25 | G.addFacNode(Pab, a, b) 26 | 27 | # factors can connect an arbitrary number of variables 28 | 29 | # run sum-product and get marginals for variables 30 | marg = G.marginals() 31 | distA = marg['a'] 32 | distB = marg['b'] 33 | 34 | # reset before altering graph further 35 | G.reset() 36 | 37 | # condition on variables 38 | G.var['a'].condition(0) 39 | 40 | # disable and enable to run sum-product on subgraphs 41 | G.var['b'].disable() 42 | G.var['b'].enable() 43 | G.disableAll() 44 | # reset automatically enables all variables and removes conditioning 45 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | # Graph class 2 | from __future__ import print_function 3 | from builtins import range 4 | from future.utils import iteritems 5 | import numpy as np 6 | from node import FacNode, VarNode 7 | import pdb 8 | 9 | """ Factor Graph classes forming structure for PGMs 10 | Basic structure is port of MATLAB code by J. Pacheco 11 | Central difference: nbrs stored as references, not ids 12 | (makes message propagation easier) 13 | """ 14 | 15 | class Graph: 16 | """ Putting everything together 17 | """ 18 | 19 | def __init__(self): 20 | self.var = {} 21 | self.fac = [] 22 | self.dims = [] 23 | self.converged = False 24 | 25 | def addVarNode(self, name, dim): 26 | newId = len(self.var) 27 | newVar = VarNode(name, dim, newId) 28 | self.var[name] = newVar 29 | self.dims.append(dim) 30 | 31 | return newVar 32 | 33 | def addFacNode(self, P, *args): 34 | newId = len(self.fac) 35 | newFac = FacNode(P, newId, *args) 36 | self.fac.append(newFac) 37 | 38 | return newFac 39 | 40 | def disableAll(self): 41 | """ Disable all nodes in graph 42 | Useful for switching on small subnetworks 43 | of bayesian nets 44 | """ 45 | for k, v in iteritems(self.var): 46 | v.disable() 47 | for f in self.fac: 48 | f.disable() 49 | 50 | def reset(self): 51 | """ Reset messages to original state 52 | """ 53 | for k, v in iteritems(self.var): 54 | v.reset() 55 | for f in self.fac: 56 | f.reset() 57 | self.converged = False 58 | 59 | def sumProduct(self, maxsteps=500): 60 | """ This is the algorithm! 61 | Each timestep: 62 | take incoming messages and multiply together to produce outgoing for all nodes 63 | then push outgoing to neighbors' incoming 64 | check outgoing v. previous outgoing to check for convergence 65 | """ 66 | # loop to convergence 67 | timestep = 0 68 | while timestep < maxsteps and not self.converged: # run for maxsteps cycles 69 | timestep = timestep + 1 70 | print(timestep) 71 | 72 | for f in self.fac: 73 | # start with factor-to-variable 74 | # can send immediately since not sending to any other factors 75 | f.prepMessages() 76 | f.sendMessages() 77 | 78 | for k, v in iteritems(self.var): 79 | # variable-to-factor 80 | v.prepMessages() 81 | v.sendMessages() 82 | 83 | # check for convergence 84 | t = True 85 | for k, v in iteritems(self.var): 86 | t = t and v.checkConvergence() 87 | if not t: 88 | break 89 | if t: 90 | for f in self.fac: 91 | t = t and f.checkConvergence() 92 | if not t: 93 | break 94 | 95 | if t: # we have convergence! 96 | self.converged = True 97 | 98 | # if run for 500 steps and still no convergence:impor 99 | if not self.converged: 100 | print("No convergence!") 101 | 102 | def marginals(self, maxsteps=500): 103 | """ Return dictionary of all marginal distributions 104 | indexed by corresponding variable name 105 | """ 106 | # Message pass 107 | self.sumProduct(maxsteps) 108 | 109 | marginals = {} 110 | # for each var 111 | for k, v in iteritems(self.var): 112 | if v.enabled: # only include enabled variables 113 | # multiply together messages 114 | vmarg = 1 115 | for i in range(0, len(v.incoming)): 116 | vmarg = vmarg * v.incoming[i] 117 | 118 | # normalize 119 | n = np.sum(vmarg) 120 | vmarg = vmarg / n 121 | 122 | marginals[k] = vmarg 123 | 124 | return marginals 125 | 126 | def bruteForce(self): 127 | """ Brute force method. Only here for completeness. 128 | Don't use unless you want your code to take forever to produce results. 129 | Note: index corresponding to var determined by order added 130 | Problem: max number of dims in numpy is 32??? 131 | Limit to enabled vars as work-around 132 | """ 133 | # Figure out what is enabled and save dimensionality 134 | enabledDims = [] 135 | enabledNids = [] 136 | enabledNames = [] 137 | enabledObserved = [] 138 | for k, v in iteritems(self.var): 139 | if v.enabled: 140 | enabledNids.append(v.nid) 141 | enabledNames.append(k) 142 | enabledObserved.append(v.observed) 143 | if v.observed < 0: 144 | enabledDims.append(v.dim) 145 | else: 146 | enabledDims.append(1) 147 | 148 | # initialize matrix over all joint configurations 149 | joint = np.zeros(enabledDims) 150 | 151 | # loop over all configurations 152 | self.configurationLoop(joint, enabledNids, enabledObserved, []) 153 | 154 | # normalize 155 | joint = joint / np.sum(joint) 156 | return {'joint': joint, 'names': enabledNames} 157 | 158 | def configurationLoop(self, joint, enabledNids, enabledObserved, currentState): 159 | """ Recursive loop over all configurations 160 | Used for brute force computation 161 | joint - matrix storing joint probabilities 162 | enabledNids - nids of enabled variables 163 | enabledObserved - observed variables (if observed!) 164 | currentState - list storing current configuration of vars up to this point 165 | """ 166 | currVar = len(currentState) 167 | if currVar != len(enabledNids): 168 | # need to continue assembling current configuration 169 | if enabledObserved[currVar] < 0: 170 | for i in range(0,joint.shape[currVar]): 171 | # add new variable value to state 172 | currentState.append(i) 173 | self.configurationLoop(joint, enabledNids, enabledObserved, currentState) 174 | # remove it for next value 175 | currentState.pop() 176 | else: 177 | # do the same thing but only once w/ observed value! 178 | currentState.append(enabledObserved[currVar]) 179 | self.configurationLoop(joint, enabledNids, enabledObserved, currentState) 180 | currentState.pop() 181 | 182 | else: 183 | # compute value for current configuration 184 | potential = 1. 185 | for f in self.fac: 186 | if f.enabled and False not in [x.enabled for x in f.nbrs]: 187 | # figure out which vars are part of factor 188 | # then get current values of those vars in correct order 189 | args = [currentState[enabledNids.index(x.nid)] for x in f.nbrs] 190 | 191 | # get value and multiply in 192 | potential = potential * f.P[tuple(args)] 193 | 194 | # now add it to joint after correcting state for observed nodes 195 | ind = [currentState[i] if enabledObserved[i] < 0 else 0 for i in range(0, currVar)] 196 | joint[tuple(ind)] = potential 197 | 198 | def marginalizeBrute(self, brute, var): 199 | """ Util for marginalizing over joint configuration arrays produced by bruteForce 200 | """ 201 | sumout = list(range(0, len(brute['names']))) 202 | del sumout[brute['names'].index(var)] 203 | marg = np.sum(brute['joint'], tuple(sumout)) 204 | return marg / np.sum(marg) # normalize to sum to one 205 | -------------------------------------------------------------------------------- /graphTests.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from graph import Graph 3 | import numpy as np 4 | 5 | """ Graphs for testing sum product implementation 6 | """ 7 | 8 | def checkEq(a,b): 9 | epsilon = 10**-6 10 | return abs(a-b) < epsilon 11 | 12 | def makeToyGraph(): 13 | """ Simple graph encoding, basic testing 14 | 2 vars, 2 facs 15 | f_a, f_ba - p(a)p(a|b) 16 | factors functions are a little funny but it works 17 | """ 18 | G = Graph() 19 | 20 | a = G.addVarNode('a', 3) 21 | b = G.addVarNode('b', 2) 22 | 23 | Pb = np.array([[0.3], [0.7]]) 24 | G.addFacNode(Pb, b) 25 | 26 | Pab = np.array([[0.2, 0.8], [0.4, 0.6], [0.1, 0.9]]) 27 | G.addFacNode(Pab, a, b) 28 | 29 | return G 30 | 31 | def testToyGraph(): 32 | """ Actual test case 33 | """ 34 | 35 | G = makeToyGraph() 36 | marg = G.marginals() 37 | brute = G.bruteForce() 38 | 39 | # check the results 40 | # want to verify incoming messages 41 | # if vars are correct then factors must be as well 42 | a = G.var['a'].incoming 43 | assert checkEq(a[0][0], 0.34065934) 44 | assert checkEq(a[0][1], 0.2967033) 45 | assert checkEq(a[0][2], 0.36263736) 46 | 47 | b = G.var['b'].incoming 48 | assert checkEq(b[0][0], 0.3) 49 | assert checkEq(b[0][1], 0.7) 50 | assert checkEq(b[1][0], 0.23333333) 51 | assert checkEq(b[1][1], 0.76666667) 52 | 53 | 54 | # check the marginals 55 | am = marg['a'] 56 | assert checkEq(am[0], 0.34065934) 57 | assert checkEq(am[1], 0.2967033) 58 | assert checkEq(am[2], 0.36263736) 59 | 60 | bm = marg['b'] 61 | assert checkEq(bm[0], 0.11538462) 62 | assert checkEq(bm[1], 0.88461538) 63 | 64 | # check brute force against sum-product 65 | amm = G.marginalizeBrute(brute, 'a') 66 | bmm = G.marginalizeBrute(brute, 'b') 67 | assert checkEq(am[0], amm[0]) 68 | assert checkEq(am[1], amm[1]) 69 | assert checkEq(am[2], amm[2]) 70 | assert checkEq(bm[0], bmm[0]) 71 | assert checkEq(bm[1], bmm[1]) 72 | 73 | print("All tests passed!") 74 | 75 | def makeTestGraph(): 76 | """ Graph for HW problem 1.c. 77 | 4 vars, 3 facs 78 | f_a, f_ba, f_dca 79 | """ 80 | G = Graph() 81 | 82 | a = G.addVarNode('a', 2) 83 | b = G.addVarNode('b', 3) 84 | c = G.addVarNode('c', 4) 85 | d = G.addVarNode('d', 5) 86 | 87 | p = np.array([[0.3], [0.7]]) 88 | G.addFacNode(p, a) 89 | 90 | p = np.array([[0.2, 0.8], [0.4, 0.6], [0.1, 0.9]]) 91 | G.addFacNode(p, b, a) 92 | 93 | p = np.array([ [[3., 1.], [1.2, 0.4], [0.1, 0.9], [0.1, 0.9]], [[11., 9.], [8.8, 9.4], [6.4, 0.1], [8.8, 9.4]], [[3., 2.], [2., 2.], [2., 2.], [3., 2.]], [[0.3, 0.7], [0.44, 0.56], [0.37, 0.63], [0.44, 0.56]], [[0.2, 0.1], [0.64, 0.44], [0.37, 0.63], [0.2, 0.1]] ]) 94 | G.addFacNode(p, d, c, a) 95 | 96 | # add a loop - not a part of 1.c., just for testing 97 | # p = np.array([[0.3, 0.2214532], [0.1, 0.4] , [0.33333, 0.76], [0.1, 0.98]]) 98 | # G.addFacNode(p, c, a) 99 | 100 | return G 101 | 102 | def testTestGraph(): 103 | """ Automated test case 104 | """ 105 | G = makeTestGraph() 106 | marg = G.marginals() 107 | brute = G.bruteForce() 108 | 109 | # check the marginals 110 | am = marg['a'] 111 | assert checkEq(am[0], 0.13755539) 112 | assert checkEq(am[1], 0.86244461) 113 | 114 | bm = marg['b'] 115 | assert checkEq(bm[0], 0.33928227) 116 | assert checkEq(bm[1], 0.30358863) 117 | assert checkEq(bm[2], 0.3571291) 118 | 119 | cm = marg['c'] 120 | assert checkEq(cm[0], 0.30378128) 121 | assert checkEq(cm[1], 0.29216947) 122 | assert checkEq(cm[2], 0.11007584) 123 | assert checkEq(cm[3], 0.29397341) 124 | 125 | dm = marg['d'] 126 | assert checkEq(dm[0], 0.076011) 127 | assert checkEq(dm[1], 0.65388724) 128 | assert checkEq(dm[2], 0.18740039) 129 | assert checkEq(dm[3], 0.05341787) 130 | assert checkEq(dm[4], 0.0292835) 131 | 132 | # check brute force against sum-product 133 | amm = G.marginalizeBrute(brute, 'a') 134 | bmm = G.marginalizeBrute(brute, 'b') 135 | cmm = G.marginalizeBrute(brute, 'c') 136 | dmm = G.marginalizeBrute(brute, 'd') 137 | 138 | assert checkEq(am[0], amm[0]) 139 | assert checkEq(am[1], amm[1]) 140 | 141 | assert checkEq(bm[0], bmm[0]) 142 | assert checkEq(bm[1], bmm[1]) 143 | assert checkEq(bm[2], bmm[2]) 144 | 145 | assert checkEq(cm[0], cmm[0]) 146 | assert checkEq(cm[1], cmm[1]) 147 | assert checkEq(cm[2], cmm[2]) 148 | assert checkEq(cm[3], cmm[3]) 149 | 150 | assert checkEq(dm[0], dmm[0]) 151 | assert checkEq(dm[1], dmm[1]) 152 | assert checkEq(dm[2], dmm[2]) 153 | assert checkEq(dm[3], dmm[3]) 154 | assert checkEq(dm[4], dmm[4]) 155 | 156 | print("All tests passed!") 157 | 158 | # standard run of test cases 159 | testToyGraph() 160 | testTestGraph() 161 | -------------------------------------------------------------------------------- /node.py: -------------------------------------------------------------------------------- 1 | from builtins import range 2 | from functools import reduce 3 | import numpy as np 4 | 5 | """ Factor Graph classes forming structure for PGMs 6 | Basic structure is port of MATLAB code by J. Pacheco 7 | Central difference: nbrs stored as references, not ids 8 | (makes message propagation easier) 9 | 10 | Note to self: use %pdb and %load_ext autoreload followed by %autoreload 2 11 | """ 12 | 13 | class Node(object): 14 | """ Superclass for graph nodes 15 | """ 16 | epsilon = 10**(-4) 17 | 18 | def __init__(self, nid): 19 | self.enabled = True 20 | self.nid = nid 21 | self.nbrs = [] 22 | self.incoming = [] 23 | self.outgoing = [] 24 | self.oldoutgoing = [] 25 | 26 | def reset(self): 27 | self.enabled = True 28 | 29 | def disable(self): 30 | self.enabled = False 31 | 32 | def enable(self): 33 | self.enabled = True 34 | for n in self.nbrs: 35 | # don't call enable() as it will recursively enable entire graph 36 | n.enabled = True 37 | 38 | def nextStep(self): 39 | """ Used to have this line in prepMessages 40 | but it didn't work? 41 | """ 42 | self.oldoutgoing = self.outgoing[:] 43 | 44 | def normalizeMessages(self): 45 | """ Normalize to sum to 1 46 | """ 47 | self.outgoing = [x / np.sum(x) for x in self.outgoing] 48 | 49 | def receiveMessage(self, f, m): 50 | """ Places new message into correct location in new message list 51 | """ 52 | if self.enabled: 53 | i = self.nbrs.index(f) 54 | self.incoming[i] = m 55 | 56 | def sendMessages(self): 57 | """ Sends all outgoing messages 58 | """ 59 | for i in range(0, len(self.outgoing)): 60 | self.nbrs[i].receiveMessage(self, self.outgoing[i]) 61 | 62 | def checkConvergence(self): 63 | """ Check if any messages have changed 64 | """ 65 | if self.enabled: 66 | for i in range(0, len(self.outgoing)): 67 | # check messages have same shape 68 | self.oldoutgoing[i].shape = self.outgoing[i].shape 69 | delta = np.absolute(self.outgoing[i] - self.oldoutgoing[i]) 70 | if (delta > Node.epsilon).any(): # if there has been change 71 | return False 72 | return True 73 | else: 74 | # Always return True if disabled to avoid interrupting check 75 | return True 76 | 77 | class VarNode(Node): 78 | """ Variable node in factor graph 79 | """ 80 | def __init__(self, name, dim, nid): 81 | super(VarNode, self).__init__(nid) 82 | self.name = name 83 | self.dim = dim 84 | self.observed = -1 # only >= 0 if variable is observed 85 | 86 | def reset(self): 87 | super(VarNode, self).reset() 88 | size = range(0, len(self.incoming)) 89 | self.incoming = [np.ones((self.dim,1)) for i in size] 90 | self.outgoing = [np.ones((self.dim,1)) for i in size] 91 | self.oldoutgoing = [np.ones((self.dim,1)) for i in size] 92 | self.observed = -1 93 | 94 | def condition(self, observation): 95 | """ Condition on observing certain value 96 | """ 97 | self.enable() 98 | self.observed = observation 99 | # set messages (won't change) 100 | for i in range(0, len(self.outgoing)): 101 | self.outgoing[i] = np.zeros((self.dim,1)) 102 | self.outgoing[i][self.observed] = 1. 103 | self.nextStep() # copy into oldoutgoing 104 | 105 | def prepMessages(self): 106 | """ Multiplies together incoming messages to make new outgoing 107 | """ 108 | 109 | # compute new messages if no observation has been made 110 | if self.enabled and self.observed < 0 and len(self.nbrs) > 1: 111 | # switch reference for old messages 112 | self.nextStep() 113 | for i in range(0, len(self.incoming)): 114 | # multiply together all excluding message at current index 115 | curr = self.incoming[:] 116 | del curr[i] 117 | self.outgoing[i] = reduce(np.multiply, curr) 118 | 119 | # normalize once finished with all messages 120 | self.normalizeMessages() 121 | 122 | class FacNode(Node): 123 | """ Factor node in factor graph 124 | """ 125 | def __init__(self, P, nid, *args): 126 | super(FacNode, self).__init__(nid) 127 | self.P = P 128 | self.nbrs = list(args) # list storing refs to variable nodes 129 | 130 | # num of edges 131 | numNbrs = len(self.nbrs) 132 | numDependencies = self.P.squeeze().ndim 133 | 134 | # init messages 135 | for i in range(0,numNbrs): 136 | v = self.nbrs[i] 137 | vdim = v.dim 138 | 139 | # init for factor 140 | self.incoming.append(np.ones((vdim,1))) 141 | self.outgoing.append(np.ones((vdim,1))) 142 | self.oldoutgoing.append(np.ones((vdim,1))) 143 | 144 | # init for variable 145 | v.nbrs.append(self) 146 | v.incoming.append(np.ones((vdim,1))) 147 | v.outgoing.append(np.ones((vdim,1))) 148 | v.oldoutgoing.append(np.ones((vdim,1))) 149 | 150 | # error check 151 | assert (numNbrs == numDependencies), "Factor dimensions does not match size of domain." 152 | 153 | def reset(self): 154 | super(FacNode, self).reset() 155 | for i in range(0, len(self.incoming)): 156 | self.incoming[i] = np.ones((self.nbrs[i].dim,1)) 157 | self.outgoing[i] = np.ones((self.nbrs[i].dim,1)) 158 | self.oldoutgoing[i] = np.ones((self.nbrs[i].dim,1)) 159 | 160 | def prepMessages(self): 161 | """ Multiplies incoming messages w/ P to make new outgoing 162 | """ 163 | if self.enabled: 164 | # switch references for old messages 165 | self.nextStep() 166 | 167 | mnum = len(self.incoming) 168 | 169 | # do tiling in advance 170 | # roll axes to match shape of newMessage after 171 | for i in range(0,mnum): 172 | # find tiling size 173 | nextShape = list(self.P.shape) 174 | del nextShape[i] 175 | nextShape.insert(0, 1) 176 | # need to expand incoming message to correct num of dims to tile properly 177 | prepShape = [1 for x in nextShape] 178 | prepShape[0] = self.incoming[i].shape[0] 179 | self.incoming[i].shape = prepShape 180 | # tile and roll 181 | self.incoming[i] = np.tile(self.incoming[i], nextShape) 182 | self.incoming[i] = np.rollaxis(self.incoming[i], 0, i+1) 183 | 184 | # loop over subsets 185 | for i in range(0, mnum): 186 | curr = self.incoming[:] 187 | del curr[i] 188 | newMessage = reduce(np.multiply, curr, self.P) 189 | 190 | # sum over all vars except i! 191 | # roll axis i to front then sum over all other axes 192 | newMessage = np.rollaxis(newMessage, i, 0) 193 | newMessage = np.sum(newMessage, tuple(range(1,mnum))) 194 | newMessage.shape = (newMessage.shape[0],1) 195 | 196 | #store new message 197 | self.outgoing[i] = newMessage 198 | 199 | # normalize once finished with all messages 200 | self.normalizeMessages() 201 | --------------------------------------------------------------------------------