├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── tests.py └── tree_expectations ├── __init__.py ├── applications.py ├── brute_force.py ├── expectation.py ├── matrix_tree_theorem.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 rz279 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree Expectations 2 | This library contains an implementation for efficient 3 | computation of zeroth-, first-, and second-order expectations 4 | under spanning tree models. 5 | A detailed description of these algorithms including proofs of correctness and rumtime can be found in 6 | ["Efficient Computation of Expectations under Spanning Tree Distributions"](https://arxiv.org/abs/2008.12988). 7 | 8 | 9 | ## Citation 10 | 11 | This code is for the paper _Efficient Computation of Expectations under Spanning Tree Distributions_. Please cite as: 12 | 13 | ```bibtex 14 | @inproceedings{zmigrod-etal-2020-efficient, 15 | title = "Efficient Computation of Expectations under Spanning Tree Distributions", 16 | author = "Ran Zmigrod and Tim Vieira and Ryan Cotterell", 17 | journal = "Transactions of the Association for Computational Linguistics", 18 | year = "2020", 19 | url = "https://arxiv.org/abs/2008.12988", 20 | } 21 | ``` 22 | 23 | ## Requirements and Installation 24 | 25 | * Python version >= 3.6 26 | * PyTorch version >= 1.6.0 27 | 28 | Installation: 29 | ```bash 30 | git clone https://github.com/rycolab/tree_expectations 31 | cd tree_expectations 32 | pip install -e . 33 | ``` 34 | 35 | ## Documenation Style 36 | Variable names: 37 | 38 | w: input matrix 39 | rho: root weight vector 40 | A: adjacency matrix 41 | q: multiplicatively decomposable function 42 | r, s, f: additively decomposable function 43 | 44 | We assume that w, r, s are structured as the root weight vector (rho) 45 | along the diagonal and the rest is the adjacency matrix (A). 46 | 47 | We give type annotations and use the following dimensions 48 | 49 | N = 'number of nodes' 50 | E = 'number of egdes (typically N^2)' 51 | R = 'Dimensionality of r function' 52 | S = 'Dimensionality of s function' 53 | F = 'Dimensionality of f function' -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='tree_expectations', 4 | version='1.0', 5 | description='Tree Expectations', 6 | author='Ran Zmigrod', 7 | url='https://github.com/rycolab/tree_expectations', 8 | install_requires=[ 9 | 'numpy', 10 | 'torch' 11 | ], 12 | packages=find_packages(), 13 | ) 14 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from tree_expectations.matrix_tree_theorem import matrix_tree_theorem 5 | from tree_expectations.expectation import zeroth_order, zeroth_order_grad 6 | from tree_expectations.expectation import first_order, first_order_grad 7 | from tree_expectations.expectation import second_order, covariance 8 | from tree_expectations.brute_force import bf_mtt, bf_zeroth, bf_first, bf_second 9 | from tree_expectations.brute_force import bf_shannon_entropy, bf_kl 10 | from tree_expectations.applications import shannon_entropy, entropy_grad 11 | from tree_expectations.applications import kl_divergence, kl_grad 12 | from tree_expectations.applications import ge_objective, ge_grad 13 | 14 | 15 | class TreeExpectationTests(unittest.TestCase): 16 | def test_matrix_tree_theorem(self): 17 | """ 18 | Test Matrix Tree Theorem 19 | """ 20 | for n in range(3, 6): 21 | for _ in range(5): 22 | P = torch.exp(torch.randn(n, n)).double() 23 | P = P.clone().detach().requires_grad_(True) 24 | P_brute = P.clone().detach().requires_grad_(True) 25 | Z = matrix_tree_theorem(P) 26 | Z_brute = bf_mtt(P_brute) 27 | self.assertTrue(torch.allclose(Z, Z_brute)) 28 | 29 | def test_zeroth(self): 30 | """ 31 | Test random zeroth-order expectations 32 | """ 33 | for n in range(3, 6): 34 | for _ in range(3): 35 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 36 | q = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 37 | w_brute = w.clone().detach().requires_grad_(True) 38 | q_brute = q.clone().detach().requires_grad_(True) 39 | e = zeroth_order(w, q) 40 | e_brute = bf_zeroth(w_brute, q_brute) 41 | self.assertTrue(torch.allclose(e, e_brute)) 42 | 43 | def test_first(self): 44 | """ 45 | Test random first-order expectations 46 | """ 47 | for n in range(3, 6): 48 | for rdim in range(3, 6): 49 | for _ in range(3): 50 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 51 | r = torch.exp(torch.randn(n, n, rdim)).double().requires_grad_(True) 52 | w_brute = w.clone().detach().requires_grad_(True) 53 | r_brute = r.clone().detach().requires_grad_(True) 54 | e = first_order(w, r) 55 | e_brute = bf_first(w_brute, r_brute) 56 | self.assertTrue(torch.allclose(e, e_brute)) 57 | 58 | def test_second(self): 59 | """ 60 | Test random second-order expectations 61 | """ 62 | for n in range(3, 6): 63 | for rdim in range(3, 6): 64 | for sdim in range(3, 6): 65 | for _ in range(3): 66 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 67 | r = torch.exp(torch.randn(n, n, rdim)).double().requires_grad_(True) 68 | s = torch.exp(torch.randn(n, n, rdim)).double().requires_grad_(True) 69 | w_brute = w.clone().detach().requires_grad_(True) 70 | r_brute = r.clone().detach().requires_grad_(True) 71 | s_brute = s.clone().detach().requires_grad_(True) 72 | e = second_order(w, r, s) 73 | e_brute = bf_second(w_brute, r_brute, s_brute) 74 | self.assertTrue(torch.allclose(e, e_brute)) 75 | 76 | def test_zeroth_grad(self): 77 | """ 78 | Test random gradients of zeroth-order expectations 79 | """ 80 | for n in range(3, 10): 81 | for _ in range(5): 82 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 83 | q = torch.randn(n, n).double().requires_grad_(True) 84 | 85 | e = zeroth_order(w, q) 86 | true_grad = torch.autograd.grad(e, [w], retain_graph=True, create_graph=True)[0] 87 | 88 | grad = zeroth_order_grad(w, q) 89 | self.assertTrue(torch.allclose(grad, true_grad, rtol=1e-4)) 90 | 91 | def test_first_grad_as_second(self): 92 | """ 93 | Test random gradients of first-order expectations 94 | """ 95 | for n in range(3, 10): 96 | for rdim in range(3, 10): 97 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 98 | r = torch.exp(torch.randn(n, n, rdim)).double().requires_grad_(True) 99 | e = first_order(w, r) 100 | e_grad = torch.zeros((rdim, n, n)).double() 101 | for i in range(rdim): 102 | e_grad[i] = torch.autograd.grad(e[i], [w], retain_graph=True, create_graph=True)[0] 103 | e_grad = e_grad.reshape(rdim, n*n) 104 | s = torch.zeros((n, n, n, n)).double().requires_grad_(True) 105 | for i in range(n): 106 | for j in range(n): 107 | s[i, j, i, j] = 1. / w[i, j] 108 | s = s.reshape(n, n, n*n) 109 | cov = covariance(w, r, s) 110 | grad = first_order_grad(w, r) 111 | self.assertTrue(torch.allclose(e_grad, cov)) 112 | self.assertTrue(torch.allclose(e_grad, grad)) 113 | 114 | def test_entropy(self): 115 | """ 116 | Test value of the Shannon entropy 117 | """ 118 | for n in range(3, 6): 119 | for _ in range(3): 120 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 121 | ent = shannon_entropy(w) 122 | true_ent = bf_shannon_entropy(w) 123 | self.assertTrue(torch.allclose(true_ent, ent)) 124 | 125 | def test_entropy_grad(self): 126 | """ 127 | Test gradient of the Shannon entropy 128 | """ 129 | for n in range(3, 10): 130 | for _ in range(3): 131 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 132 | ent = shannon_entropy(w) 133 | true_grad = torch.autograd.grad(ent, [w], retain_graph=True, create_graph=True)[0] 134 | true_grad = true_grad.reshape(n * n) 135 | grad = entropy_grad(w) 136 | self.assertTrue(torch.allclose(true_grad, grad, rtol=1e-4)) 137 | 138 | def test_kl(self): 139 | """ 140 | Test value of the KL Divergence 141 | """ 142 | for n in range(3, 6): 143 | for _ in range(3): 144 | w_p = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 145 | w_q = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 146 | kl = kl_divergence(w_p, w_q) 147 | true_kl = bf_kl(w_p, w_q) 148 | self.assertTrue(torch.allclose(true_kl, kl, rtol=1e-5)) 149 | 150 | def test_kl_grad(self): 151 | """ 152 | Test gradient of the KL Divergence 153 | """ 154 | for n in range(3, 8): 155 | for _ in range(3): 156 | w_p = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 157 | w_q = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 158 | ent = kl_divergence(w_p, w_q) 159 | true_grad = torch.autograd.grad(ent, [w_p], retain_graph=True, create_graph=True)[0] 160 | true_grad = true_grad.reshape(n * n) 161 | grad = kl_grad(w_p, w_q) 162 | self.assertTrue(torch.allclose(true_grad, grad)) 163 | 164 | def test_ge_grad(self): 165 | """ 166 | Test gradient of the Generalized Expectation Criterion 167 | """ 168 | for n in range(3, 10): 169 | for sdim in range(2, 6): 170 | for _ in range(3): 171 | w = torch.exp(torch.randn(n, n)).double().requires_grad_(True) 172 | s = torch.exp(torch.randn(n, n, sdim)).double().requires_grad_(True) 173 | target = torch.randn(sdim).double().requires_grad_(True) 174 | ge = ge_objective(w, s, target) 175 | true_grad = torch.autograd.grad(ge, [w], retain_graph=True, create_graph=True)[0] 176 | true_grad = true_grad.reshape(n * n) 177 | grad = ge_grad(w, s, target) 178 | self.assertTrue(torch.allclose(true_grad, grad)) 179 | 180 | 181 | if __name__ == '__main__': 182 | unittest.main() 183 | -------------------------------------------------------------------------------- /tree_expectations/__init__.py: -------------------------------------------------------------------------------- 1 | # Sizing Types 2 | N = 'number of nodes' 3 | E = 'number of egdes (typically N^2)' 4 | R = 'Dimensionality of r function' 5 | S = 'Dimensionality of s function' 6 | F = 'Dimensionality of f function' 7 | -------------------------------------------------------------------------------- /tree_expectations/applications.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from tree_expectations import N, R, F 5 | from tree_expectations.matrix_tree_theorem import matrix_tree_theorem, _dz_base 6 | from tree_expectations.expectation import zeroth_order, first_order, first_order_grad 7 | from tree_expectations.utils import device 8 | 9 | 10 | # Renyi Entropy 11 | def renyi(w: 'Tensor[N, N]', alpha: float): 12 | w_alpha = w.pow(1 - alpha) 13 | return zeroth_order(w, w_alpha) 14 | 15 | 16 | # RISK - Section 5.1 of the paper 17 | def risk(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R]': 18 | """ 19 | Compute the risk of W with respect to additively decomposable function r. 20 | Function has a runtime of O(N^3+R'N^2) 21 | Note that r is a constant with respect to w 22 | """ 23 | return first_order(w, r) 24 | 25 | 26 | def risk_grad(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R, N^2]': 27 | """ 28 | Compute the risk of w with respect to additively decomposable function r. 29 | Function has a runtime of O(N^3 min(R, N R') ) 30 | Note that r is a constant with respect to w 31 | """ 32 | return first_order_grad(w, r) 33 | 34 | 35 | # Shannon Entropy - Section 5.2 of the paper 36 | def shannon_entropy(w: 'Tensor[N, N]') -> 'Tensor[1]': 37 | """ 38 | Compute the Shannon entropy of w. 39 | Function has a runtime of O(N^3) 40 | """ 41 | n = w.size(0) 42 | Z = matrix_tree_theorem(w) 43 | logw = w.clone() 44 | logw[logw == 0] = 1 45 | logw = torch.log(Z) / n - torch.log(logw) 46 | return first_order(w, logw.unsqueeze(-1)) 47 | 48 | 49 | def entropy_grad(w: 'Tensor[N, N]') -> 'Tensor[N^2]': 50 | """ 51 | Compute the gradient of the Shannon entropy of w. 52 | Function has a runtime of O(N^3) 53 | """ 54 | n = w.size(0) 55 | logw = w.clone() 56 | logw[logw == 0] = 1 57 | Z = matrix_tree_theorem(w) 58 | r = torch.log(Z) / n - torch.log(logw) 59 | dz = _dz_base(w) * Z 60 | mu = dz * w 61 | x = mu.sum() / (n * Z) - 1 62 | return first_order_grad(w, r.unsqueeze(-1)) + dz.reshape(n * n) * x 63 | 64 | 65 | def smith_eisner_shannon_entropy(w: 'Tensor[N, N]') -> 'Tensor[1]': 66 | """ 67 | Compute the Shannon entropy of W using method described in: 68 | https://www.cs.jhu.edu/~jason/papers/smith+eisner.emnlp07.pdf 69 | Function has a runtime of O(N^4) 70 | """ 71 | n = w.size(0) 72 | h = torch.tensor(0).double().to(device) 73 | log_w = w.clone() 74 | log_w[log_w == 0] = 1. 75 | log_w = torch.log(log_w) 76 | Z = matrix_tree_theorem(w) 77 | for i in range(n): 78 | w_mod = torch.ones((n, n)).double().to(device) 79 | w_mod[:, i] = log_w[:, i] 80 | h += matrix_tree_theorem(w * w_mod) 81 | return torch.log(Z) - h / Z 82 | 83 | 84 | # KL Divergence - Section 5.3 of the paper 85 | def kl_divergence(w_p: 'Tensor[N, N]', w_q: 'Tensor[N, N]') -> 'Tensor[1]': 86 | """ 87 | Compute the KL divergence of w_p and w_q. 88 | Function has a runtime of O(N^3) 89 | """ 90 | n = w_p.size(0) 91 | log_w_p = w_p.clone() 92 | log_w_q = w_q.clone() 93 | log_w_p[log_w_p == 0] = 1 94 | log_w_q[log_w_q == 0] = 1 95 | Z_p = matrix_tree_theorem(w_p) 96 | Z_q = matrix_tree_theorem(w_q) 97 | r = torch.log(log_w_p) - torch.log(log_w_q) + \ 98 | (torch.log(Z_q) - torch.log(Z_p)) / n 99 | return first_order(w_p, r.unsqueeze(-1)) 100 | 101 | 102 | def kl_grad(w_p: 'Tensor[N, N]', w_q: 'Tensor[N, N]') -> 'Tensor[N^2]': 103 | """ 104 | Compute the gradient of the KL divergence of w_p and w_q. 105 | Function has a runtime of O(N^3) 106 | """ 107 | n = w_p.size(0) 108 | log_w_p = w_p.clone() 109 | log_w_q = w_q.clone() 110 | log_w_p[log_w_p == 0] = 1 111 | log_w_q[log_w_q == 0] = 1 112 | Z_p = matrix_tree_theorem(w_p) 113 | Z_q = matrix_tree_theorem(w_q) 114 | r = torch.log(log_w_p) - torch.log(log_w_q) + \ 115 | (torch.log(Z_q) - torch.log(Z_p)) / n 116 | dz = _dz_base(w_p) * Z_p 117 | x = torch.ones(1).double() - (dz * w_p).sum() / (n * Z_p) 118 | return first_order_grad(w_p, r.unsqueeze(-1)) + dz.reshape(n * n) * x 119 | 120 | 121 | # Generalized Expectation Criterion - Section 5.4 of the paper 122 | def ge_objective( 123 | w: 'Tensor[N, N]', 124 | f: 'Tensor[N, N, F]', 125 | target: 'Tensor[F]' 126 | ) -> 'Tensor[1]': 127 | """ 128 | Compute the Generalized-Expected criterion of w with respect to 129 | additively decomposable function s and a target. 130 | Function has a runtime of O(N^3 + N^2 F') 131 | """ 132 | e_s = first_order(w, f) 133 | distance = e_s - target 134 | return 0.5 * distance @ distance 135 | 136 | 137 | def ge_grad( 138 | w: 'Tensor[N, N]', 139 | f: 'Tensor[N, N, F]', 140 | target: 'Tensor[F]' 141 | ) -> 'Tensor[N^2]': 142 | """ 143 | Compute the Generalized-Expected criterion of w with respect to 144 | additively decomposable function s and a target. 145 | Function has a runtime of O(N^3 + N^2 F) 146 | """ 147 | first = first_order(w, f) 148 | residual = first - target 149 | f = torch.einsum("ijs,s->ij", f, residual) 150 | return first_order_grad(w, f).unsqueeze(0) 151 | 152 | 153 | -------------------------------------------------------------------------------- /tree_expectations/brute_force.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List 4 | 5 | from tree_expectations import F, N, R, S 6 | from tree_expectations.utils import device 7 | 8 | 9 | # Scores 10 | def _prod_score(r: 'Tensor[N, N, R]', tree: 'Tensor[N]') -> 'Tensor[R]': 11 | n = r.size(0) 12 | return r[tree, torch.arange(n)].prod(0) 13 | 14 | 15 | def _sum_score(r: 'Tensor[N, N, R]', tree: 'Tensor[N]') -> 'Tensor[R]': 16 | n = r.size(0) 17 | return r[tree, torch.arange(n)].sum(0) 18 | 19 | 20 | def _enumerate_trees(A: 'Tensor[N, N]', root: int, root_weight: float) -> List: 21 | n = A.size(0) 22 | 23 | def enum_dst(weight, included, rest, excluded): 24 | if len(included) == n: 25 | return [(rest, weight)] 26 | dsts = [] 27 | new_excluded = list(excluded) 28 | for i in included: 29 | for j in range(n): 30 | weight_ij = A[i, j] 31 | if j not in included and (i, j) not in excluded and weight_ij: 32 | new_excluded += [(i, j)] 33 | dsts += enum_dst(weight * weight_ij, included + [j], 34 | rest + [(i, j, weight_ij)], new_excluded) 35 | return dsts 36 | return enum_dst(root_weight, [root], [], []) 37 | 38 | 39 | def all_multi_root_trees(w: 'Tensor[N, N]') -> List: 40 | """ 41 | Compute all spanning trees of w that contain at least one root. 42 | Warning: this method is very inefficient. 43 | It should only be used on small examples, e.g., for testing purposes. 44 | """ 45 | n = w.size(0) 46 | rho = torch.diag(w) 47 | A = w * (torch.ones(1) - torch.eye(n)).to(device) 48 | new_A = torch.zeros((n + 1, n + 1)) 49 | new_A[1:, 1:] = A 50 | new_A[0, 1:] = rho 51 | dsts = [] 52 | unrooted_dsts = [] 53 | for i in range(n): 54 | unrooted_dsts += _enumerate_trees(new_A, i, 1) 55 | for tree, weight in unrooted_dsts: 56 | t = - torch.ones(n) 57 | for i, j, _ in tree[1:]: 58 | if i == 0: 59 | t[j - 1] = j - 1 60 | else: 61 | t[j - 1] = i - 1 62 | dsts.append((t, weight)) 63 | return dsts 64 | 65 | 66 | def all_single_root_trees(w: 'Tensor[N, N]') -> List: 67 | """ 68 | Compute all spanning trees of w that contain only one root. 69 | Warning: this method is very inefficient. 70 | It should only be used on small examples, e.g., for testing purposes. 71 | """ 72 | n = w.size(0) 73 | rho = torch.diag(w) 74 | A = w * (torch.ones(1) - torch.eye(n)).to(device) 75 | dsts = [] 76 | for root, weight in enumerate(rho): 77 | if weight: 78 | rooted_dsts = _enumerate_trees(A, root, weight) 79 | for r_tree, weight in rooted_dsts: 80 | tree = - torch.ones(rho.size(0), dtype=torch.long) 81 | tree[root] = root 82 | for i, j, _ in r_tree: 83 | tree[j] = i 84 | dsts += [(tree, weight)] 85 | return dsts 86 | 87 | 88 | def bf_mtt(w: 'Tensor[N, N]') -> 'Tensor[1]': 89 | """ 90 | Compute the sum of costs over all spanning trees in w. 91 | Warning: this method is very inefficient. 92 | It should only be used on small examples, e.g., for testing purposes. 93 | """ 94 | Z = torch.tensor(0).double().to(device) 95 | for _, weight in all_single_root_trees(w): 96 | Z += weight 97 | return Z 98 | 99 | 100 | def bf_zeroth(w: 'Tensor[N, N]', q: 'Tensor[N, N]') -> 'Tensor[1]': 101 | """ 102 | Compute the zeroth-order expectation over all spanning trees in w 103 | given a multiplicatively decomposable function q. 104 | Warning: this method is very inefficient. 105 | It should only be used on small examples, e.g., for testing purposes. 106 | """ 107 | e = torch.tensor(0).double().to(device) 108 | Z = torch.tensor(0).double().to(device) 109 | for tree, weight in all_single_root_trees(w): 110 | Z += weight 111 | e += weight * _prod_score(q, tree) 112 | return e / Z 113 | 114 | 115 | def bf_first(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R]': 116 | """ 117 | Compute the first-order expectation over all spanning trees in w 118 | given an additively decomposable function r. 119 | Warning: this method is very inefficient. 120 | It should only be used on small examples, e.g., for testing purposes. 121 | """ 122 | rdim = r.size(-1) 123 | e = torch.zeros(rdim).double().to(device) 124 | Z = torch.tensor(0).double().to(device) 125 | for tree, weight in all_single_root_trees(w): 126 | Z += weight 127 | e += weight * _sum_score(r, tree) 128 | return e / Z 129 | 130 | 131 | def bf_second( 132 | w: 'Tensor[N, N]', 133 | r: 'Tensor[N, N, R]', 134 | s: 'Tensor[N, N, S]' 135 | ) -> 'Tensor[R, S]': 136 | """ 137 | Compute the second-order expectation over all spanning trees in w 138 | given additively decomposable functions r and s. 139 | Warning: this method is very inefficient. 140 | It should only be used on small examples, e.g., for testing purposes. 141 | """ 142 | rdim = r.size(-1) 143 | sdim = s.size(-1) 144 | e = torch.zeros((rdim, sdim)).double().to(device) 145 | Z = bf_mtt(w) 146 | for tree, weight in all_single_root_trees(w): 147 | e += weight / Z * torch.ger(_sum_score(r, tree), _sum_score(s, tree)) 148 | return e 149 | 150 | 151 | # Renyi Entropy 152 | def bf_renyi_entropy(w: 'Tensor[N, N]', alpha: float) -> 'Tensor[1]': 153 | """ 154 | Compute the Renyi entropy of w. 155 | Warning: this method is very inefficient. 156 | It should only be used on small examples, e.g., for testing purposes. 157 | """ 158 | Z = torch.zeros(1).double().to(device) 159 | H = torch.zeros(1).double().to(device) 160 | for _, weight in all_single_root_trees(w): 161 | Z += weight 162 | H += torch.pow(weight, alpha) 163 | return (torch.log(H) - alpha * torch.log(Z)) / (1 - alpha) 164 | 165 | 166 | # RISK 167 | def bf_risk(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R]': 168 | """ 169 | Compute the risk of w with respect to additively decomposable function r. 170 | Warning: this method is very inefficient. 171 | It should only be used on small examples, e.g., for testing purposes. 172 | """ 173 | Z = torch.zeros(1).double().to(device) 174 | risk = torch.zeros(1).double().to(device) 175 | for tree, weight in all_single_root_trees(w): 176 | Z += weight 177 | risk += weight * _sum_score(r, tree) 178 | return risk / Z 179 | 180 | 181 | # Shannon Entropy 182 | def bf_shannon_entropy(w: 'Tensor[N, N]') -> 'Tensor[1]': 183 | """ 184 | Compute the Shannon entropy of w. 185 | Warning: this method is very inefficient. 186 | It should only be used on small examples, e.g., for testing purposes. 187 | """ 188 | Z = torch.zeros(1).double().to(device) 189 | H = torch.zeros(1).double().to(device) 190 | for _, weight in all_single_root_trees(w): 191 | Z += weight 192 | H += weight * torch.log(weight) 193 | return torch.log(Z) - H / Z 194 | 195 | 196 | # KL Divergence 197 | def bf_kl(w_p: 'Tensor[N, N]', w_q: 'Tensor[N, N]') -> 'Tensor[1]': 198 | """ 199 | Compute the KL divergence between the distributions of w and X. 200 | Warning: this method is very inefficient. 201 | It should only be used on small examples, e.g., for testing purposes. 202 | """ 203 | total = torch.zeros(1).double().to(device) 204 | Z_w = torch.zeros(1).double().to(device) 205 | Z_x = torch.zeros(1).double().to(device) 206 | for tree, weight in all_single_root_trees(w_p): 207 | weight_q = _prod_score(w_q, tree) 208 | Z_w += weight 209 | Z_x += weight_q 210 | total += weight * (torch.log(weight) - torch.log(weight_q)) 211 | return torch.log(Z_x) - torch.log(Z_w) + total / Z_w 212 | 213 | 214 | # Generalized Expectation 215 | def bf_ge(w: 'Tensor[N, N]', f: 'Tensor[N, N, F]', target: 'Tensor[F]') -> 'Tensor[1]': 216 | """ 217 | Compute the Generalized-Expected criterion of w with respect to additively decomposable 218 | function s and a target. 219 | Warning: this method is very inefficient. 220 | It should only be used on small examples, e.g., for testing purposes. 221 | """ 222 | e = torch.zeros(f.size(-1)) 223 | Z = torch.tensor(0).double().to(device) 224 | for tree, weight in all_single_root_trees(w): 225 | Z += weight 226 | e += weight * _sum_score(f, tree) 227 | residual = e - target 228 | return 0.5 * residual @ residual 229 | -------------------------------------------------------------------------------- /tree_expectations/expectation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from tree_expectations import N, R, S 5 | from tree_expectations.matrix_tree_theorem import lap, dlap, matrix_tree_theorem, _dz_base 6 | 7 | 8 | def _mu(w: 'Tensor[N, N]', B: 'Tensor[N, N]' = None) -> 'Tensor[N, N]': 9 | return w * _dz_base(w, B) 10 | 11 | 12 | def zeroth_order(w: 'Tensor[N, N]', q: 'Tensor[N, N]') -> 'Tensor[1]': 13 | """ 14 | Compute the zeroth-order expectation of multiplicatively decomsosable 15 | function q. 16 | This algorithm is E_0 in the paper. 17 | This function has a runtime of O(N^3). 18 | """ 19 | w_q = w * q 20 | return matrix_tree_theorem(w_q) / matrix_tree_theorem(w) 21 | 22 | 23 | def first_order(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R]': 24 | """ 25 | Compute the first-order expectation of additively decomsosable function r. 26 | This algorithm is E_1 in the paper. 27 | This function has a runtime of O(N^3 + N^2 R'). 28 | """ 29 | n = w.size(0) 30 | B = torch.inverse(lap(w)).t() 31 | mu = _mu(w, B) 32 | return (mu.unsqueeze(-1) * r).reshape(n*n, -1).sum(0) 33 | 34 | 35 | def second_order( 36 | w: 'Tensor[N, N]', 37 | r: 'Tensor[N, N, R]', 38 | s: 'Tensor[N, N, S]' 39 | ) -> 'Tensor[R, S]': 40 | """ 41 | Compute the second-order expectation of additively decomsosable 42 | functions r and s. 43 | This algorithm is E_2 in the paper. 44 | This function has a runtime of: 45 | O(N^3 (R' + S') + R S + N^2 min(R, N R') min(S, n S')) 46 | """ 47 | n = w.size(0) 48 | rdim = r.size(-1) 49 | sdim = s.size(-1) 50 | B = torch.inverse(lap(w)).t() 51 | mu = _mu(w, B) 52 | e_r = (mu.unsqueeze(-1) * r).reshape(n*n, -1).sum(0) 53 | e_s = (mu.unsqueeze(-1) * s).reshape(n*n, -1).sum(0) 54 | rhat = torch.zeros((n, n, rdim)).double().requires_grad_(True) 55 | shat = torch.zeros((n, n, sdim)).double().requires_grad_(True) 56 | e_t = torch.ger(e_r, e_s) 57 | for i in range(n): 58 | for j in range(n): 59 | for k in range(n): 60 | for i_, j_, dL in dlap(i, j): 61 | rhat[k, j_] += B[i_, k] * dL * w[i, j] * r[i, j] 62 | shat[j_, k] += B[i_, k] * dL * w[i, j] * s[i, j] 63 | for i in range(n): 64 | for j in range(n): 65 | e_t += mu[i, j] * torch.ger(r[i, j], s[i, j]) - torch.ger(rhat[i, j], shat[i, j]) 66 | return e_t 67 | 68 | 69 | def covariance( 70 | w: 'Tensor[N, N]', 71 | r: 'Tensor[N, N, R]', 72 | s: 'Tensor[N, N, S]' 73 | ) -> 'Tensor[R, S]': 74 | """ 75 | Compute the covariance between additively decomsosable functions r and s. 76 | This function has a runtime of: 77 | O(N^3 (R' + S') + N^2 min(R, N R') min(S, n S')) 78 | """ 79 | n = w.size(0) 80 | rdim = r.size(-1) 81 | sdim = s.size(-1) 82 | B = torch.inverse(lap(w)).t() 83 | mu = _mu(w, B) 84 | rhat = torch.zeros((n, n, rdim)).double().requires_grad_(True) 85 | shat = torch.zeros((n, n, sdim)).double().requires_grad_(True) 86 | cov = torch.zeros((rdim, sdim)).double().requires_grad_(True) 87 | for i in range(n): 88 | for j in range(n): 89 | for k in range(n): 90 | for i_, j_, dL in dlap(i, j): 91 | rhat[k, j_] += B[i_, k] * dL * w[i, j] * r[i, j] 92 | shat[j_, k] += B[i_, k] * dL * w[i, j] * s[i, j] 93 | for i in range(n): 94 | for j in range(n): 95 | cov[:, :] += mu[i, j] * torch.ger(r[i, j], s[i, j]) - torch.ger(rhat[i, j], shat[i, j]) 96 | return cov 97 | 98 | 99 | def zeroth_order_grad(w: 'Tensor[N, N]', q: 'Tensor[N, N]') -> 'Tensor[N, N]': 100 | """ 101 | Compute the gradient of a zeroth-order expectation. 102 | This assumes that q does not depend on w, if this is not the case, 103 | the gradient of q must be added to grad 104 | This relates to Proposition 4 in the paper. 105 | This function has a runtime of O(N^3) 106 | """ 107 | wq = w * q 108 | B = torch.inverse(lap(w)).t() 109 | B_q = torch.inverse(lap(wq)).t() 110 | base = _dz_base(w, B) 111 | base_q = _dz_base(wq, B_q) * q 112 | grad = zeroth_order(w, q) * (base_q - base) 113 | return grad 114 | 115 | 116 | def first_order_grad(w: 'Tensor[N, N]', r: 'Tensor[N, N, R]') -> 'Tensor[R, N, N]': 117 | """ 118 | Compute the gradient of a first-order expectation. 119 | This assumes that r does not depend on w, if this is not the case, 120 | the gradient of r must be added to grad 121 | This relates to Theorem 3 in the paper. 122 | This function has a runtime of O(N^3 min(R, N R')) 123 | """ 124 | n = w.size(0) 125 | rdim = r.size(-1) 126 | B = torch.inverse(lap(w)).t() 127 | base = _dz_base(w, B) 128 | rhat = torch.zeros((n, n, rdim)).double().requires_grad_(True) 129 | shat = torch.zeros((n, n, n, n)).double().requires_grad_(True) 130 | for i in range(n): 131 | for j in range(n): 132 | for k in range(n): 133 | for i_, j_, dL in dlap(i, j): 134 | rhat[k, j_] += B[i_, k] * dL * w[i, j] * r[i, j] 135 | shat[j_, k, i, j] += B[i_, k] * dL 136 | shat = shat.reshape(n, n, n*n) 137 | e_rs = torch.zeros((rdim, n, n)).double().requires_grad_(True) 138 | hats = torch.zeros((rdim, n*n)).double().requires_grad_(True) 139 | for i in range(n): 140 | for j in range(n): 141 | e_rs[:, i, j] += base[i, j] * r[i, j] 142 | hats[:, :] += torch.ger(rhat[i, j], shat[i, j]) 143 | e_rs = e_rs.reshape(rdim, n*n) 144 | return e_rs - hats 145 | -------------------------------------------------------------------------------- /tree_expectations/matrix_tree_theorem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from tree_expectations import N 5 | from tree_expectations.utils import device 6 | 7 | 8 | def laplacian(A: 'Tensor[N, N]', rho: 'Tensor[N]') -> 'Tensor[N, N]': 9 | """ 10 | Compute the root-weighted Laplacian. 11 | Function has a runtime of O(N^2). 12 | """ 13 | L = -A + torch.diag_embed(A.sum(dim=0)).to(device) 14 | L[0] = rho 15 | return L 16 | 17 | 18 | def lap(w: 'Tensor[N, N]') -> 'Tensor[N, N]': 19 | """ 20 | Compute the root-weighted Laplacian. 21 | Function has a runtime of O(N^2). 22 | """ 23 | rho = torch.diag(w) 24 | A = w * (torch.tensor(1).double().to(device) - torch.eye(w.size(0)).double().to(device)) 25 | return laplacian(A, rho) 26 | 27 | 28 | def dlap(k, l): 29 | """ 30 | Index over sparsity in the Jacobian of the Laplacian 31 | 32 | Given (k,l), return (i,j) such that for a 33 | 34 | dL[i,j] / ∂A[k,l] ≠ 0 35 | 36 | Done in O(1), see Proposition of paper (see README) 37 | """ 38 | if k == l: 39 | return [(0, k, 1.)] 40 | out = [] 41 | if l != 0: 42 | out.append((l, l, 1.)) 43 | if k != 0: 44 | out.append((k, l, -1.)) 45 | return out 46 | 47 | 48 | def adj(A: 'Tensor[N, N]') -> 'Tensor[N, N]': 49 | """ 50 | Compute the adjugate of a matrix A. 51 | The adjugate can be used for calculating the derivative of a determinants. 52 | Function has a runtime of O(N^3). 53 | """ 54 | Ad = torch.slogdet(A) 55 | Ad = Ad[0] * torch.exp(Ad[1]) 56 | return Ad * torch.inverse(A).t() 57 | 58 | 59 | def matrix_tree_theorem(w: 'Tensor[N, N]', use_log: bool = False) -> 'Tensor[1]': 60 | """ 61 | Compute the sum over all spanning trees in W using the Matrix--Tree Theorem. 62 | Function has a runtime of O(N^3). 63 | This relates to Section 2, Proposition 1 of paper (see README) 64 | """ 65 | r = torch.diag(w) 66 | A = w * (torch.tensor(1).double().to(device) - torch.eye(w.size(0)).double().to(device)) 67 | sign, logZ = torch.slogdet(laplacian(A, r)) 68 | return logZ if use_log else sign * torch.exp(logZ) 69 | 70 | 71 | def _dz_base(w: 'Tensor[N, N]', B: 'Tensor[N, N]' = None): 72 | """ 73 | Evaluate the base of the derivative of the Matrix--Tree Theorem 74 | using a possibly cached transpoed inverse Laplacian matrix. 75 | To get the full derivative, we must multiply by Z. 76 | To get mu, we must multiply by w. 77 | This function runs in O(N^2) if B is provided and O(N^3) otherwise. 78 | """ 79 | n = w.size(0) 80 | if B is None: 81 | B = torch.inverse(lap(w)).t() 82 | base = torch.zeros((n, n)).double().requires_grad_(True) 83 | for i in range(n): 84 | for j in range(n): 85 | for i_, j_, dL in dlap(i, j): 86 | base[i, j] += B[i_, j_] * dL 87 | return base 88 | -------------------------------------------------------------------------------- /tree_expectations/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | --------------------------------------------------------------------------------