├── Example.ipynb ├── README.md ├── fastlri ├── base │ ├── cfg.py │ ├── exceptions.py │ ├── nonterminal.py │ ├── production.py │ └── symbol.py └── parsing │ └── parser.py ├── setup.py └── test └── test_lri.py /Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "f9ee1605", 7 | "metadata": {}, 8 | "source": [ 9 | "### Example usage\n", 10 | "First, define a weighted context free grammar as follows:" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "1858fc62", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "N → fruit\t0.5\n", 24 | "N → flies\t0.25\n", 25 | "N → banana\t0.25\n", 26 | "S → NP VP\t1.0\n", 27 | "V → like\t0.5\n", 28 | "V → flies\t0.5\n", 29 | "NP → N N\t0.25\n", 30 | "NP → Det N\t0.25\n", 31 | "NP → Adj N\t0.25\n", 32 | "NP → Det NP\t0.25\n", 33 | "VP → V NP\t1.0\n", 34 | "Adj → green\t1.0\n", 35 | "Adv → like\t1.0\n", 36 | "Det → a\t1.0\n", 37 | "AdvP → Adv NP\t1.0\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "from fastlri.parsing.parser import Parser\n", 43 | "from fastlri.base.cfg import CFG\n", 44 | "from fastlri.base.nonterminal import S, NT\n", 45 | "from fastlri.base.symbol import Sym\n", 46 | "\n", 47 | "# define the nonterminals of the grammar\n", 48 | "NP = NT(\"NP\")\n", 49 | "VP = NT(\"VP\")\n", 50 | "Det = NT(\"Det\")\n", 51 | "N = NT(\"N\")\n", 52 | "PP = NT(\"PP\")\n", 53 | "V = NT(\"V\")\n", 54 | "Adj = NT(\"Adj\")\n", 55 | "Adv = NT(\"Adv\")\n", 56 | "AdvP = NT(\"AdvP\")\n", 57 | "\n", 58 | "# define the terminals of the grammar\n", 59 | "fruit = Sym(\"fruit\")\n", 60 | "flies = Sym(\"flies\")\n", 61 | "like = Sym(\"like\")\n", 62 | "a = Sym(\"a\")\n", 63 | "green = Sym(\"green\")\n", 64 | "banana = Sym(\"banana\")\n", 65 | "\n", 66 | "# define the rules of the grammar\n", 67 | "cfg = CFG()\n", 68 | "cfg.add(1, cfg.S, NP, VP)\n", 69 | "cfg.add(0.25, NP, Det, N)\n", 70 | "cfg.add(0.25, NP, Det, NP)\n", 71 | "cfg.add(0.25, NP, N, N)\n", 72 | "cfg.add(0.25, NP, Adj, N)\n", 73 | "cfg.add(1, VP, V, NP)\n", 74 | "cfg.add(1, AdvP, Adv, NP)\n", 75 | "cfg.add(0.5, N, fruit)\n", 76 | "cfg.add(0.25, N, flies)\n", 77 | "cfg.add(0.25, N, banana)\n", 78 | "cfg.add(0.5, V, flies)\n", 79 | "cfg.add(0.5, V, like)\n", 80 | "cfg.add(1, Det, a)\n", 81 | "cfg.add(1, Adj, green)\n", 82 | "cfg.add(1, Adv, like)\n", 83 | "\n", 84 | "print(cfg)" 85 | ] 86 | }, 87 | { 88 | "attachments": {}, 89 | "cell_type": "markdown", 90 | "id": "07c79cd4", 91 | "metadata": {}, 92 | "source": [ 93 | "Then, create a parser for this CFG:" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 2, 99 | "id": "b7076495", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "parser = Parser(cfg)" 104 | ] 105 | }, 106 | { 107 | "attachments": {}, 108 | "cell_type": "markdown", 109 | "id": "5aa3b97d", 110 | "metadata": {}, 111 | "source": [ 112 | " Now you can parse input strings using CKY:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 3, 118 | "id": "6fdd1e66", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "0.000244140625" 125 | ] 126 | }, 127 | "execution_count": 3, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "parser.cky(\"fruit flies like a green banana\")" 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "id": "4b7fae32", 140 | "metadata": {}, 141 | "source": [ 142 | "Similarly for lri and fast lri:" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 4, 148 | "id": "45ef18c1", 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": [ 154 | "0.000244140625" 155 | ] 156 | }, 157 | "execution_count": 4, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "parser.lri(\"fruit flies like a green banana\")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "id": "b936664e", 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "text/plain": [ 175 | "0.000244140625" 176 | ] 177 | }, 178 | "execution_count": 5, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "parser.lri_fast(\"fruit flies like a green banana\")" 185 | ] 186 | }, 187 | { 188 | "attachments": {}, 189 | "cell_type": "markdown", 190 | "id": "732b3068", 191 | "metadata": {}, 192 | "source": [ 193 | "For a prefix that has no rooted parse tree under the CFG, cky will return 0, while lri returns a positive probability:" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "id": "c46879d0", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "0.0" 206 | ] 207 | }, 208 | "execution_count": 6, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "parser.cky(\"fruit flies like\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "id": "20d8a2ac", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "0.015625" 227 | ] 228 | }, 229 | "execution_count": 7, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "parser.lri_fast(\"fruit flies like\")" 236 | ] 237 | }, 238 | { 239 | "attachments": {}, 240 | "cell_type": "markdown", 241 | "id": "055d6853", 242 | "metadata": {}, 243 | "source": [ 244 | "It is also possible to get the full dynamic programming chart by setting a flag:" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 8, 250 | "id": "43bc3276", 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "defaultdict(.()>,\n", 257 | " {(VP, 0, 0): 0.0,\n", 258 | " (VP, 1, 1): 0.5,\n", 259 | " (NP, 0, 0): 0.125,\n", 260 | " (NP, 1, 1): 0.0625,\n", 261 | " (Adj, 0, 0): 0.0,\n", 262 | " (Adj, 1, 1): 0.0,\n", 263 | " (AdvP, 0, 0): 0.0,\n", 264 | " (AdvP, 1, 1): 0.0,\n", 265 | " (V, 0, 0): 0.0,\n", 266 | " (V, 1, 1): 0.5,\n", 267 | " (N, 0, 0): 0.5,\n", 268 | " (N, 1, 1): 0.25,\n", 269 | " (Adv, 0, 0): 0.0,\n", 270 | " (Adv, 1, 1): 0.0,\n", 271 | " (S, 0, 0): 0.125,\n", 272 | " (S, 1, 1): 0.0625,\n", 273 | " (Det, 0, 0): 0.0,\n", 274 | " (Det, 1, 1): 0.0,\n", 275 | " (VP, 0, 1): 0.0,\n", 276 | " (NP, 0, 1): 0.03125,\n", 277 | " (Adj, 0, 1): 0.0,\n", 278 | " (AdvP, 0, 1): 0.0,\n", 279 | " (V, 0, 1): 0.0,\n", 280 | " (N, 0, 1): 0.0,\n", 281 | " (Adv, 0, 1): 0.0,\n", 282 | " (S, 0, 1): 0.03125,\n", 283 | " (Det, 0, 1): 0.0})" 284 | ] 285 | }, 286 | "execution_count": 8, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "parser.lri_fast(\"fruit flies\", chart=True)" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3 (ipykernel)", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.10.11" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 5 317 | } 318 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prefix-parsing 2 | 3 | Code accompanying the ACL 2023 publication "[A Fast Algorithm for Computing Prefix Probabilities](https://aclanthology.org/2023.acl-short.6/)". 4 | 5 | This repository contains implementations for parsing weighted context free grammars (WCFGs) in chomsky normal form (CNF). 6 | 7 | A context-free grammar is in CNF if the all the rules are in one of the following forms: 8 | ``` 9 | S -> ε 10 | X -> Y Z 11 | X -> a 12 | ``` 13 | Where S is the distinguished start non-terminal, X, Y, and Z are non-terminals, and a is a terminal. 14 | 15 | The methods can be found under `src/parsing/parser.py` implement: 16 | - The CKY algorithm ([Kasami, 1965](https://www.ideals.illinois.edu/items/100444); [Younger, 17 | 1967](https://doi.org/https://doi.org/10.1016/S0019-9958(67)80007-X); [Cocke and Schwartz, 1969](https://www.softwarepreservation.org/projects/FORTRAN/CockeSchwartz_ProgLangCompilers.pdf)) for parsing a string under a WCFG in CNF; 18 | - The LRI algorithm ([Jelinek and Lafferty, 1991](https://aclanthology.org/J91-3004)) for finding the weight of a prefix string under a WCFG in CNF; 19 | - An improved version of the LRI algorithm ([Nowak and Cotterell, 2023](https://arxiv.org/abs/2306.02303)) using additional memoization. 20 | 21 | --- 22 | To start, run: 23 | ```bash 24 | $ git clone git@github.com:rycolab/prefix-parsing.git 25 | $ cd prefix-parsing 26 | $ pip install -e . 27 | ``` 28 | To unit test, run: 29 | ``` 30 | pytest . 31 | ``` 32 | --- 33 | 34 | ## Example usage 35 | First, define a weighted context free grammar as follows: 36 | ```python 37 | from fastlri.parsing.parser import Parser 38 | from fastlri.base.cfg import CFG 39 | from fastlri.base.nonterminal import NT 40 | from fastlri.base.symbol import Sym 41 | 42 | # define the nonterminals of the grammar 43 | NP = NT("NP") 44 | VP = NT("VP") 45 | Det = NT("Det") 46 | N = NT("N") 47 | PP = NT("PP") 48 | V = NT("V") 49 | Adj = NT("Adj") 50 | Adv = NT("Adv") 51 | AdvP = NT("AdvP") 52 | 53 | # define the terminals of the grammar 54 | fruit = Sym("fruit") 55 | flies = Sym("flies") 56 | like = Sym("like") 57 | a = Sym("a") 58 | green = Sym("green") 59 | banana = Sym("banana") 60 | 61 | # define the rules of the grammar 62 | cfg = CFG() 63 | cfg.add(1, cfg.S, NP, VP) 64 | cfg.add(0.25, NP, Det, N) 65 | cfg.add(0.25, NP, Det, NP) 66 | cfg.add(0.25, NP, N, N) 67 | cfg.add(0.25, NP, Adj, N) 68 | cfg.add(1, VP, V, NP) 69 | cfg.add(1, AdvP, Adv, NP) 70 | cfg.add(0.5, N, fruit) 71 | cfg.add(0.25, N, flies) 72 | cfg.add(0.25, N, banana) 73 | cfg.add(0.5, V, flies) 74 | cfg.add(0.5, V, like) 75 | cfg.add(1, Det, a) 76 | cfg.add(1, Adj, green) 77 | cfg.add(1, Adv, like) 78 | ``` 79 | 80 | Alternatively, grammars can more easily be defined directly from strings, where non-terminals need to be capitalized or start with an '@', whereas terminals are lower case: 81 | ```python 82 | cfg = CFG.from_string(""" 83 | 1.0: S -> NP VP 84 | 0.5: N -> fruit 85 | 0.25: N -> flies 86 | 0.25: N -> banana 87 | 0.5: V -> like 88 | 0.5: V -> flies 89 | 0.25: NP -> N N 90 | 0.25: NP -> Det N 91 | 0.25: NP -> Adj N 92 | 0.25: NP -> Det NP 93 | 1.0: VP -> V NP 94 | 1.0: Adj -> green 95 | 1.0: Adv -> like 96 | 1.0: Det -> a 97 | 1.0: AdvP -> Adv NP 98 | """, start = 'S') 99 | ``` 100 | 101 | Then, create a parser for this CFG: 102 | ```python 103 | parser = Parser(cfg) 104 | ``` 105 | 106 | Now you can parse input strings using CKY: 107 | ```python 108 | parser.cky("fruit flies like a green banana") 109 | ``` 110 | ``` 111 | 0.000244140625 112 | ``` 113 | 114 | Similarly for lri and fast lri: 115 | ```python 116 | parser.lri("fruit flies like a green banana") 117 | ``` 118 | ``` 119 | 0.000244140625 120 | ``` 121 | 122 | ```python 123 | parser.lri_fast("fruit flies like a green banana") 124 | ``` 125 | ``` 126 | 0.000244140625 127 | ``` 128 | 129 | For a prefix that has no rooted parse tree under the CFG, cky will return 0, while lri returns a positive probability: 130 | 131 | ```python 132 | parser.cky("fruit flies like") 133 | ``` 134 | ``` 135 | 0.0 136 | ``` 137 | 138 | ```python 139 | parser.lri_fast("fruit flies like") 140 | ``` 141 | ``` 142 | 0.015625 143 | ``` 144 | 145 | It is also possible to get the full dynamic programming chart by setting a flag: 146 | 147 | ```python 148 | parser.lri_fast("fruit flies", chart=True) 149 | ``` 150 | ``` 151 | defaultdict(.()>, 152 | {(N, 0, 1): 0.5, 153 | (N, 1, 2): 0.25, 154 | (Adv, 0, 1): 0.0, 155 | (Adv, 1, 2): 0.0, 156 | (Det, 0, 1): 0.0, 157 | (Det, 1, 2): 0.0, 158 | (VP, 0, 1): 0.0, 159 | (VP, 1, 2): 0.5, 160 | (NP, 0, 1): 0.125, 161 | (NP, 1, 2): 0.0625, 162 | (S, 0, 1): 0.125, 163 | (S, 1, 2): 0.0625, 164 | (AdvP, 0, 1): 0.0, 165 | (AdvP, 1, 2): 0.0, 166 | (Adj, 0, 1): 0.0, 167 | (Adj, 1, 2): 0.0, 168 | (V, 0, 1): 0.0, 169 | (V, 1, 2): 0.5, 170 | (N, 0, 2): 0.0, 171 | (Adv, 0, 2): 0.0, 172 | (Det, 0, 2): 0.0, 173 | (VP, 0, 2): 0.0, 174 | (NP, 0, 2): 0.03125, 175 | (S, 0, 2): 0.03125, 176 | (AdvP, 0, 2): 0.0, 177 | (Adj, 0, 2): 0.0, 178 | (V, 0, 2): 0.0}) 179 | ``` 180 | 181 | --- 182 | ## Cite 183 | 184 | If you use this code or the underlying algorithm in your own work, please cite our publication as follows: 185 | 186 | Franz Nowak and Ryan Cotterell. 2023. A Fast Algorithm for Computing Prefix Probabilities. In _Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)_, pages 57–69, Toronto, Canada. Association for Computational Linguistics. 187 | 188 | ``` 189 | @inproceedings{nowak-cotterell-2023-fast, 190 | title = "A Fast Algorithm for Computing Prefix Probabilities", 191 | author = "Nowak, Franz and 192 | Cotterell, Ryan", 193 | editor = "Rogers, Anna and 194 | Boyd-Graber, Jordan and 195 | Okazaki, Naoaki", 196 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)", 197 | month = jul, 198 | year = "2023", 199 | address = "Toronto, Canada", 200 | publisher = "Association for Computational Linguistics", 201 | url = "https://aclanthology.org/2023.acl-short.6", 202 | doi = "10.18653/v1/2023.acl-short.6", 203 | pages = "57--69", 204 | abstract = "Multiple algorithms are known for efficiently calculating the prefix probability of a string under a probabilistic context-free grammar (PCFG). Good algorithms for the problem have a runtime cubic in the length of the input string. However, some proposed algorithms are suboptimal with respect to the size of the grammar. This paper proposes a new speed-up of Jelinek and Lafferty{'}s (1991) algorithm, which runs in $O(n^3|N|^3 + |N|^4)$, where n is the input length and |N| is the number of non-terminals in the grammar. In contrast, our speed-up runs in $O(n^2|N|^3 + n^3|N|^2)$.", 205 | } 206 | ``` 207 | 208 | 209 | 210 | ## Contact 211 | 212 | For any questions or problems, please file an [issue](https://github.com/rycolab/prefix-parsing/issues) or email [fnowak@ethz.ch](mailto:fnowak@ethz.ch). 213 | -------------------------------------------------------------------------------- /fastlri/base/cfg.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict as dd 2 | from fastlri.base.nonterminal import NT, S 3 | from fastlri.base.symbol import Sym 4 | from fastlri.base.production import Production 5 | from fastlri.base.exceptions import InvalidProduction 6 | 7 | class CFG: 8 | 9 | def __init__(self, _S = S): 10 | # alphabet of terminal symbols Σ 11 | self.Sigma = set([]) 12 | 13 | # non-terminal symbols V 14 | self.V = set([_S]) 15 | 16 | # production rules of the form V × (Σ ∪ V)* × R 17 | self._P = dd(lambda: 0.0) 18 | 19 | # unique start non-terminal symbol S 20 | self.S = _S 21 | 22 | @property 23 | def P(self): 24 | for p, w in self._P.items(): 25 | yield p, w 26 | 27 | @property 28 | def terminal(self): 29 | """Returns terminal productions of the CFG.""" 30 | for p, w in self.P: 31 | (head, body) = p 32 | if len(body) == 1 and isinstance(body[0], Sym): 33 | yield p, w 34 | 35 | @property 36 | def binary(self): 37 | """Returns binary productions of the CFG.""" 38 | for p, w in self.P: 39 | (head, body) = p 40 | if len(body) == 2 and isinstance(body[0], NT) \ 41 | and isinstance(body[1], NT): 42 | yield p, w 43 | 44 | @property 45 | def ordered_V(self): 46 | """Returns a list of nonterminals ordered by alphabetical index.""" 47 | V = list(self.V) 48 | V.sort(key=lambda a: str(a.X)) 49 | return V 50 | 51 | @property 52 | def in_cnf(self): 53 | """Checks if grammar is in CNF.""" 54 | for p, w in self.P: 55 | (head, body) = p 56 | if head == self.S and body == (): 57 | # S → ε 58 | continue 59 | elif head in self.V and len(body) == 2 and all([elem in self.V \ 60 | and elem != self.S for elem in body]): 61 | # A → B C 62 | continue 63 | elif head in self.V and len(body) == 1 and body[0] in self.Sigma: 64 | # A → a 65 | continue 66 | else: 67 | return False 68 | return True 69 | 70 | @property 71 | def is_pcfg(self) -> bool: 72 | """Returns whether the grammar is locally normalized.""" 73 | for head in self.V: 74 | total = 0 75 | for p, w in self.P: 76 | if p.head == head: 77 | total += w 78 | if total != 1: 79 | return False 80 | return True 81 | 82 | def add(self, w, head, *body): 83 | """Add a rule to the CFG.""" 84 | if not isinstance(head, NT): 85 | raise InvalidProduction 86 | 87 | self.V.add(head) 88 | 89 | for elem in body: 90 | if isinstance(elem, NT): 91 | self.V.add(elem) 92 | elif isinstance(elem, Sym): 93 | self.Sigma.add(elem) 94 | elif elem != (): 95 | raise InvalidProduction 96 | 97 | self._P[Production(head, body)] += w 98 | 99 | @staticmethod 100 | def from_string(string, comment="#", start='S'): 101 | import re 102 | if isinstance(start, str): start = NT(start) 103 | cfg = CFG(_S = start) 104 | string = string.replace('->', '→') # synonym for the arrow 105 | for line in string.split('\n'): 106 | line = line.strip() 107 | if not line or line.startswith(comment): continue 108 | try: 109 | [(w, lhs, rhs)] = re.findall('(.*):\s*(\S+)\s*→\s*(.*)$', line) 110 | lhs = lhs.strip() 111 | rhs = rhs.strip().split() 112 | 113 | rhs_ = [] 114 | for x in rhs: 115 | if x[0].isupper() or x[0].startswith('@'): 116 | rhs_.append(NT(x)) 117 | else: 118 | rhs_.append(Sym(x)) 119 | cfg.add(float(w), NT(lhs), *rhs_) 120 | 121 | except ValueError as e: 122 | raise ValueError(f'bad input line:\n{line}') 123 | return cfg 124 | 125 | def __str__(self): 126 | return "\n".join(f"{w}: \t {p}" for (p, w) in sorted(self.P, \ 127 | key=lambda x: (len(str(x[0].head)), str(x[0].head), \ 128 | len(str(x[0]))))) 129 | -------------------------------------------------------------------------------- /fastlri/base/exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidProduction(Exception): 2 | pass -------------------------------------------------------------------------------- /fastlri/base/nonterminal.py: -------------------------------------------------------------------------------- 1 | class NT: 2 | 3 | def __init__(self, X, label=None, n=None): 4 | self._X = X 5 | self._label = label 6 | 7 | @property 8 | def X(self): 9 | return self._X 10 | 11 | def copy(self): 12 | return NT(self.X) 13 | 14 | def __repr__(self): 15 | return f'{self.X}' 16 | 17 | def __hash__(self): 18 | return hash(self.X) 19 | 20 | def __eq__(self, other): 21 | return isinstance(other, NT) and self.X == other.X 22 | 23 | S = NT("S") -------------------------------------------------------------------------------- /fastlri/base/production.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | class Production(namedtuple("Production", "head, body")): 4 | 5 | def __repr__(self): 6 | return str(self.head) + " → " + " ".join(map(str, self.body)) -------------------------------------------------------------------------------- /fastlri/base/symbol.py: -------------------------------------------------------------------------------- 1 | class Sym: 2 | def __init__(self, sym): 3 | self.sym = sym 4 | 5 | def __str__(self): 6 | return str(self.sym) 7 | 8 | def __repr__(self): 9 | return str(self.sym) 10 | 11 | def __hash__(self): 12 | return hash(self.sym) 13 | 14 | def __eq__(self, other): 15 | return isinstance(other, Sym) and self.sym == other.sym 16 | 17 | ε = Sym("ε") 18 | -------------------------------------------------------------------------------- /fastlri/parsing/parser.py: -------------------------------------------------------------------------------- 1 | from fastlri.base.symbol import Sym 2 | import numpy as np 3 | from collections import defaultdict as dd 4 | from fastlri.base.production import Production 5 | from fastlri.base.nonterminal import S 6 | 7 | class Parser: 8 | def __init__(self, cfg): 9 | self.cfg = cfg 10 | 11 | def cky(self, input, chart=False): 12 | """Calculates the chart of substring probabilities using CKY. Requires CNF.""" 13 | assert self.cfg.in_cnf 14 | 15 | # convert input string to list 16 | if type(input) == str: 17 | input = [Sym(token) for token in input.split()] 18 | N = len(input) 19 | 20 | # initialization 21 | β = dd(lambda: 0.0) 22 | β[0, self.cfg.S, 0] = self.cfg._P[self.cfg.S, ()] 23 | 24 | # terminal productions 25 | for (head, body), w in self.cfg.terminal: 26 | for k in range(N): 27 | if body[0] == input[k]: 28 | β[k, head, k+1] += w 29 | 30 | # binary productions 31 | for l in range(2, N+1): 32 | for i in range(N-l+1): 33 | k = i + l 34 | for j in range(i+1, k): 35 | for p, w in self.cfg.binary: 36 | X, Y, Z = p.head, p.body[0], p.body[1] 37 | β[i, X, k] += β[i, Y, j] * β[j, Z, k] * w 38 | return β if chart else β[0, S, N] 39 | 40 | def cky_fast(self, input, chart=False): 41 | """A faster version of CKY for dense grammars. Requires CNF.""" 42 | assert self.cfg.in_cnf 43 | 44 | # convert input string to list 45 | if type(input) == str: 46 | input = [Sym(token) for token in input.split()] 47 | N = len(input) 48 | 49 | # initialization 50 | β = dd(lambda: 0.0) 51 | β[0, self.cfg.S, 0] = self.cfg._P[self.cfg.S, ()] 52 | 53 | # create an index from NT triplets to binary production weights 54 | W = dd(lambda: 0.0) 55 | for p, w in self.cfg.binary: 56 | X, Y, Z = p.head, p.body[0], p.body[1] 57 | W[X, Y, Z] = w 58 | 59 | # terminal productions 60 | for (head, body), w in self.cfg.terminal: 61 | for k in range(N): 62 | if body[0] == input[k]: 63 | β[k, head, k+1] += w 64 | 65 | # binary productions 66 | for l in range(2, N+1): 67 | for i in range(N-l+1): 68 | k = i + l 69 | for Y in self.cfg.V: 70 | for Z in self.cfg.V: 71 | γ = 0.0 72 | for j in range(i+1, k): 73 | γ += β[i, Y, j] * β[j, Z, k] 74 | for X in self.cfg.V: 75 | β[i, X, k] += γ * W[X, Y, Z] 76 | return β if chart else β[0, S, N] 77 | 78 | def plc(self): 79 | """Computes the left-corner expectations. Requires CNF.""" 80 | assert self.cfg.in_cnf 81 | 82 | # get canonical index over non-terminals 83 | V = self.cfg.ordered_V 84 | V_idx = {X:i for i,X in enumerate(V)} 85 | 86 | # calculate the matrix P of one-step derivations 87 | P = np.zeros((len(V),len(V))) 88 | for p, w in self.cfg.binary: 89 | X, Y = V_idx[p.head], V_idx[p.body[0]] 90 | P[X, Y] += w 91 | 92 | # compute the closure over derivations 93 | P_L = np.linalg.inv(np.eye(len(V), len(V)) - P) 94 | return P_L 95 | 96 | def lri(self, input, chart=False): 97 | """Original LRI algorithm by Jelinek and Lafferty (1991). Requires CNF.""" 98 | assert self.cfg.in_cnf 99 | 100 | # convert input string to list 101 | if type(input) == str: 102 | input = [Sym(token) for token in input.split()] 103 | 104 | # initialization 105 | N = len(input) 106 | V = self.cfg.ordered_V 107 | V_idx = {X:i for i, X in enumerate(V)} 108 | ppre = dd(lambda: 0.0) 109 | for k in range(N+1): 110 | for X in self.cfg.V: 111 | ppre[k, X, k] = 1 112 | 113 | # precompute β using CKY 114 | β = self.cky(input, chart=True) 115 | 116 | # precompute E 117 | E = dd(lambda: 0.0) 118 | P_L = self.plc() 119 | for X in self.cfg.V: 120 | for Y in self.cfg.V: 121 | E[X, Y] = P_L[V_idx[X], V_idx[Y]] 122 | 123 | # precompute E2 124 | E2 = dd(lambda: 0.0) 125 | for X in self.cfg.V: 126 | for (head, body), w in self.cfg.binary: 127 | Y2, Y, Z = head, body[0], body[1] 128 | E2[X, Y, Z] += E[X, Y2] * self.cfg._P[Production(Y2, (Y, Z))] 129 | 130 | # compute base case 131 | for X in self.cfg.V: 132 | for k in range(N): 133 | for (head, body), w in self.cfg.terminal: 134 | Y, v = head, body[0] 135 | if v == input[k]: 136 | ppre[k, X, k+1] += E[X, Y] * w 137 | 138 | # compute prefix probability 139 | for l in range(2, N+1): 140 | for i in range(N-l+1): 141 | k = i + l 142 | for j in range(i+1, k): 143 | for X in self.cfg.V: 144 | for Y in self.cfg.V: 145 | for Z in self.cfg.V: 146 | ppre[i, X, k] += E2[X, Y, Z] * β[i, Y, j] \ 147 | * ppre[j, Z, k] 148 | 149 | return ppre if chart else ppre[0, S, N] 150 | 151 | def lri_fast(self, input, chart=False): 152 | """Faster prefix parsing algorithm by Nowak and Cotterell (2023). Requires CNF.""" 153 | assert self.cfg.in_cnf 154 | 155 | # convert input string to list 156 | if type(input) == str: 157 | input = [Sym(token) for token in input.split()] 158 | 159 | # initialization 160 | N = len(input) 161 | V = self.cfg.ordered_V 162 | V_idx = {X:i for i,X in enumerate(V)} 163 | ppre = dd(lambda: 0.0) 164 | for k in range(N+1): 165 | for X in self.cfg.V: 166 | ppre[k, X, k] = 1 167 | 168 | # precompute β using CKY 169 | β = self.cky_fast(input, chart=True) 170 | 171 | # precompute E 172 | E = dd(lambda: 0.0) 173 | P_L = self.plc() 174 | for X in self.cfg.V: 175 | for Y in self.cfg.V: 176 | E[X, Y] = P_L[V_idx[X], V_idx[Y]] 177 | 178 | # precompute γ and δ 179 | γ = dd(lambda: 0.0) 180 | δ = dd(lambda: 0.0) 181 | for i in range(N): 182 | for j in range(N): 183 | for p, w in self.cfg.binary: 184 | X, Y, Z = p.head, p.body[0], p.body[1] 185 | γ[i, j, X, Z] += w * β[i, Y, j] 186 | for X in self.cfg.V: 187 | for Y in self.cfg.V: 188 | for Z in self.cfg.V: 189 | δ[i, j, X, Z] += E[X, Y] * γ[i, j, Y, Z] 190 | 191 | # compute base case 192 | for X in self.cfg.V: 193 | for i in range(N): 194 | for p, w in self.cfg.terminal: 195 | Y, v = p.head, p.body[0] 196 | if v == input[i]: 197 | ppre[i, X, i+1] += E[X, Y] * w 198 | 199 | # compute prefix probability 200 | for l in range(2, N+1): 201 | for i in range(N-l+1): 202 | k = i + l 203 | for j in range(i+1, k): 204 | for X in self.cfg.V: 205 | for Z in self.cfg.V: 206 | ppre[i, X, k] += δ[i, j, X, Z] * ppre[j, Z, k] 207 | 208 | return ppre if chart else ppre[0, S, N] 209 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | install_requires = [ 4 | "numpy", 5 | "pytest", 6 | ] 7 | 8 | 9 | setup( 10 | name="fastlri", 11 | install_requires=install_requires, 12 | version="0.1", 13 | scripts=[], 14 | packages=["fastlri"], 15 | ) 16 | -------------------------------------------------------------------------------- /test/test_lri.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastlri.base.cfg import CFG 3 | from fastlri.base.nonterminal import NT 4 | from fastlri.base.symbol import Sym 5 | from fastlri.parsing.parser import Parser 6 | 7 | 8 | def get_simple_cfg(): 9 | Y = NT("Y") 10 | Z = NT("Z") 11 | a = Sym("a") 12 | 13 | cfg = CFG() 14 | cfg.add(1, cfg.S, Y, Z) 15 | cfg.add(0.5, Y, Z, Y) 16 | cfg.add(0.5, Y, a) 17 | cfg.add(1, Z, a) 18 | 19 | return cfg 20 | 21 | def get_complex_cfg(): 22 | NP = NT("NP") 23 | VP = NT("VP") 24 | Det = NT("Det") 25 | N = NT("N") 26 | PP = NT("PP") 27 | V = NT("V") 28 | Adj = NT("Adj") 29 | Adv = NT("Adv") 30 | AdvP = NT("AdvP") 31 | fruit = Sym("fruit") 32 | flies = Sym("flies") 33 | like = Sym("like") 34 | a = Sym("a") 35 | green = Sym("green") 36 | banana = Sym("banana") 37 | 38 | cfg = CFG() 39 | cfg.add(1, cfg.S, NP, VP) 40 | cfg.add(0.25, NP, Det, N) 41 | cfg.add(0.25, NP, Det, NP) 42 | cfg.add(0.25, NP, N, N) 43 | cfg.add(0.25, NP, Adj, N) 44 | cfg.add(1, VP, V, NP) 45 | cfg.add(1, AdvP, Adv, NP) 46 | cfg.add(0.5, N, fruit) 47 | cfg.add(0.25, N, flies) 48 | cfg.add(0.25, N, banana) 49 | cfg.add(0.5, V, flies) 50 | cfg.add(0.5, V, like) 51 | cfg.add(1, Det, a) 52 | cfg.add(1, Adj, green) 53 | cfg.add(1, Adv, like) 54 | 55 | return cfg 56 | 57 | class TestLri: 58 | def test_plc(self): 59 | parser = Parser(get_simple_cfg()) 60 | P_L = parser.plc() 61 | correct = np.array([[1. , 1. , 0.5],\ 62 | [0. , 1. , 0.5],\ 63 | [0. , 0. , 1. ]]) 64 | assert np.all(P_L == correct) 65 | 66 | def test_cky(self): 67 | parser = Parser(get_complex_cfg()) 68 | pins = parser.cky("fruit flies", chart=True) 69 | assert pins[(0, NT("N"), 1)] == 0.5 70 | assert pins[(1, NT("N"), 2)] == 0.25 71 | assert pins[(0, NT("NP"), 2)] == 0.03125 72 | 73 | def test_cky_fast(self): 74 | parser = Parser(get_complex_cfg()) 75 | pins = parser.cky_fast("fruit flies", chart=True) 76 | assert pins[(0, NT("N"), 1)] == 0.5 77 | assert pins[(1, NT("N"), 2)] == 0.25 78 | assert pins[(0, NT("NP"), 2)] == 0.03125 79 | 80 | 81 | def test_lri(self): 82 | parser = Parser(get_simple_cfg()) 83 | ppre = parser.lri("a a a", chart=True) 84 | assert ppre[(0, parser.cfg.S, 0)] == 1.0 85 | assert ppre[(1, parser.cfg.S, 1)] == 1.0 86 | assert ppre[(0, parser.cfg.S, 1)] == 1.0 87 | assert ppre[(0, parser.cfg.S, 2)] == 1.0 88 | assert ppre[(0, parser.cfg.S, 3)] == 0.5 89 | 90 | parser = Parser(get_complex_cfg()) 91 | ppre = parser.lri("fruit flies", chart=True) 92 | assert ppre[(0, parser.cfg.S, 0)] == 1.0 93 | assert ppre[(1, parser.cfg.S, 1)] == 1.0 94 | assert ppre[(0, parser.cfg.S, 1)] == 0.125 95 | assert ppre[(1, parser.cfg.S, 2)] == 0.0625 96 | assert ppre[(0, parser.cfg.S, 2)] == 0.03125 97 | 98 | def test_lri_fast(self): 99 | parser = Parser(get_simple_cfg()) 100 | ppre = parser.lri_fast("a a a", chart=True) 101 | assert ppre[(0, parser.cfg.S, 0)] == 1.0 102 | assert ppre[(1, parser.cfg.S, 1)] == 1.0 103 | assert ppre[(0, parser.cfg.S, 1)] == 1.0 104 | assert ppre[(0, parser.cfg.S, 2)] == 1.0 105 | assert ppre[(0, parser.cfg.S, 3)] == 0.5 106 | 107 | parser = Parser(get_complex_cfg()) 108 | ppre = parser.lri_fast("fruit flies", chart=True) 109 | assert ppre[(0, parser.cfg.S, 0)] == 1.0 110 | assert ppre[(1, parser.cfg.S, 1)] == 1.0 111 | assert ppre[(0, parser.cfg.S, 1)] == 0.125 112 | assert ppre[(1, parser.cfg.S, 2)] == 0.0625 113 | assert ppre[(0, parser.cfg.S, 2)] == 0.03125 114 | --------------------------------------------------------------------------------