├── .hgignore ├── README.md ├── setup.py └── spanning_tree ├── __init__.py ├── brute_force.py ├── matrix_tree.py └── test_matrix_tree.py /.hgignore: -------------------------------------------------------------------------------- 1 | syntax: glob 2 | __pycache__ 3 | *.egg-info/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A reference implementation of algorithms for distributions over spanning trees 2 | 3 | This project provides a reference implementation of the inference algorithms for 4 | globally normalized distributions over directed spanning trees. In natural 5 | language processing, such models are used to develop nonprojective dependency 6 | parsers. The current implementation shows the matrix-tree theorem in 7 | action. The implementation is carefully tested against brute-force algorithms 8 | that explicitly marginalizes over directed trees. 9 | 10 | 11 | **Citation:** If you found this useful, please cite it as 12 | ```bibtex 13 | @software{vieira-spanningtree, 14 | author = {Tim Vieira}, 15 | title = {A reference implementation of algorithms for distributions over spanning trees}, 16 | url = {https://github.com/timvieira/spanning_tree} 17 | } 18 | ``` 19 | 20 | **References**: In 2007, three different groups published similar methods for 21 | inference in discriminative nonprojective dependency parsing. 22 | 23 | - Koo, Globerson, Carreras and Collins (EMNLP'07) 24 | [Structured Prediction Models via the Matrix-Tree Theorem](https://www.aclweb.org/anthology/D07-1015) 25 | 26 | - Smith and Smith (EMNLP'07) 27 | [Probabilistic Models of Nonprojective Dependency Trees](https://www.aclweb.org/anthology/D07-1014) 28 | 29 | - McDonald & Satta (IWPT'07) 30 | [On the Complexity of Non-Projective Data-Driven Dependency Parsing](https://www.aclweb.org/anthology/W07-2216) 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='spanning_tree', 4 | author='Tim Vieira', 5 | description='A reference implementation of the matrix-tree theorem with applications to nonprojective spanning tree models in natural language processing.', 6 | version='1.0', 7 | install_requires=[ 8 | 'arsenal', 9 | ], 10 | dependency_links=[ 11 | 'https://github.com/timvieira/arsenal.git', 12 | ], 13 | packages=['spanning_tree'], 14 | ) 15 | -------------------------------------------------------------------------------- /spanning_tree/__init__.py: -------------------------------------------------------------------------------- 1 | from spanning_tree.matrix_tree import matrix_tree_theorem 2 | -------------------------------------------------------------------------------- /spanning_tree/brute_force.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | from arsenal.maths import logsumexp 4 | from itertools import product 5 | 6 | 7 | class brute_force: 8 | "Brute-force computation for spanning tree distrbutions." 9 | 10 | def __init__(self, A, r): 11 | [N,_] = A.shape 12 | self.N = N; self.A = A; self.r = r 13 | 14 | self.scores = {(root, tree): self.score(root, tree) for root, tree in self.domain()} 15 | lnz = logsumexp(list(self.scores.values())) 16 | 17 | R = np.zeros(N) # root marginals 18 | M = np.zeros((N,N)) # edges marginals 19 | P = {} 20 | for (root, tree), s in self.scores.items(): 21 | p = np.exp(s - lnz) 22 | P[root, tree] = p 23 | R[root] += p 24 | for h,m in tree: 25 | M[h,m] += p 26 | 27 | self.P = P 28 | self.lnz = lnz 29 | self.R = R 30 | self.M = M 31 | 32 | def lprob(self, root, tree): 33 | "Log-probability of a `tree` with a specific `root`." 34 | return self.scores[root, tree] - self.lnz 35 | 36 | def score(self, root, tree): 37 | "Unnormalized log probability of a `tree` with a specific `root`." 38 | s = self.r[root] 39 | for h,m in tree: 40 | assert m != root 41 | s += self.A[h,m] 42 | return s 43 | 44 | def domain(self): 45 | "Enumerate the support of the probability distribution" 46 | return enumerate_dtrees(self.N) 47 | 48 | 49 | def enumerate_dtrees(n): 50 | "Enumerate all spanning trees of a complete graph over n nodes." 51 | 52 | # Implementation: We use a rejection-based strategy where the (two) outer 53 | # loops "propose" singly connected rooted graphs and the body of loops 54 | # checks whether the singly connected graph is also acyclic (i.e., a tree). 55 | proposals = 0; accepts = 0 56 | for root in range(n): 57 | 58 | # Define for each node a set of possible `parents` 59 | # 60 | # Outer loop: Each node picks a parent (that's not itself). Since a 61 | # `root` has been chosen, we ensure here that no graphs will have edges 62 | # pointing from it. (In other words, the outer loop is over singly 63 | # connected graphs.) 64 | parents = [] 65 | for j in range(n): 66 | if j == root: 67 | d = [None] # `None` is a sentinel value; the important thing is 68 | # that the set has size one. 69 | else: 70 | d = set(range(n)) - {j} # minor improvement: drop `j` from the 71 | # set: a tree can't have a self loop. 72 | parents.append(d) 73 | 74 | # Iterate over the Cartesian product of possible parents to get singly 75 | # connected graphs (technically, with no self cycles) 76 | for A in product(*parents): 77 | proposals += 1 78 | 79 | edges = [(h,m) for m, h in enumerate(A) if m != root and h != None] 80 | 81 | if is_arborescence(edges): 82 | accepts += 1 83 | yield root, frozenset(edges) 84 | 85 | assert proposals == n * (n-1)**(n-1) 86 | assert accepts == n ** (n-1) 87 | 88 | 89 | def is_arborescence(edges): 90 | "Check if `edges` forms an arborescence (a directed tree)." 91 | return nx.algorithms.tree.recognition.is_arborescence(nx.DiGraph(list(edges))) 92 | -------------------------------------------------------------------------------- /spanning_tree/matrix_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def matrix_tree_theorem(A, r): 5 | """ 6 | Compute the log-partition function and its gradients (marginal probabilities) 7 | for non-projective dependency parser given a weighted adjacency matrix 8 | `A: head × modifier ↦ log-weight` (note: diagonal ignored) 9 | and root weight vector `r: head ↦ log-weight`. 10 | 11 | References 12 | 13 | - Koo, Globerson, Carreras and Collins (EMNLP'07) 14 | Structured Prediction Models via the Matrix-Tree Theorem 15 | https://www.aclweb.org/anthology/D07-1015 16 | 17 | - Smith and Smith (EMNLP'07) 18 | Probabilistic Models of Nonprojective Dependency Trees 19 | https://www.aclweb.org/anthology/D07-1014 20 | 21 | - McDonald & Satta (IWPT'07) 22 | https://www.aclweb.org/anthology/W07-2216 23 | 24 | """ 25 | 26 | # Numerical stability trick: We use an extension of the log-sum-exp and 27 | # exp-normalize tricks to our log-det-exp setting. [I haven't never seen 28 | # this trick elsewhere.] 29 | # 30 | # Note: The `exp` function below is point-wise exponential, not the matrix 31 | # exponential! 32 | # 33 | # for any value c, 34 | # 35 | # log(det(exp(c) * exp(A - c))) 36 | # = log(exp(c)^n * det(exp(A - c))) 37 | # = c*n + log(det(exp(A - c))) 38 | # 39 | # Furthermore, 40 | # 41 | # ∇ log(det(exp(c) * exp(A - c))) 42 | # = ∇ [ c*n + log(det(exp(A - c))) ] 43 | # = exp(A - c)⁻ᵀ 44 | # 45 | # Vector version of the trick: for any n-dimensional vector `c`, 46 | # 47 | # log(det(diag(exp(c)) * exp(A - c))) 48 | # = log(product(exp(c)) * det(exp(A - c))) 49 | # = sum(c) + log(det(exp(A - c))) 50 | # 51 | # Although, it is a generalization of the scalar trick, I don't 52 | # think it makes much of a difference to have the extra parameters. 53 | # Much like the log-sum-exp trick, the goal is only to avoid overflow 54 | # the scalar is enough to do that. 55 | 56 | c = max(r.max(), A.max()) 57 | 58 | r = np.exp(r - c) 59 | A = np.exp(A - c) 60 | np.fill_diagonal(A, 0) 61 | 62 | L = np.diag(A.sum(axis=0)) - A # The Laplacian matrix of a graph 63 | L[0,:] = r # Koo et al.'s efficiency trick 64 | 65 | lnz = np.linalg.slogdet(L)[1] + c*len(r) 66 | 67 | dL = np.linalg.inv(L).T 68 | dr = r * dL[0,:] 69 | 70 | dA = A * 0 71 | N = len(r) 72 | for h in range(N): 73 | for m in range(N): 74 | dA[h,m] = A[h,m] * (dL[m,m] * (m!=0) - dL[h,m] * (h!=0)) 75 | 76 | return lnz, dr, dA 77 | -------------------------------------------------------------------------------- /spanning_tree/test_matrix_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for matrix-tree theorem. 3 | """ 4 | import numpy as np 5 | from arsenal.maths import quick_fdcheck 6 | 7 | from matrix_tree import matrix_tree_theorem as mt 8 | from brute_force import brute_force, enumerate_dtrees 9 | 10 | 11 | def test_enumerate_dtrees(N): 12 | print('[test enumerate dtrees]') 13 | S = set() 14 | for t in enumerate_dtrees(N): 15 | assert t not in S, [t, S] 16 | S.add(t) 17 | print(f'number of trees over N={N} nodes -> {len(S)}') 18 | # Cayley's formula for the number of complete graphs on N vertices 19 | assert len(S) == N ** (N-1) 20 | 21 | 22 | def test_mt_bf(A, r): 23 | """ 24 | Compare the computation of log Z and edge marginals via the matrix-tree 25 | theorem and brute-force search. Warning this method is exponential in 26 | the number of nodes so it should only be used on small graphs. 27 | """ 28 | 29 | bf = brute_force(A, r) 30 | [bf_lnz, bf_dr, bf_dA] = bf.lnz, bf.R, bf.M 31 | [mt_lnz, mt_dr, mt_dA] = mt(A, r) 32 | 33 | print() 34 | print(f'bf logZ= {bf_lnz:g}') 35 | print(f'mt logZ= {mt_lnz:g}') 36 | 37 | np.testing.assert_allclose(bf_lnz, mt_lnz) 38 | 39 | print('p(i = root(t))') 40 | print('\n'.join(f' {x}' for x in str(mt_dr).split('\n'))) 41 | 42 | print('p( in t)') 43 | print('\n'.join(f' {x}' for x in str(mt_dA).split('\n'))) 44 | 45 | np.testing.assert_allclose(bf_dr, mt_dr) 46 | np.testing.assert_allclose(bf_dA, mt_dA) 47 | 48 | 49 | def test_mt_self_test(A, r): 50 | """ 51 | Run self-consistency tests. 52 | """ 53 | [_, dr, dA] = mt(A, r) 54 | 55 | # Finite-difference gradient test. 56 | assert quick_fdcheck(lambda: mt(A, r)[0], r, dr, verbose=False).max_rel_err <= 0.0001 57 | assert quick_fdcheck(lambda: mt(A, r)[0], A, dA, verbose=False).max_rel_err <= 0.0001 58 | 59 | # An additional self-consistency self is the the marginals sum to one as 60 | # follows, sum_h p( \in T) = 1. 61 | d = dA.sum(axis=0) + dr 62 | assert np.allclose(d, 1) 63 | 64 | 65 | def test(): 66 | 67 | for N in [2,3,4,5]: 68 | test_enumerate_dtrees(N) 69 | 70 | for N in [2,3,4,5]: 71 | A = np.random.normal(size=(N,N)) 72 | r = np.random.normal(size=N) 73 | test_mt_bf(A, r) 74 | 75 | for N in [10,20,50]: 76 | A = np.random.normal(size=(N,N)) 77 | r = np.random.normal(size=N) 78 | test_mt_self_test(A, r) 79 | 80 | 81 | if __name__ == '__main__': 82 | test() 83 | --------------------------------------------------------------------------------