├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY.md ├── grammars ├── calculator.lark └── python.lark ├── incremental_parsing ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── evaluate_stack.py │ ├── gen_tables.sql │ └── text_cuts.py ├── generation │ ├── __init__.py │ ├── constrained_generation.py │ ├── lex_earley_worker.py │ ├── logits_processor.py │ ├── probability_trie.py │ ├── stack_completion_right_context.py │ └── utils.py ├── lex_earley │ ├── __init__.py │ ├── branch_guard │ │ ├── __init__.py │ │ ├── banned_regex_set.py │ │ ├── combined_guard.py │ │ ├── lexer_branch_guard.py │ │ └── single_char_banned_regex_set.py │ ├── earley_base.py │ ├── earley_nfa.py │ ├── earley_trie.py │ ├── incremental_pattern.py │ ├── lark_grammar.py │ ├── lex_earley.py │ ├── lexer.py │ ├── middle_earley.py │ ├── native_earley_nfa.py │ ├── native_earley_trie.py │ ├── python_lex_wrapper.py │ ├── simple_bnf.py │ └── test.py ├── regex_ixes │ ├── __init__.py │ ├── regex_compile.py │ ├── regex_nfa.py │ ├── regex_parse.py │ └── regex_tree.py └── utils │ ├── __init__.py │ ├── flags_to_regex_flags.py │ ├── indexable_container.py │ ├── lookback_trie.py │ └── simple_nfa.py ├── mypy.ini ├── notebooks ├── create_figures.ipynb ├── create_parse_hierarchy_viz_calc_lang.ipynb ├── create_parse_hierarchy_viz_python.ipynb ├── example_parsing.ipynb ├── interactive_constrained_generation.ipynb ├── interactive_recognition.py ├── paper_examples.ipynb └── timing_graphs.ipynb ├── pyproject.toml ├── requirements.txt ├── scripts ├── constrained_generation_nice_cuts.sh └── constrained_generation_random_cuts.sh └── src ├── bridge.rs ├── bridge ├── bnf.rs └── charts.rs ├── grammar.rs ├── grammar ├── bnf.rs └── names.rs ├── lib.rs ├── parser.rs └── parser └── earley.rs /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "autocfg" 7 | version = "1.3.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" 10 | 11 | [[package]] 12 | name = "bitflags" 13 | version = "2.5.0" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" 16 | 17 | [[package]] 18 | name = "cfg-if" 19 | version = "1.0.0" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 22 | 23 | [[package]] 24 | name = "heck" 25 | version = "0.4.1" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" 28 | 29 | [[package]] 30 | name = "incremental_parsing_rust" 31 | version = "0.1.0" 32 | dependencies = [ 33 | "pyo3", 34 | ] 35 | 36 | [[package]] 37 | name = "indoc" 38 | version = "2.0.5" 39 | source = "registry+https://github.com/rust-lang/crates.io-index" 40 | checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" 41 | 42 | [[package]] 43 | name = "libc" 44 | version = "0.2.155" 45 | source = "registry+https://github.com/rust-lang/crates.io-index" 46 | checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" 47 | 48 | [[package]] 49 | name = "lock_api" 50 | version = "0.4.12" 51 | source = "registry+https://github.com/rust-lang/crates.io-index" 52 | checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" 53 | dependencies = [ 54 | "autocfg", 55 | "scopeguard", 56 | ] 57 | 58 | [[package]] 59 | name = "memoffset" 60 | version = "0.9.1" 61 | source = "registry+https://github.com/rust-lang/crates.io-index" 62 | checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" 63 | dependencies = [ 64 | "autocfg", 65 | ] 66 | 67 | [[package]] 68 | name = "once_cell" 69 | version = "1.19.0" 70 | source = "registry+https://github.com/rust-lang/crates.io-index" 71 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" 72 | 73 | [[package]] 74 | name = "parking_lot" 75 | version = "0.12.3" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" 78 | dependencies = [ 79 | "lock_api", 80 | "parking_lot_core", 81 | ] 82 | 83 | [[package]] 84 | name = "parking_lot_core" 85 | version = "0.9.10" 86 | source = "registry+https://github.com/rust-lang/crates.io-index" 87 | checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" 88 | dependencies = [ 89 | "cfg-if", 90 | "libc", 91 | "redox_syscall", 92 | "smallvec", 93 | "windows-targets", 94 | ] 95 | 96 | [[package]] 97 | name = "portable-atomic" 98 | version = "1.6.0" 99 | source = "registry+https://github.com/rust-lang/crates.io-index" 100 | checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" 101 | 102 | [[package]] 103 | name = "proc-macro2" 104 | version = "1.0.85" 105 | source = "registry+https://github.com/rust-lang/crates.io-index" 106 | checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" 107 | dependencies = [ 108 | "unicode-ident", 109 | ] 110 | 111 | [[package]] 112 | name = "pyo3" 113 | version = "0.21.2" 114 | source = "registry+https://github.com/rust-lang/crates.io-index" 115 | checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" 116 | dependencies = [ 117 | "cfg-if", 118 | "indoc", 119 | "libc", 120 | "memoffset", 121 | "parking_lot", 122 | "portable-atomic", 123 | "pyo3-build-config", 124 | "pyo3-ffi", 125 | "pyo3-macros", 126 | "unindent", 127 | ] 128 | 129 | [[package]] 130 | name = "pyo3-build-config" 131 | version = "0.21.2" 132 | source = "registry+https://github.com/rust-lang/crates.io-index" 133 | checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" 134 | dependencies = [ 135 | "once_cell", 136 | "target-lexicon", 137 | ] 138 | 139 | [[package]] 140 | name = "pyo3-ffi" 141 | version = "0.21.2" 142 | source = "registry+https://github.com/rust-lang/crates.io-index" 143 | checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" 144 | dependencies = [ 145 | "libc", 146 | "pyo3-build-config", 147 | ] 148 | 149 | [[package]] 150 | name = "pyo3-macros" 151 | version = "0.21.2" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" 154 | dependencies = [ 155 | "proc-macro2", 156 | "pyo3-macros-backend", 157 | "quote", 158 | "syn", 159 | ] 160 | 161 | [[package]] 162 | name = "pyo3-macros-backend" 163 | version = "0.21.2" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" 166 | dependencies = [ 167 | "heck", 168 | "proc-macro2", 169 | "pyo3-build-config", 170 | "quote", 171 | "syn", 172 | ] 173 | 174 | [[package]] 175 | name = "quote" 176 | version = "1.0.36" 177 | source = "registry+https://github.com/rust-lang/crates.io-index" 178 | checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" 179 | dependencies = [ 180 | "proc-macro2", 181 | ] 182 | 183 | [[package]] 184 | name = "redox_syscall" 185 | version = "0.5.1" 186 | source = "registry+https://github.com/rust-lang/crates.io-index" 187 | checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" 188 | dependencies = [ 189 | "bitflags", 190 | ] 191 | 192 | [[package]] 193 | name = "scopeguard" 194 | version = "1.2.0" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 197 | 198 | [[package]] 199 | name = "smallvec" 200 | version = "1.13.2" 201 | source = "registry+https://github.com/rust-lang/crates.io-index" 202 | checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" 203 | 204 | [[package]] 205 | name = "syn" 206 | version = "2.0.66" 207 | source = "registry+https://github.com/rust-lang/crates.io-index" 208 | checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" 209 | dependencies = [ 210 | "proc-macro2", 211 | "quote", 212 | "unicode-ident", 213 | ] 214 | 215 | [[package]] 216 | name = "target-lexicon" 217 | version = "0.12.14" 218 | source = "registry+https://github.com/rust-lang/crates.io-index" 219 | checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" 220 | 221 | [[package]] 222 | name = "unicode-ident" 223 | version = "1.0.12" 224 | source = "registry+https://github.com/rust-lang/crates.io-index" 225 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 226 | 227 | [[package]] 228 | name = "unindent" 229 | version = "0.2.3" 230 | source = "registry+https://github.com/rust-lang/crates.io-index" 231 | checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" 232 | 233 | [[package]] 234 | name = "windows-targets" 235 | version = "0.52.5" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" 238 | dependencies = [ 239 | "windows_aarch64_gnullvm", 240 | "windows_aarch64_msvc", 241 | "windows_i686_gnu", 242 | "windows_i686_gnullvm", 243 | "windows_i686_msvc", 244 | "windows_x86_64_gnu", 245 | "windows_x86_64_gnullvm", 246 | "windows_x86_64_msvc", 247 | ] 248 | 249 | [[package]] 250 | name = "windows_aarch64_gnullvm" 251 | version = "0.52.5" 252 | source = "registry+https://github.com/rust-lang/crates.io-index" 253 | checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" 254 | 255 | [[package]] 256 | name = "windows_aarch64_msvc" 257 | version = "0.52.5" 258 | source = "registry+https://github.com/rust-lang/crates.io-index" 259 | checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" 260 | 261 | [[package]] 262 | name = "windows_i686_gnu" 263 | version = "0.52.5" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" 266 | 267 | [[package]] 268 | name = "windows_i686_gnullvm" 269 | version = "0.52.5" 270 | source = "registry+https://github.com/rust-lang/crates.io-index" 271 | checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" 272 | 273 | [[package]] 274 | name = "windows_i686_msvc" 275 | version = "0.52.5" 276 | source = "registry+https://github.com/rust-lang/crates.io-index" 277 | checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" 278 | 279 | [[package]] 280 | name = "windows_x86_64_gnu" 281 | version = "0.52.5" 282 | source = "registry+https://github.com/rust-lang/crates.io-index" 283 | checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" 284 | 285 | [[package]] 286 | name = "windows_x86_64_gnullvm" 287 | version = "0.52.5" 288 | source = "registry+https://github.com/rust-lang/crates.io-index" 289 | checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" 290 | 291 | [[package]] 292 | name = "windows_x86_64_msvc" 293 | version = "0.52.5" 294 | source = "registry+https://github.com/rust-lang/crates.io-index" 295 | checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" 296 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "incremental_parsing_rust" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "incremental_parsing_rust" 8 | crate-type = ["cdylib"] 9 | 10 | [profile.release] 11 | debug = true 12 | 13 | [dependencies.pyo3] 14 | version = "0.21.2" 15 | features = ["extension-module", "abi3-py39"] 16 | 17 | #[dependencies.regex-automata] 18 | #features = ["dfa-search", "std", "syntax", "dfa-build"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Incremental Parser 2 | 3 | This repository contains code to perform incremental parsing for Python for constrained generation of code by LLMs. 4 | It is a reference implementation of the following paper and please cite it if you find the repo useful. 5 | 6 | "[Constrained Decoding for Fill-in-the-Middle Code Language Models via Efficient Left and Right Quotienting of Context-Sensitive Grammars](https://arxiv.org/pdf/2402.17988.pdf)," 7 | Daniel Melcer, Nathan Fulton, Sanjay Krishna Gouda and Haifeng Qian, *arXiv:2402.17988*, 2024. 8 | 9 | ## Installation 10 | 11 | For exact reproducibility, we are using CUDA version 12.0, driver version 525.85.12, on a A100 GPU. 12 | Python version 3.9.19, and the Rust toolchain version 1.79.0. 13 | We use the version of Santacoder that was released on December 20, 2022. 14 | 15 | 1. Install everything in requirements.txt. 16 | 2. Run `maturin develop --release` in your virtual environment. 17 | - Note that IDEs sometimes have trouble with automatic code completion for this. 18 | As long as you get the message `Installed incremental_parsing_rust-0.1.0`, the library is installed in the virtual environment. 19 | 3. Edit `scripts/constrained_generation_{random/nice}_cuts.sh` so that `results_path` is an absolute path. 20 | The program will create a (large) folder at this path. 21 | Also, edit the loop max, device name, and min/max data indices so that it fits your hardware and eventually loops 22 | through data indices 0 through 9999. 23 | 4. `PYTHONPATH=. scripts/constrained_generation_random_cuts.sh` 24 | - Read the documentation for `hapless` (`hap --help`) for information about process management 25 | 5. When done, edit the source path at the bottom of `incremental_parsing/evaluation/evaluate_stack.py` to match the 26 | results path. 27 | Edit the destination path to be somewhere you want a csv file to be created. 28 | 6. Import the csv file into a sqlite table named `stack`, and then use `incremental_parsing/evaluation/gen_tables.sql` 29 | to obtain the numbers from the paper. 30 | 31 | You can also use the following interactive scripts in the `notebooks` directory: 32 | 33 | - `create_parse_hierarchy_viz_python.ipynb` creates a parse hierarchy from left and right contexts, and outputs a 34 | visualization of this. 35 | Note that there might be multiple active branches with different parse hierarchies; the visualizer requires you to 36 | select one branch. 37 | - `create_parse_hierarchy_viz_calc_lang.ipynb` is the same for a much simpler language, a calculator language with 38 | tuples. 39 | It is significantly easier to inspect the output and understand what is going on here. 40 | - `interactive_constrained_generation.ipynb` generates code, and shows all the left contexts which are considered to be 41 | a member of the quotient language, plus their scores from the LLM. 42 | - `interactive_recognition.py` lets you type and see whether some text is in the quotient language, is incrementally 43 | parsable, or cannot be a prefix of a member of the quotient language. 44 | - `paper_examples.ipynb` reproduces code generation examples. 45 | 46 | 47 | ## Security 48 | 49 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 50 | 51 | 52 | ## License 53 | 54 | This project is licensed under the Apache-2.0 License. 55 | 56 | -------------------------------------------------------------------------------- /THIRD-PARTY.md: -------------------------------------------------------------------------------- 1 | This repository includes the following third-party software: 2 | 3 | - Lark - a parsing toolkit for Python (https://github.com/lark-parser/lark) 4 | - `grammars/python.lark` 5 | - Sections of `incremental_parsing/lex_earley/python_lex_wrapper.py` - `PythonLexWrapper` - `process_single_token` 6 | 7 | Copyright © 2017 Erez Shinan 8 | 9 | - Nullable Rule Calculation (https://github.com/jeffreykegler/old_kollos/blob/master/notes/misc/loup2.md) 10 | - `incremental_parsing/lex_earley/simple_bnf.py` - `SimpleBNF` - `get_nullable_rules` 11 | 12 | Copyright © 2014 Jeffrey Kegler 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy of 15 | this software and associated documentation files (the "Software"), to deal in 16 | the Software without restriction, including without limitation the rights to 17 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 18 | the Software, and to permit persons to whom the Software is furnished to do so, 19 | subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 26 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 27 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 28 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 29 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 30 | -------------------------------------------------------------------------------- /grammars/calculator.lark: -------------------------------------------------------------------------------- 1 | start: expression 2 | 3 | comma_expr: "," expression 4 | | "," expression comma_expr 5 | 6 | expression: "(" expression ")" 7 | | "(" expression comma_expr ")" 8 | | expression binop expression 9 | | unop expression 10 | | NUMBER 11 | 12 | NUMBER : /\d+/ 13 | 14 | binop : "-"|"+"|"*"|"/" 15 | unop : "-" 16 | -------------------------------------------------------------------------------- /grammars/python.lark: -------------------------------------------------------------------------------- 1 | // Python 3 grammar for Lark 2 | // From https://github.com/lark-parser/lark/blob/master/lark/grammars/python.lark 3 | // But with "_" identifiers fixed 4 | 5 | // This grammar should parse all python 3.x code successfully. 6 | 7 | // Adapted from: https://docs.python.org/3/reference/grammar.html 8 | 9 | // Start symbols for the grammar: 10 | // single_input is a single interactive statement; 11 | // file_input is a module or sequence of commands read from an input file; 12 | // eval_input is the input for the eval() functions. 13 | // NB: compound_stmt in single_input is followed by extra NEWLINE! 14 | // 15 | 16 | single_input: _NEWLINE | simple_stmt | compound_stmt _NEWLINE 17 | file_input: (_NEWLINE | stmt)* 18 | eval_input: testlist _NEWLINE* 19 | 20 | decorator: "@" dotted_name [ "(" [arguments] ")" ] _NEWLINE 21 | decorators: decorator+ 22 | decorated: decorators (classdef | funcdef | async_funcdef) 23 | 24 | async_funcdef: "async" funcdef 25 | funcdef: "def" name "(" [parameters] ")" ["->" test] ":" suite 26 | 27 | parameters: slash_params_no_default ("," typedparam)* ("," paramvalue)* ("," star_etc)? ","? 28 | | slash_params_with_default ("," paramvalue)* ("," star_etc)? ","? 29 | | typedparam ("," typedparam)* ("," paramvalue)* ("," star_etc)? ","? 30 | | paramvalue ("," paramvalue)* ("," star_etc)? ","? 31 | | star_etc ","? 32 | SLASH: "/" 33 | slash_params_no_default: typedparam ("," typedparam)* "," SLASH 34 | 35 | slash_params_with_default: typedparam ("," paramvalue)+ "," SLASH 36 | | paramvalue ("," paramvalue)+ "," SLASH 37 | 38 | 39 | star_etc: "*" typedparam ("," param_maybe_default)* ["," kwds] 40 | | "*" ("," param_maybe_default)+ ["," kwds] 41 | | kwds 42 | 43 | 44 | kwds: "**" typedparam 45 | 46 | ?paramvalue: typedparam ("=" test) 47 | ?typedparam: name (":" test)? 48 | 49 | param_maybe_default: typedparam ("=" test)? 50 | 51 | 52 | lambdef: "lambda" [lambda_params] ":" test 53 | lambdef_nocond: "lambda" [lambda_params] ":" test_nocond 54 | lambda_params: lambda_paramvalue ("," lambda_paramvalue)* ["," [lambda_starparams | lambda_kwparams]] 55 | | lambda_starparams 56 | | lambda_kwparams 57 | ?lambda_paramvalue: name ("=" test)? 58 | lambda_starparams: "*" [name] ("," lambda_paramvalue)* ["," [lambda_kwparams]] 59 | lambda_kwparams: "**" name ","? 60 | 61 | 62 | ?stmt: simple_stmt | compound_stmt 63 | ?simple_stmt: small_stmt (";" small_stmt)* [";"] _NEWLINE 64 | ?small_stmt: (expr_stmt | assign_stmt | del_stmt | pass_stmt | flow_stmt | import_stmt | global_stmt | nonlocal_stmt | assert_stmt) 65 | expr_stmt: testlist_star_expr 66 | assign_stmt: annassign | augassign | assign 67 | 68 | annassign: single_target ":" test ["=" test] 69 | assign: (star_targets "=")+ (yield_expr|testlist_star_expr) 70 | augassign: single_target augassign_op (yield_expr|testlist) 71 | !augassign_op: "+=" | "-=" | "*=" | "@=" | "/=" | "%=" | "&=" | "|=" | "^=" | "<<=" | ">>=" | "**=" | "//=" 72 | ?testlist_star_expr: test_or_star_expr 73 | | test_or_star_expr ("," test_or_star_expr)+ ","? -> tuple 74 | | test_or_star_expr "," -> tuple 75 | 76 | // For normal and annotated assignments, additional restrictions enforced by the interpreter 77 | del_stmt: "del" exprlist 78 | pass_stmt: "pass" 79 | ?flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt | yield_stmt 80 | break_stmt: "break" 81 | continue_stmt: "continue" 82 | return_stmt: "return" [testlist] 83 | yield_stmt: yield_expr 84 | raise_stmt: "raise" [test ["from" test]] 85 | import_stmt: import_name | import_from 86 | import_name: "import" dotted_as_names 87 | // note below: the ("." | "...") is necessary because "..." is tokenized as ELLIPSIS 88 | import_from: "from" (dots? dotted_name | dots) "import" ("*" | "(" import_as_names ")" | import_as_names) 89 | !dots: "." | ".." | ("..." dots?) 90 | import_as_name: name ["as" name] 91 | dotted_as_name: dotted_name ["as" name] 92 | import_as_names: import_as_name ("," import_as_name)* [","] 93 | dotted_as_names: dotted_as_name ("," dotted_as_name)* 94 | dotted_name: name ("." name)* 95 | global_stmt: "global" name ("," name)* 96 | nonlocal_stmt: "nonlocal" name ("," name)* 97 | assert_stmt: "assert" test ["," test] 98 | 99 | ?compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | match_stmt 100 | | with_stmt | funcdef | classdef | decorated | async_stmt 101 | async_stmt: "async" (funcdef | with_stmt | for_stmt) 102 | if_stmt: "if" test ":" suite elifs ["else" ":" suite] 103 | elifs: elif_* 104 | elif_: "elif" test ":" suite 105 | while_stmt: "while" test ":" suite ["else" ":" suite] 106 | for_stmt: "for" star_targets "in" testlist ":" suite ["else" ":" suite] 107 | try_stmt: "try" ":" suite except_clauses ["else" ":" suite] [finally] 108 | | "try" ":" suite finally -> try_finally 109 | finally: "finally" ":" suite 110 | except_clauses: except_clause+ 111 | except_clause: "except" [test ["as" name]] ":" suite 112 | // NB compile.c makes sure that the default except clause is last 113 | 114 | 115 | with_stmt: "with" with_items ":" suite 116 | with_items: with_item ("," with_item)* 117 | with_item: test ["as" exprlist] 118 | 119 | match_stmt: "match" test ":" _NEWLINE _INDENT case+ _DEDENT 120 | 121 | case: "case" pattern ["if" test] ":" suite 122 | 123 | ?pattern: sequence_item_pattern "," _sequence_pattern -> sequence_pattern 124 | | as_pattern 125 | ?as_pattern: or_pattern ("as" NAME)? 126 | ?or_pattern: closed_pattern ("|" closed_pattern)* 127 | ?closed_pattern: literal_pattern 128 | | NAME -> capture_pattern 129 | | "_" -> any_pattern 130 | | attr_pattern 131 | | "(" as_pattern ")" 132 | | "[" _sequence_pattern "]" -> sequence_pattern 133 | | "(" (sequence_item_pattern "," _sequence_pattern)? ")" -> sequence_pattern 134 | | "{" (mapping_item_pattern ("," mapping_item_pattern)* ","?)?"}" -> mapping_pattern 135 | | "{" (mapping_item_pattern ("," mapping_item_pattern)* ",")? "**" NAME ","? "}" -> mapping_star_pattern 136 | | class_pattern 137 | 138 | literal_pattern: inner_literal_pattern 139 | 140 | ?inner_literal_pattern: "None" -> const_none 141 | | "True" -> const_true 142 | | "False" -> const_false 143 | | STRING -> string 144 | | number 145 | 146 | attr_pattern: NAME ("." NAME)+ -> value 147 | 148 | name_or_attr_pattern: NAME ("." NAME)* -> value 149 | 150 | mapping_item_pattern: (literal_pattern|attr_pattern) ":" as_pattern 151 | 152 | _sequence_pattern: (sequence_item_pattern ("," sequence_item_pattern)* ","?)? 153 | ?sequence_item_pattern: as_pattern 154 | | "*" NAME -> star_pattern 155 | 156 | class_pattern: name_or_attr_pattern "(" [arguments_pattern ","?] ")" 157 | arguments_pattern: pos_arg_pattern ["," keyws_arg_pattern] 158 | | keyws_arg_pattern -> no_pos_arguments 159 | 160 | pos_arg_pattern: as_pattern ("," as_pattern)* 161 | keyws_arg_pattern: keyw_arg_pattern ("," keyw_arg_pattern)* 162 | keyw_arg_pattern: NAME "=" as_pattern 163 | 164 | 165 | 166 | suite: simple_stmt | _NEWLINE _INDENT stmt+ _DEDENT 167 | 168 | ?test: or_test ("if" or_test "else" test)? 169 | | lambdef 170 | | assign_expr 171 | 172 | assign_expr: name ":=" test 173 | 174 | ?test_nocond: or_test | lambdef_nocond 175 | 176 | ?or_test: and_test ("or" and_test)* 177 | ?and_test: not_test_ ("and" not_test_)* 178 | ?not_test_: "not" not_test_ -> not_test 179 | | comparison 180 | ?comparison: expr (comp_op expr)* 181 | star_expr: "*" expr 182 | 183 | ?expr: or_expr 184 | ?or_expr: xor_expr ("|" xor_expr)* 185 | ?xor_expr: and_expr ("^" and_expr)* 186 | ?and_expr: shift_expr ("&" shift_expr)* 187 | ?shift_expr: arith_expr (_shift_op arith_expr)* 188 | ?arith_expr: term (_add_op term)* 189 | ?term: factor (_mul_op factor)* 190 | ?factor: _unary_op factor | power 191 | 192 | !_unary_op: "+"|"-"|"~" 193 | !_add_op: "+"|"-" 194 | !_shift_op: "<<"|">>" 195 | !_mul_op: "*"|"@"|"/"|"%"|"//" 196 | // <> isn't actually a valid comparison operator in Python. It's here for the 197 | // sake of a __future__ import described in PEP 401 (which really works :-) 198 | !comp_op: "<"|">"|"=="|">="|"<="|"<>"|"!="|"in"|"not" "in"|"is"|"is" "not" 199 | 200 | ?power: await_expr ("**" factor)? 201 | ?await_expr: AWAIT? atom_expr 202 | AWAIT: "await" 203 | 204 | ?atom_expr: atom_expr "(" [arguments] ")" -> funccall 205 | | atom_expr "[" subscriptlist "]" -> getitem 206 | | atom_expr "." name -> getattr 207 | | atom 208 | 209 | ?atom: "(" yield_expr ")" 210 | | "(" _tuple_inner? ")" -> tuple 211 | | "(" comprehension{test_or_star_expr} ")" -> tuple_comprehension 212 | | "[" _exprlist? "]" -> list 213 | | "[" comprehension{test_or_star_expr} "]" -> list_comprehension 214 | | "{" _dict_exprlist? "}" -> dict 215 | | "{" comprehension{key_value} "}" -> dict_comprehension 216 | | "{" _exprlist "}" -> set 217 | | "{" comprehension{test} "}" -> set_comprehension 218 | | name -> var 219 | | number 220 | | string_concat 221 | | "(" test ")" 222 | | "..." -> ellipsis 223 | | "None" -> const_none 224 | | "True" -> const_true 225 | | "False" -> const_false 226 | 227 | 228 | star_targets: star_target ("," star_target)* ","? 229 | 230 | star_target: "*"? target_with_star_atom 231 | 232 | target_with_star_atom: atom_expr "." name 233 | | atom_expr "[" subscriptlist "]" 234 | | star_atom 235 | 236 | star_atom: name 237 | | "(" target_with_star_atom ")" 238 | | "(" [star_target ","] ")" 239 | | "(" star_target ("," star_target)+ ","? ")" 240 | | "[" [star_target ("," star_target)* ","?] "]" 241 | 242 | single_target: name 243 | | atom_expr "." name 244 | | atom_expr "[" subscriptlist "]" 245 | | "(" single_target ")" 246 | 247 | 248 | 249 | 250 | ?string_concat: string+ 251 | 252 | _tuple_inner: test_or_star_expr (("," test_or_star_expr)+ [","] | ",") 253 | 254 | ?test_or_star_expr: test 255 | | star_expr 256 | 257 | ?subscriptlist: subscript 258 | | subscript (("," subscript)+ [","] | ",") -> subscript_tuple 259 | ?subscript: test | ([test] ":" [test] [sliceop]) -> slice 260 | sliceop: ":" [test] 261 | ?exprlist: (expr|star_expr) 262 | | (expr|star_expr) (("," (expr|star_expr))+ [","]|",") 263 | ?testlist: test | testlist_tuple 264 | testlist_tuple: test (("," test)+ [","] | ",") 265 | _dict_exprlist: (key_value | "**" expr) ("," (key_value | "**" expr))* [","] 266 | 267 | key_value: test ":" test 268 | 269 | _exprlist: test_or_star_expr ("," test_or_star_expr)* [","] 270 | 271 | classdef: "class" name ["(" [arguments] ")"] ":" suite 272 | 273 | arguments: nonkwargs ["," kwargs] ["," kwstarargs] [","] 274 | | kwargs ["," kwstarargs] [","] 275 | | kwstarargs [","] 276 | | comprehension{test} 277 | 278 | nonkwargs: nonkwarg ("," nonkwarg)* 279 | 280 | nonkwarg: "*" test 281 | | test 282 | 283 | kwargs: kwarg ("," kwarg)* 284 | 285 | // For some reason *args are allowed after keyword args, but not **args 286 | kwarg: "*" test 287 | | name "=" test 288 | 289 | kwstarargs: kwstararg ("," kwstararg)* 290 | 291 | kwstararg: "**" test 292 | | name "=" test 293 | 294 | 295 | comprehension{comp_result}: comp_result comp_for (comp_for | comp_if)* 296 | comp_for: [ASYNC] "for" star_targets "in" or_test 297 | ASYNC: "async" 298 | ?comp_if: "if" test_nocond 299 | 300 | // not used in grammar, but may appear in "node" passed from Parser to Compiler 301 | encoding_decl: name 302 | 303 | yield_expr: "yield" [testlist] 304 | | "yield" "from" test -> yield_from 305 | 306 | number: DEC_NUMBER | HEX_NUMBER | BIN_NUMBER | OCT_NUMBER | FLOAT_NUMBER | IMAG_NUMBER 307 | string: STRING | LONG_STRING 308 | 309 | // Other terminals 310 | 311 | _NEWLINE: ( /(\r?\n|\r)[\t ]*/ | COMMENT )+ 312 | 313 | %ignore /[\t \f]+/ // WS 314 | %ignore /\\[\t \f]*(\r?\n|\r)/ // LINE_CONT 315 | %ignore COMMENT 316 | %declare _INDENT _DEDENT 317 | 318 | 319 | // Python terminals 320 | 321 | !name: NAME | "match" | "case" | "_" 322 | NAME: /[^\W\d]\w*/ 323 | COMMENT: /#[^\n\r]*/ 324 | 325 | STRING: /([ubf]?r?|r[ubf])(?:"(?:[^\\\r\n"]|\\.|\\\r\n)*"|'(?:[^\n\r\\']|\\.|\\\r\n)*')/is 326 | 327 | // Ugh a backslash before double quotes (but not single quotes) becomes escaping here which is why the regex looks ugly 328 | LONG_STRING: /([ubf]?r?|r[ubf])(?:"""(?:(?:[^\\\"]|\\.)|"(?:[^\\\"]|\\.)|""(?:[^\\\"]|\\.))*"""|'''(?:(?:[^\\']|\\.)|'(?:[^\\']|\\.)|''(?:[^\\']|\\.))*''')/is 329 | 330 | _SPECIAL_DEC: "0".."9" ("_"? "0".."9" )* 331 | DEC_NUMBER: "1".."9" ("_"? "0".."9" )* 332 | | "0" ("_"? "0" )* 333 | // Technically should have the lookahead /(?![1-9])/ but probably not actually necessary; the parser should 334 | // disallow two numbers next to each other 335 | HEX_NUMBER.2: "0" ("x" | "X") ("_"? ("0".."9" | "a".."f" | "A".."F"))+ 336 | OCT_NUMBER.2: "0" ("o" | "O") ("_"? "0".."7" )+ 337 | BIN_NUMBER.2: "0" ("b" | "B") ("_"? "0".."1" )+ 338 | 339 | _EXP: ("e"|"E") ["+" | "-"] _SPECIAL_DEC 340 | DECIMAL: "." _SPECIAL_DEC | _SPECIAL_DEC "." _SPECIAL_DEC? 341 | FLOAT_NUMBER.2: _SPECIAL_DEC _EXP | DECIMAL _EXP? 342 | IMAG_NUMBER.2: (_SPECIAL_DEC | FLOAT_NUMBER) ("J" | "j") 343 | 344 | 345 | // Comma-separated list (with an optional trailing comma) 346 | cs_list{item}: item ("," item)* ","? 347 | _cs_list{item}: item ("," item)* ","? -------------------------------------------------------------------------------- /incremental_parsing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/evaluation/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/evaluation/evaluate_stack.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import csv 3 | import json 4 | import multiprocessing 5 | import sys 6 | import traceback 7 | from pathlib import Path 8 | from typing import Dict, Union, List 9 | 10 | from tqdm import tqdm 11 | 12 | NUM_CUTS = 10 13 | NUM_SEEDS = 1 14 | 15 | 16 | def parse_top_level_file(top_level_file: Path): 17 | """ 18 | For a given data index, compute all the results under this index 19 | """ 20 | results: List[Dict[str, Union[str, int, None]]] 21 | try: 22 | if top_level_file.is_file(): 23 | filename_parts = top_level_file.name.split(".") 24 | if len(filename_parts) == 3 and filename_parts[1] in ("pyparse", "earleyparse") \ 25 | and filename_parts[2] == "err": 26 | results = [] 27 | for cut_num in range(NUM_CUTS): 28 | for seed_num in range(NUM_SEEDS): 29 | results.append({ 30 | "data_idx": int(filename_parts[0]), 31 | "cut_idx": cut_num, 32 | "seed_idx": seed_num, 33 | "parse_orig": 0 if filename_parts[1] == "pyparse" else 1, 34 | "parse_earley": 0, 35 | "successful_generation": 0, 36 | "parse_constrained": 0, 37 | "parse_unconstrained": 0, 38 | "differs": 0 39 | }) 40 | 41 | return results 42 | 43 | elif len(filename_parts) == 4 and filename_parts[3] == "err": 44 | raise Exception("Error during generation, you should check that out: " + top_level_file.name) 45 | else: 46 | raise Exception("Unexpected filename format: " + top_level_file.name) 47 | else: 48 | results = [] 49 | data_idx = int(top_level_file.name) 50 | for cut_path in top_level_file.iterdir(): 51 | cut_idx = int(cut_path.name) 52 | for seed_path in cut_path.iterdir(): 53 | seed_idx = int(seed_path.name) 54 | 55 | prefix = (seed_path / "prefix").read_text() 56 | suffix = (seed_path / "suffix").read_text() 57 | 58 | constrained = seed_path / "constrained" 59 | unconstrained = seed_path / "unconstrained" 60 | unconstrained_checked = seed_path / "unconstrained_checked" 61 | 62 | error_message = None 63 | 64 | is_complete = unconstrained.exists() 65 | if not is_complete: 66 | continue # Allow us to use this script during active generation 67 | 68 | 69 | prefix_suffix_constrained = None 70 | constrained_hit_max_attempts = False 71 | if not constrained.exists(): 72 | # This "full_constrained" refers to constrained generation which might 73 | # be cut off at a wrong spot 74 | constrained_parse = False 75 | constrained_clipped_generation = False 76 | 77 | if not (seed_path / "full_constrained").exists(): 78 | assert (seed_path / "constrained_max_num_attempts").exists() 79 | constrained_hit_max_attempts = True 80 | else: 81 | constrained_text = constrained.read_text() 82 | prefix_suffix_constrained = prefix + constrained_text + suffix 83 | constrained_clipped_generation = True 84 | 85 | try: 86 | ast.parse(prefix_suffix_constrained) 87 | constrained_parse = True 88 | except SyntaxError as e: 89 | error_message = f"{e.__class__.__name__}: {e.msg}" 90 | constrained_parse = False 91 | except (ValueError, MemoryError) as e: 92 | error_message = f"{e.__class__.__name__}: {e.args[0]}" 93 | constrained_parse = False 94 | 95 | unconstrained_text = unconstrained.read_text() 96 | prefix_suffix_unconstrained = prefix + unconstrained_text + suffix 97 | 98 | try: 99 | ast.parse(prefix_suffix_unconstrained) 100 | unconstrained_parse = True 101 | except (SyntaxError, MemoryError, ValueError): 102 | unconstrained_parse = False 103 | 104 | unconstrained_checked_exists = unconstrained_checked.exists() 105 | 106 | if (seed_path / "stats.json").exists(): 107 | stats = json.loads((seed_path / "stats.json").read_text()) 108 | else: 109 | stats = {} 110 | 111 | results.append({ 112 | "data_idx": data_idx, 113 | "cut_idx": cut_idx, 114 | "seed_idx": seed_idx, 115 | "parse_orig": 1, 116 | "parse_earley": 1, 117 | "successful_generation": int(not constrained_hit_max_attempts), 118 | "successful_clipped_generation": int(constrained_clipped_generation), 119 | "parse_constrained": int(constrained_parse), 120 | "parse_unconstrained": int(unconstrained_parse), 121 | "parse_unconstrained_checked": int(unconstrained_checked_exists), 122 | "differs": int(prefix_suffix_constrained != prefix_suffix_unconstrained), 123 | "error_message": error_message, 124 | "prefix_size": len(prefix), 125 | "suffix_size": len(suffix), 126 | **stats 127 | }) 128 | 129 | return results 130 | except BaseException as e: 131 | return traceback.format_exc() 132 | 133 | 134 | def stack_results(results_dir: Path, summary_csv: Path): 135 | with summary_csv.open("w") as f, multiprocessing.Pool(20) as pool: 136 | writer = csv.DictWriter(f, fieldnames=["data_idx", "cut_idx", "seed_idx", "parse_orig", "parse_earley", 137 | "successful_generation", "successful_clipped_generation", 138 | "parse_constrained", "parse_unconstrained", 139 | "parse_unconstrained_checked", 140 | "differs", "num_branches_in_middle", 141 | "num_branches_after_first_suffix_lexeme", 142 | "error_message", 143 | "pre_time", "total_constrained_time", "num_constrained_tokens", 144 | "num_constrained_chars", "num_constrained_output_chars", 145 | "num_unconstrained_chars", "num_unconstrained_output_chars", 146 | "num_unconstrained_tokens", "unconstrained_generation_time", 147 | "constrained_overhead_p50", "constrained_overhead_p90", 148 | "constrained_overhead_mean", "constrained_overhead_num", 149 | "constrained_overhead_eval_p50", "constrained_overhead_eval_p90", 150 | "constrained_overhead_eval_mean", "constrained_overhead_eval_num", 151 | "unconstrained_checking_overhead_p50", 152 | "unconstrained_checking_overhead_p90", 153 | "unconstrained_checking_overhead_mean", 154 | "unconstrained_checking_overhead_num", 155 | "prefix_size", "suffix_size"]) 156 | writer.writeheader() 157 | 158 | results = pool.imap_unordered(parse_top_level_file, results_dir.iterdir(), chunksize=20) 159 | 160 | for result in tqdm(results): 161 | if isinstance(result, str): 162 | print(result) 163 | exit(1) 164 | 165 | for r in result: 166 | result_plus_defaults = { 167 | "num_branches_in_middle": None, 168 | "num_branches_after_first_suffix_lexeme": None, 169 | "successful_clipped_generation": 0, 170 | "error_message": None, 171 | **r 172 | } 173 | writer.writerow(result_plus_defaults) 174 | 175 | 176 | if __name__ == "__main__": 177 | src = Path(sys.argv[1]) 178 | assert src.is_dir() 179 | if len(sys.argv) == 2: 180 | dest = src.parent / f"{src.stem}.csv" 181 | else: 182 | dest = Path(sys.argv[2]) 183 | 184 | assert not dest.exists() 185 | 186 | stack_results(src, dest) 187 | -------------------------------------------------------------------------------- /incremental_parsing/evaluation/gen_tables.sql: -------------------------------------------------------------------------------- 1 | -- INSTRUCTIONS 2 | -- evaluate_stack.py takes the output folder of stack_completion_right_context.py and makes a nice CSV file out of it 3 | -- Creating a database out of the results makes it easy to work with them instead of needing to write redundant code 4 | -- I use the Pycharm SQL import tool to create a SQLite table named `stack` 5 | -- Then, all the numbers in the paper come from the results of running these statements. 6 | 7 | -- data_idx: Index in the-stack-smol-xl/data/python 8 | -- cut_idx: Random seed used to chop up the file into prefix & suffix 9 | -- seed_idx: If the LLM/sampling itself uses randomness, can put multiple seeds. We always used seed 0 10 | -- parse_orig: 1 if the file at data_idx is valid Python 3 11 | -- parse_earley: 1 if the file at data_idx is valid Python 3, and in our subset of Python 12 | -- succesful_generation: 1 if there were no errors in generation. 13 | -- If not equal to parse_earley, there is a bug in the program (look for a file named data_idx.cut_idx.seed_idx.err) 14 | -- successful_clipped_generation: Did the algorithm identify a left context as a member of the quotient language 15 | -- parse_[un]constrained: Does the Python parser like the result? 16 | -- differs: Does the text context differ? 17 | -- error_message: What error message does the Python parser give when reading the constrained output? 18 | -- num_branches_after_first_suffix_lexeme: How many sub-languages created by the LCFL 19 | -- Divide this by the number of sub-languages created for indentation 20 | -- num_branches_in_middle: How many sub-languages after lexing the right context 21 | 22 | -- Table body 23 | SELECT COUNT(*), parse_constrained, parse_unconstrained, parse_unconstrained_checked 24 | FROM stack 25 | WHERE parse_earley = 1 26 | GROUP BY parse_constrained, parse_unconstrained, parse_unconstrained_checked; 27 | 28 | -- Bottom margin 29 | SELECT COUNT(*), parse_constrained 30 | FROM stack 31 | WHERE parse_earley = 1 32 | GROUP BY parse_constrained; 33 | 34 | -- Side margin 35 | SELECT COUNT(*), parse_unconstrained, parse_unconstrained_checked 36 | FROM stack 37 | WHERE parse_earley = 1 38 | GROUP BY parse_unconstrained, parse_unconstrained_checked; 39 | 40 | -- Breakdown of failure cases 41 | SELECT COUNT(*), parse_unconstrained, parse_unconstrained_checked, successful_clipped_generation 42 | FROM stack 43 | WHERE parse_earley = 1 and parse_constrained = 0 and successful_generation = 1 44 | GROUP BY stack.parse_unconstrained, parse_unconstrained_checked, successful_clipped_generation; 45 | 46 | SELECT median(constrained_overhead_mean), median(constrained_overhead_eval_mean), median(unconstrained_checking_overhead_mean) 47 | FROM stack WHERE parse_unconstrained_checked = 1 and parse_constrained = 1; 48 | 49 | -- What error messages are given show up 50 | SELECT error_message, COUNT(*) as c 51 | FROM stack 52 | WHERE error_message NOT NULL 53 | GROUP BY error_message 54 | ORDER by c DESC; 55 | 56 | -- Specific cases with error messages for debugging 57 | SELECT stack.data_idx, stack.cut_idx, stack.error_message 58 | from stack 59 | where error_message NOT null; 60 | 61 | -- How many branches are due to LCFL (divide by 2 because of extra indentation branches) 62 | SELECT stack.num_branches_after_first_suffix_lexeme / 2, COUNT(*) 63 | FROM stack 64 | WHERE parse_earley = 1 65 | GROUP BY stack.num_branches_after_first_suffix_lexeme; 66 | 67 | SELECT AVG(stack.num_branches_after_first_suffix_lexeme / 2), median(stack.num_branches_after_first_suffix_lexeme / 2), stdev(stack.num_branches_after_first_suffix_lexeme/2) 68 | FROM stack 69 | WHERE parse_earley = 1; 70 | 71 | -- How many branches due to LCFL + Indentation 72 | SELECT stack.num_branches_in_middle, COUNT(*) 73 | FROM stack 74 | WHERE parse_earley = 1 75 | GROUP BY stack.num_branches_in_middle; 76 | 77 | -------------------------------------------------------------------------------- /incremental_parsing/evaluation/text_cuts.py: -------------------------------------------------------------------------------- 1 | import io 2 | import random 3 | import tokenize 4 | from tokenize import TokenInfo 5 | from typing import Tuple, List, Optional 6 | 7 | 8 | def cut_text_random(text: str, min_cut_percentage: float, max_cut_percentage: float, cut_amount: float, cut_no_more_than: Optional[int] = None) -> Tuple[ 9 | str, str, str]: 10 | """ 11 | Take a random point "p" between min_cut_percentage and max_cut_percentage of the way through the file. 12 | Split it into 3 parts: before p, from p to p+cut_amount, and from p+cut_amount to the end. 13 | Note that p+cut_amount might be > 1, in which case middle will be smaller than cut_amount and suffix will be empty 14 | cut_no_more_than is a limit on the amount of text (in characters) that can be cut. 15 | """ 16 | cut_percentage = min_cut_percentage + (max_cut_percentage - min_cut_percentage) * random.random() 17 | cut_index = int(cut_percentage * len(text)) 18 | cut_amount = cut_amount * len(text) 19 | if cut_no_more_than is not None: 20 | cut_amount = min(cut_amount, cut_no_more_than) 21 | 22 | cut_end = cut_index + int(cut_amount) 23 | prefix = text[:cut_index] 24 | middle = text[cut_index:cut_end] 25 | suffix = text[cut_end:] 26 | return prefix, middle, suffix 27 | 28 | 29 | IndentationRun = List[Tuple[int, int]] # A set of (start_idx, end_idx) with the same indentation 30 | 31 | 32 | def get_runs_of_same_indentation(tokens: List[TokenInfo]) -> List[IndentationRun]: 33 | """ 34 | Output including (3, 6) would mean that we can cut the string before token 3, 4, or 5 35 | Output including (3, 4) means only valid cut is before token 3 36 | 37 | Input: 38 | 0: Pass 39 | 1: Indent ----| 40 | 2: Pass | 41 | 3: Indent -| | 42 | 4: Pass | | 43 | 5: Dedent -| | 44 | 6: Pass | 45 | 7: Dedent ----| 46 | 8: Pass 47 | 9: Indent -| 48 | 10:Pass | 49 | 11:Dedent -| 50 | 51 | So the output would include these four groups [index may be off by one, need to double-check] 52 | ((0, 1), (8, 9)) 53 | ((2, 3), (6, 7)) 54 | ((4, 5)) 55 | ((10, 11)) 56 | The idea being that all tokens that belong to a group are related to each other by indentation level 57 | """ 58 | finished_runs: List[IndentationRun] = [] 59 | run_stack: List[IndentationRun] = [] 60 | current_run: IndentationRun = [] 61 | current_run_start = 0 62 | 63 | for i, token in enumerate(tokens): 64 | if token.type == tokenize.INDENT: 65 | if current_run_start < i: 66 | current_run.append((current_run_start, i)) 67 | run_stack.append(current_run) 68 | current_run = [] 69 | current_run_start = i + 1 70 | elif token.type == tokenize.DEDENT: 71 | if current_run_start < i: 72 | current_run.append((current_run_start, i)) 73 | finished_runs.append(current_run) 74 | current_run = run_stack.pop() 75 | current_run_start = i + 1 76 | 77 | assert len(run_stack) == 0 78 | if current_run_start < len(tokens): 79 | current_run.append((current_run_start, len(tokens))) 80 | 81 | finished_runs.append(current_run) 82 | 83 | return finished_runs 84 | 85 | 86 | def extract_absolute_positions(text: str, tokens: List[TokenInfo]) -> List[Tuple[int, int]]: 87 | """ 88 | Surprisingly it's quite weird going from line/col to absolute pos. 89 | This function does so (giving the left/right bounds of a token) 90 | """ 91 | str_line_offsets = [0] 92 | running_str_offset = 0 93 | for line in text.split("\n"): 94 | str_line_offsets.append(running_str_offset) 95 | running_str_offset += len(line) + 1 96 | 97 | token_positions: List[Tuple[int, int]] = [] 98 | 99 | for token in tokens: 100 | start_pos = str_line_offsets[token.start[0]] + token.start[1] 101 | end_pos = str_line_offsets[token.end[0]] + token.end[1] 102 | 103 | assert text[start_pos:end_pos] == token.string, \ 104 | f"{token.start[0]}:{token.start[1]}-{token.end[0]}:{token.end[1]}: " \ 105 | f"{text[start_pos:end_pos]} != {token.string}" 106 | token_positions.append((start_pos, end_pos)) 107 | 108 | return token_positions 109 | 110 | 111 | def select_random_points_from_indentation_run(runs: IndentationRun, k: int) -> List[int]: 112 | """Pick k points from within the same indentation run (see get_runs_of_same_indentation)""" 113 | run_choices = random.choices(population=runs, weights=[r[1] - r[0] for r in runs], k=k) 114 | return [random.randint(r[0], r[1] - 1) for r in run_choices] 115 | 116 | 117 | def run_weights(runs: List[IndentationRun]) -> List[int]: 118 | return [sum(r[1] - r[0] for r in run) for run in runs] 119 | 120 | 121 | def cut_text_same_indentation_levels(text: str) -> Tuple[str, str, str]: 122 | """ 123 | An attempt to create a method of chopping up text that is hopefully more reminiscent of what insertion points would 124 | look like in practice. 125 | The main things being that the prefix ends at the same indentation level the suffix begins, 126 | and that the suffix begins on a token boundary. 127 | """ 128 | text_bytes = bytes(text, 'utf-8') 129 | byte_io = io.BytesIO(text_bytes) 130 | tokens = list(tokenize.tokenize(byte_io.readline))[1:] 131 | token_posns = extract_absolute_positions(text, tokens) 132 | 133 | runs = get_runs_of_same_indentation(tokens) 134 | left_token_num, right_token_num = select_random_points_from_indentation_run( 135 | random.choices(runs, weights=run_weights(runs), k=1)[0], 2) 136 | 137 | if left_token_num > right_token_num: 138 | left_token_num, right_token_num = right_token_num, left_token_num 139 | 140 | left_token_start, left_token_end = token_posns[left_token_num] 141 | left_cutoff_point = random.randint(left_token_start, left_token_end) 142 | right_token_start, right_token_end = token_posns[right_token_num] 143 | right_cutoff_point = random.randint(right_token_start, right_token_end) 144 | return text[:left_cutoff_point], text[left_cutoff_point:right_cutoff_point], text[right_cutoff_point:] 145 | -------------------------------------------------------------------------------- /incremental_parsing/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/generation/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/generation/logits_processor.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import itertools 3 | from typing import List, Tuple, Optional 4 | 5 | import torch 6 | from transformers import LogitsProcessor 7 | 8 | from incremental_parsing.generation.lex_earley_worker import LexEarleyWorker 9 | 10 | 11 | class MaxNumAtempts(Exception): 12 | pass 13 | 14 | 15 | class LexEarleyLogitsProcessor(LogitsProcessor): 16 | def __init__(self, worker: LexEarleyWorker, beam_size: int, eof_id: int, debug: bool = False, max_iter_attempts: Optional[int] = None): 17 | self.beam_size = beam_size 18 | self.worker = worker 19 | self.eof_id = eof_id 20 | self.debug = debug 21 | self.overall_token_times: List[datetime.timedelta] = [] 22 | self.per_evaluation_times: List[datetime.timedelta] = [] 23 | self.max_iter_attempts = max_iter_attempts 24 | 25 | # The following hack accounts for santacoder being weird about its pad IDs 26 | pad_token_input = worker.tokenizer("").input_ids 27 | if len(pad_token_input) == 1 and pad_token_input[0] in worker.tokenizer.all_special_ids: 28 | self.pad_token_id = pad_token_input[0] 29 | else: 30 | self.pad_token_id = worker.tokenizer.pad_token_id 31 | 32 | def get_and_reset_times(self) -> Tuple[List[datetime.timedelta], List[datetime.timedelta]]: 33 | ot, pe = self.overall_token_times, self.per_evaluation_times 34 | self.overall_token_times = [] 35 | self.per_evaluation_times = [] 36 | return ot, pe 37 | 38 | def __call__(self, input_ids, scores): 39 | assert len(scores.shape) == 2 40 | 41 | if self.pad_token_id is not None: 42 | scores[:, self.pad_token_id] = float("-inf") 43 | 44 | scores_log_softmax = torch.log_softmax(scores, dim=1) 45 | 46 | eof_scores = scores[:, self.eof_id].cpu().numpy().tolist() 47 | eof_scores_log_softmax = scores_log_softmax[:, self.eof_id].cpu().numpy().tolist() 48 | scores[:, self.eof_id] = float("-inf") 49 | 50 | total_toks = 0 51 | result = torch.full_like(scores, float("-inf")) 52 | 53 | time_start = datetime.datetime.now() 54 | 55 | for _ in (range(self.max_iter_attempts) if self.max_iter_attempts is not None else itertools.count()): 56 | # Keep going until we have enough tokens that the parser likes 57 | top_k = torch.topk(scores, self.beam_size, dim=1) 58 | top_k_softmax = torch.topk(scores_log_softmax, self.beam_size, dim=1) 59 | top_k_log_softmax_values_cpu = top_k_softmax.values.cpu().numpy().tolist() 60 | 61 | input_ids_cpu = input_ids.cpu().numpy().tolist() 62 | top_k_indices_cpu = top_k.indices.cpu().numpy().tolist() 63 | 64 | total_eof_valid = 0 65 | 66 | for i in range(scores.shape[0]): 67 | per_evaluation_start_time = datetime.datetime.now() 68 | results, eof_valid, redundant_branches = self.worker.check_toks(prefix=input_ids_cpu[i], 69 | possible_generations=top_k_indices_cpu[ 70 | i], 71 | scores=top_k_log_softmax_values_cpu[i], 72 | eof_score=eof_scores_log_softmax[i]) 73 | 74 | results_gpu = torch.as_tensor(results, dtype=torch.bool, device=scores.device) 75 | result[i][top_k.indices[i]] = torch.where(results_gpu, top_k.values[i], float("-inf")) 76 | 77 | if eof_valid: 78 | result[i, self.eof_id] = eof_scores[i] 79 | total_eof_valid += 1 80 | 81 | total_toks += sum(results) 82 | 83 | self.per_evaluation_times.append(datetime.datetime.now() - per_evaluation_start_time) 84 | 85 | if self.debug: 86 | print(self.worker.tokenizer.decode(input_ids[i][self.worker.num_ignored_tokens:])) 87 | print(f"({sum(results)})") 88 | 89 | # Hopefully we don't get stuck here? 90 | if total_toks + total_eof_valid < (1 if self.beam_size == 1 else self.beam_size * 2): 91 | if self.debug: 92 | print("Not enough allowable tokens, generating more") 93 | for i in range(scores.shape[0]): 94 | scores[i][top_k.indices[i]] = float("-inf") 95 | else: 96 | break 97 | else: 98 | # Max num iterations reached, just send an EOF 99 | if self.debug: 100 | raise MaxNumAtempts() 101 | result[:,self.eof_id] = 1 102 | 103 | if self.debug: 104 | print("-----") 105 | 106 | self.overall_token_times.append(datetime.datetime.now() - time_start) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /incremental_parsing/generation/probability_trie.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | 4 | class TokenProbabilityTrieNode: 5 | """ 6 | Keep track of cumulative probability, given marginal probabilities of each token 7 | """ 8 | def __init__(self, running_sum_logprob: float, hypothesis_length: int): 9 | self.running_sum_logprob = running_sum_logprob 10 | self.children: Dict[int, TokenProbabilityTrieNode] = {} 11 | self.running_eof_sum_logprob: Optional[float] = None 12 | self.hypothesis_length = hypothesis_length 13 | 14 | def add_child(self, child_token: int, child_logprob: float) -> "TokenProbabilityTrieNode": 15 | self.children[child_token] = TokenProbabilityTrieNode( 16 | self.running_sum_logprob + child_logprob, self.hypothesis_length + 1) 17 | return self.children[child_token] 18 | 19 | def get_child_or_default(self, child_token: int): 20 | if child_token in self.children: 21 | return self.children[child_token] 22 | else: 23 | return self.add_child(child_token, float("-inf")) 24 | 25 | def set_eof_probability(self, eof_sum_logprob: float): 26 | self.running_eof_sum_logprob = eof_sum_logprob + self.running_sum_logprob 27 | 28 | @property 29 | def eof_score(self) -> Optional[float]: 30 | if self.running_eof_sum_logprob is None: 31 | return None 32 | return self.running_eof_sum_logprob / (self.hypothesis_length + 1) 33 | -------------------------------------------------------------------------------- /incremental_parsing/generation/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import datetime 3 | from typing import Tuple, Callable, Optional, Sequence, List 4 | 5 | import numpy as np 6 | import torch 7 | from ansi.colour import bg, fg 8 | from transformers import PreTrainedTokenizer 9 | 10 | 11 | def tokenizer_int64(tokenizer: PreTrainedTokenizer, text: str) -> Tuple[torch.Tensor, torch.Tensor]: 12 | val = tokenizer(text, return_tensors="pt") 13 | return val["input_ids"].type(torch.int64), val["attention_mask"].type(torch.int64) 14 | 15 | 16 | def create_balanced_context(pre_input_ids: torch.Tensor, 17 | pre_attention_mask: torch.Tensor, 18 | post_input_ids: torch.Tensor, 19 | post_attention_mask: torch.Tensor, 20 | tokenizer: PreTrainedTokenizer, 21 | max_generation_length: int, 22 | device: str) -> Tuple[torch.Tensor, torch.Tensor]: 23 | max_tokens_for_context = tokenizer.model_max_length - max_generation_length - 3 24 | 25 | fim_prefix_tokens = tokenizer("", return_tensors="pt") 26 | fim_suffix_tokens = tokenizer("", return_tensors="pt") 27 | fim_middle_tokens = tokenizer("", return_tensors="pt") 28 | 29 | shorter_context = min(pre_input_ids.shape[1], post_input_ids.shape[1]) 30 | 31 | if shorter_context * 2 > max_tokens_for_context: 32 | # The context is too long, and both sides have over half the allowed amount 33 | # Restrict each side to half 34 | max_context_each_side = max_tokens_for_context // 2 35 | else: 36 | # One side has less than half of the allowed amount; allocate the rest of the allowed amount to the longer side 37 | max_context_each_side = max_tokens_for_context - shorter_context 38 | 39 | # Suffix might be empty, but even then this likely helps to prevent the LLM from babbling 40 | 41 | input_ids = torch.concat(( 42 | fim_prefix_tokens["input_ids"], 43 | pre_input_ids[:, -max_context_each_side:], 44 | fim_suffix_tokens["input_ids"], 45 | post_input_ids[:, :max_context_each_side], 46 | fim_middle_tokens["input_ids"] 47 | ), dim=1).to(device) 48 | attention_mask = torch.concat(( 49 | fim_prefix_tokens["attention_mask"], 50 | pre_attention_mask[:, -max_context_each_side:], 51 | fim_suffix_tokens["attention_mask"], 52 | post_attention_mask[:, :max_context_each_side], 53 | fim_middle_tokens["attention_mask"] 54 | ), dim=1).to(device) 55 | 56 | return input_ids, attention_mask 57 | 58 | 59 | colors = [ 60 | (bg.red, fg.black), 61 | (bg.green, fg.black), 62 | (bg.yellow, fg.black), 63 | (bg.blue, fg.black), 64 | (bg.magenta, fg.black), 65 | (bg.cyan, fg.black), 66 | (bg.white, fg.black), 67 | (bg.black, fg.white) 68 | ] 69 | 70 | 71 | def color_idx(idx: int) -> Callable[[str], str]: 72 | b, f = colors[idx % len(colors)] 73 | return lambda s: b(f(s)) 74 | 75 | 76 | def try_incremental_unconstrained(prefix: str, middle: str, suffix: str) -> Tuple[ 77 | Optional[str], Sequence[datetime.timedelta]]: 78 | """ 79 | Check whether every prefix of unconstrained generation parses 80 | (for a performance comparison to constrained generation) 81 | """ 82 | longest_match: Optional[str] = None 83 | times: List[datetime.timedelta] = [] 84 | 85 | for i in range(len(middle) + 1): 86 | start_time = datetime.datetime.now() 87 | full_text = prefix + middle[:i] + suffix 88 | 89 | # noinspection PyBroadException 90 | try: 91 | ast.parse(full_text) 92 | except: 93 | pass 94 | else: 95 | longest_match = middle[:i] 96 | 97 | end_time = datetime.datetime.now() 98 | times.append(end_time - start_time) 99 | 100 | return longest_match, times 101 | 102 | 103 | def get_p50_p90_mean_count(times: Sequence[datetime.timedelta]) -> Tuple[Optional[float], Optional[float], Optional[float], float]: 104 | if len(times) == 0: 105 | return None, None, None, 0 106 | 107 | time_seconds = [time.total_seconds() for time in times] 108 | return float(np.quantile(time_seconds, .5)), float(np.quantile(time_seconds, .9)), sum(time_seconds) / len(time_seconds), len(time_seconds) 109 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/lex_earley/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/branch_guard/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/lex_earley/branch_guard/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/branch_guard/banned_regex_set.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, FrozenSet, Sequence 2 | 3 | from incremental_parsing.lex_earley.branch_guard.lexer_branch_guard import LexerBranchGuard 4 | from incremental_parsing.lex_earley.incremental_pattern import IncrementalPattern 5 | 6 | 7 | class LexerBranchGuardBannedRegexSet(LexerBranchGuard): 8 | """ 9 | This branch guard is used to make sure there aren't any longer matches 10 | I.E. if we have generated "abc" so far, and have emitted a branch with a token for identifiers, 11 | then "identifiers which are >=4 chars" would be a banned regex. If we see "d" next, then "abcd" would be the 12 | leftmost-longest identifier, and so we should delete the "abc" branch. 13 | """ 14 | def __init__(self, banned_regex_set: Tuple[IncrementalPattern, ...], min_banned_match_length: int): 15 | self.banned_pattern_set = banned_regex_set 16 | self.min_banned_match_length = min_banned_match_length 17 | 18 | def branch_allowed( 19 | self, text: str, start_index: int, branch_guard_state: Tuple[Tuple[FrozenSet[int], ...], int]) -> \ 20 | Tuple[Optional[bool], Optional[Tuple[Tuple[FrozenSet[int], ...], int]]]: 21 | token_nfas: Sequence[FrozenSet[int]] 22 | token_nfas, current_length = branch_guard_state 23 | 24 | for string_idx in range(start_index, len(text)): 25 | new_token_nfas = [] 26 | seen_any_partial = False 27 | 28 | current_length += 1 29 | for pattern, nfa in zip(self.banned_pattern_set, token_nfas): 30 | next_nfa = pattern.step_forwards(nfa, text[string_idx]) 31 | if current_length >= self.min_banned_match_length and any(n in pattern.final_states for n in next_nfa): 32 | return False, None # Seen a full match of a banned token longer than the min len -> branch is bad 33 | elif next_nfa: 34 | seen_any_partial = True 35 | new_token_nfas.append(next_nfa) 36 | 37 | if not seen_any_partial: 38 | return True, None # All banned tokens have dropped out, this branch is definitely okay 39 | else: 40 | token_nfas = new_token_nfas 41 | 42 | return None, (tuple(token_nfas), current_length) 43 | 44 | def eof_allowed(self, current_state: Tuple[Tuple[FrozenSet[int], ...], int]) -> bool: 45 | # If we matched a banned pattern before, we would have already rejected the branch 46 | return True 47 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/branch_guard/combined_guard.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Any, Tuple, Optional 2 | 3 | from incremental_parsing.lex_earley.branch_guard.lexer_branch_guard import LexerBranchGuard 4 | 5 | 6 | class LexerBranchGuardCombined(LexerBranchGuard): 7 | """ 8 | So that a branch can have more than one branch guard 9 | """ 10 | def __init__(self, branch_guards: Iterable[LexerBranchGuard]): 11 | self.branch_guards = tuple(branch_guards) 12 | 13 | @staticmethod 14 | def combine_branch_guard_states(states: Iterable[Any]) -> Any: 15 | s = tuple(states) 16 | return s, (True,) * len(s) 17 | 18 | def branch_allowed(self, text: str, start_index: int, branch_guard_state: Any) -> Tuple[Optional[bool], Any]: 19 | inner_states, still_consider = branch_guard_state 20 | next_inner_states, next_still_consider = [], [] 21 | any_failed = False 22 | any_ambiguous = False 23 | for guard, inner_state, consider in zip(self.branch_guards, inner_states, still_consider): 24 | if not consider: 25 | next_inner_states.append(inner_state) 26 | next_still_consider.append(consider) 27 | else: 28 | inner_res, inner_next_state = guard.branch_allowed(text, start_index, inner_state) 29 | next_inner_states.append(inner_next_state) 30 | if inner_res: 31 | next_still_consider.append(False) 32 | elif inner_res is None: 33 | next_still_consider.append(True) 34 | any_ambiguous = True 35 | else: 36 | next_still_consider.append(True) 37 | any_failed = True 38 | 39 | if any_failed: 40 | result = False 41 | elif any_ambiguous: 42 | result = None 43 | else: 44 | result = True 45 | 46 | return result, (tuple(next_inner_states), tuple(next_still_consider)) 47 | 48 | def eof_allowed(self, branch_guard_state: Any) -> bool: 49 | inner_states, still_considers = branch_guard_state 50 | return all((not consider) or branch.eof_allowed(inner_state) 51 | for branch, inner_state, consider in zip(self.branch_guards, inner_states, still_considers)) 52 | 53 | def replace(self, branch_guard_state: Any) -> Tuple[Optional["LexerBranchGuard"], Any]: 54 | inner_states, still_considers = branch_guard_state 55 | if all(still_considers): 56 | return self, branch_guard_state 57 | 58 | rep_branches, rep_inner_states = [], [] 59 | for branch, inner_state, still_consider in zip(self.branch_guards, inner_states, still_considers): 60 | if not still_consider: 61 | continue 62 | 63 | repl_branch, repl_state = branch.replace(inner_state) 64 | if not repl_branch: 65 | continue 66 | 67 | rep_branches.append(repl_branch) 68 | rep_inner_states.append(repl_state) 69 | 70 | if len(rep_branches) == 0: 71 | return None, None 72 | elif len(rep_branches) == 1: 73 | return rep_branches[0], rep_inner_states[0] 74 | else: 75 | return LexerBranchGuardCombined(rep_branches), (tuple(rep_inner_states), (True,) * len(rep_inner_states)) 76 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/branch_guard/lexer_branch_guard.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Tuple, Optional 3 | 4 | 5 | class LexerBranchGuard(abc.ABC): 6 | @abc.abstractmethod 7 | def branch_allowed(self, text: str, start_index: int, branch_guard_state: Any) -> Tuple[Optional[bool], Any]: 8 | """ 9 | Is the current lexer branch allowed, given the lookahead text? 10 | Invariant: 11 | t1, t2 : str 12 | t = t1 + t2 13 | bgs: branch_guard_state 14 | assert branch_allowed(t, bgs) == branch_allowed(t2, branch_allowed(t1, bgs)[1]) 15 | :param text: Lookahead text to determine whether the branch is allowed by the lexer 16 | :param start_index: Start index of the lookahead text 17 | :param branch_guard_state: Some internal state of the branch guard 18 | :return: True if the branch is definitely allowed and we no longer need to check for future text. 19 | None if the branch is allowed so far, but there is some future text that would rule out the branch. 20 | False if the branch is not allowed and should be pruned. 21 | """ 22 | pass 23 | 24 | @abc.abstractmethod 25 | def eof_allowed(self, branch_guard_state: Any) -> bool: 26 | """ 27 | Are we allowed to reach the end of the file here? 28 | """ 29 | pass 30 | 31 | def replace(self, branch_guard_state: Any) -> Tuple[Optional["LexerBranchGuard"], Any]: 32 | """ 33 | Can this be replaced by a different branch guard? 34 | """ 35 | return self, branch_guard_state 36 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/branch_guard/single_char_banned_regex_set.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, FrozenSet, Sequence 2 | 3 | from incremental_parsing.lex_earley.branch_guard.lexer_branch_guard import LexerBranchGuard 4 | from incremental_parsing.lex_earley.incremental_pattern import IncrementalPattern 5 | 6 | 7 | class LexerBranchGuardBannedRegexSetRealBehavior(LexerBranchGuard): 8 | """ 9 | Implements the behavior like the Python lexer, as opposed to true maximal munch. 10 | Otherwise, see banned_regex_set.py's documentation 11 | """ 12 | 13 | def __init__(self, banned_regex_set: Tuple[IncrementalPattern, ...], min_banned_match_length: int): 14 | self.banned_pattern_set = banned_regex_set 15 | self.min_banned_match_length = min_banned_match_length 16 | 17 | def branch_allowed(self, text: str, start_index: int, branch_guard_state: Tuple[Tuple[FrozenSet[int], ...], int]) -> \ 18 | Tuple[Optional[bool], Optional[Tuple[Tuple[FrozenSet[int], ...], int]]]: 19 | token_nfas: Sequence[FrozenSet[int]] 20 | token_nfas, current_length = branch_guard_state 21 | 22 | for string_idx in range(start_index, len(text)): 23 | new_token_nfas = [] 24 | current_length += 1 25 | for pattern, nfa in zip(self.banned_pattern_set, token_nfas): 26 | next_nfa = pattern.step_forwards(nfa, text[string_idx]) 27 | new_token_nfas.append(next_nfa) 28 | 29 | if current_length < self.min_banned_match_length: 30 | token_nfas = new_token_nfas 31 | continue 32 | elif current_length == self.min_banned_match_length: 33 | if any(len(nfa) > 0 for nfa in new_token_nfas): 34 | return False, None # Sees a partial match of a token exactly one longer 35 | else: 36 | return True, None # No such partial match exists 37 | else: 38 | assert False, "length should not be greater than match length" 39 | 40 | return None, (tuple(token_nfas), current_length) 41 | 42 | def eof_allowed(self, current_state: Tuple[Tuple[FrozenSet[int], ...], int]) -> bool: 43 | # If we matched a banned pattern before, we would have already rejected the branch 44 | return True 45 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/earley_base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import NamedTuple, Tuple, Iterable, List, Callable, Union, Dict, Sequence 3 | 4 | from incremental_parsing.lex_earley.lexer import Token 5 | from incremental_parsing.lex_earley.simple_bnf import SimpleBNF, BNFElement, BNFTerminal, BNFNonterminal 6 | 7 | 8 | class LexEarleyState(NamedTuple): 9 | span_start: int 10 | rule_name: str 11 | production_index: int 12 | position: int 13 | max_position: int 14 | 15 | def next_element(self, grammar: SimpleBNF) -> BNFElement: 16 | """ 17 | Will error if is_complete is true 18 | """ 19 | return grammar.rules[self.rule_name].productions[self.production_index].elements[self.position] 20 | 21 | def is_complete(self) -> bool: 22 | return self.position == self.max_position 23 | 24 | def advance(self) -> 'LexEarleyState': 25 | return self._replace(position=self.position + 1) 26 | 27 | def to_str(self, grammar: SimpleBNF, only_elements: bool = False): 28 | production = grammar.rules[self.rule_name].productions[self.production_index] 29 | prev_prods = " ".join(e.name for e in production.elements[:self.position]) 30 | after_prods = " ".join(e.name for e in production.elements[self.position:]) 31 | if only_elements: 32 | return f"{prev_prods} • {after_prods}" 33 | else: 34 | return f"{self.rule_name} -> {prev_prods} • {after_prods} ({self.span_start})" 35 | 36 | def reverse_position(self) -> "LexEarleyState": 37 | return self._replace(position=(self.max_position - self.position)) 38 | 39 | 40 | # Various ways that an Earley state could be created 41 | 42 | @dataclass(frozen=True) 43 | class TopLevel: 44 | pass 45 | 46 | 47 | @dataclass(frozen=True) 48 | class Scanned: 49 | from_chart_idx: int 50 | from_state_idx: int 51 | 52 | 53 | @dataclass(frozen=True) 54 | class Predicted: 55 | from_chart_idx: int 56 | from_state_idx: int 57 | 58 | 59 | @dataclass(frozen=True) 60 | class PredictedNullableCompletion: 61 | from_chart_idx: int 62 | from_state_idx: int 63 | 64 | 65 | @dataclass(frozen=True) 66 | class Completed: 67 | finished_rule_chart_idx: int 68 | finished_rule_state_idx: int 69 | complete_into_chart_idx: int 70 | complete_into_state_idx: int 71 | 72 | 73 | StateCreationMethod = Union[TopLevel, Scanned, Predicted, PredictedNullableCompletion, Completed] 74 | 75 | StatePlusCreationMethods = Tuple[LexEarleyState, Tuple[StateCreationMethod, ...]] 76 | 77 | 78 | # A somewhat textbook implementation of the Earley Algorithm for the rest of the file 79 | 80 | class LexEarleyAlgorithmChart(NamedTuple): 81 | states: Tuple[StatePlusCreationMethods, ...] 82 | 83 | def reverse_state_positions(self): 84 | return LexEarleyAlgorithmChart( 85 | states=tuple((state.reverse_position(), creation_methods) for state, creation_methods in self.states) 86 | ) 87 | 88 | def get_states_and_creation_methods(self) -> Sequence[Tuple[LexEarleyState, Iterable[StateCreationMethod]]]: 89 | return self.states 90 | 91 | def __len__(self): 92 | return len(self.states) 93 | 94 | def __getitem__(self, idx: int) -> Tuple[LexEarleyState, Iterable[StateCreationMethod]]: 95 | return self.states[idx] 96 | 97 | 98 | def scan_all(grammar: SimpleBNF, prev_chart: Iterable[StatePlusCreationMethods], prev_chart_idx: int, symbol: Token) \ 99 | -> List[Tuple[LexEarleyState, StateCreationMethod]]: 100 | next_states_ordered: List[Tuple[LexEarleyState, StateCreationMethod]] = [] 101 | for state_idx, (state, _) in enumerate(prev_chart): 102 | if not state.is_complete(): 103 | next_element = state.next_element(grammar) 104 | if isinstance(next_element, BNFTerminal): 105 | if next_element.name == symbol.name: 106 | next_states_ordered.append((state.advance(), Scanned(from_chart_idx=prev_chart_idx, 107 | from_state_idx=state_idx))) 108 | 109 | return next_states_ordered 110 | 111 | 112 | def predictor_completer(prev_charts: Sequence[LexEarleyAlgorithmChart], 113 | items_from_scanner: Sequence[Tuple[LexEarleyState, StateCreationMethod]], 114 | bnf: SimpleBNF) -> Tuple[LexEarleyAlgorithmChart, Iterable[str]]: 115 | """ 116 | :return: all states of this chart, and all possible terminals which could show up next 117 | """ 118 | states: Dict[LexEarleyState, int] = dict() # To index in processed_states_ordered 119 | states_ordered: List[Tuple[LexEarleyState, List[StateCreationMethod]]] = [] 120 | processed_state_idx = 0 121 | allowed_next_symbols = set() 122 | 123 | def adder(earley_state: LexEarleyState, creation_method: StateCreationMethod): 124 | if earley_state in states: 125 | # Already exists, add the creation method 126 | states_ordered[states[earley_state]][1].append(creation_method) 127 | else: 128 | # New state 129 | states[earley_state] = len(states_ordered) 130 | states_ordered.append((earley_state, [creation_method])) 131 | 132 | for state, creation_method in items_from_scanner: 133 | adder(state, creation_method) 134 | 135 | while processed_state_idx < len(states_ordered): 136 | state_to_process, _creation_methods = states_ordered[processed_state_idx] 137 | 138 | if state_to_process.is_complete(): 139 | items_from_span_start: Sequence[Tuple[LexEarleyState, Sequence[StateCreationMethod]]] 140 | 141 | if state_to_process.span_start == len(prev_charts): 142 | items_from_span_start = states_ordered 143 | else: 144 | items_from_span_start = prev_charts[state_to_process.span_start].states 145 | 146 | states_in_span_start = (item[0] for item in items_from_span_start) 147 | completer( 148 | items_in_span_start=states_in_span_start, 149 | rule_name=state_to_process.rule_name, 150 | adder=adder, 151 | grammar=bnf, 152 | current_chart_idx=len(prev_charts), 153 | span_start_idx=state_to_process.span_start, 154 | current_proc_state_idx=processed_state_idx 155 | ) 156 | else: 157 | next_element = state_to_process.next_element(bnf) 158 | if isinstance(next_element, BNFNonterminal): 159 | predictor( 160 | state=state_to_process, 161 | next_rule_name=next_element.name, 162 | grammar=bnf, 163 | current_chart_idx=len(prev_charts), 164 | adder=adder, 165 | predicted_from_state_idx=processed_state_idx 166 | ) 167 | else: 168 | allowed_next_symbols.add(next_element.name) 169 | 170 | processed_state_idx += 1 171 | 172 | immutable_states_ordered = tuple( 173 | (state, tuple(creation_methods)) 174 | for state, creation_methods in states_ordered 175 | ) 176 | 177 | return LexEarleyAlgorithmChart(states=immutable_states_ordered), allowed_next_symbols 178 | 179 | 180 | Adder = Callable[[LexEarleyState, StateCreationMethod], None] 181 | 182 | 183 | def completer(items_in_span_start: Iterable[LexEarleyState], 184 | rule_name, 185 | adder: Adder, 186 | grammar: SimpleBNF, 187 | span_start_idx: int, 188 | current_chart_idx: int, 189 | current_proc_state_idx: int): 190 | for complete_into_state_idx, previous_state in enumerate(items_in_span_start): 191 | if not previous_state.is_complete(): 192 | expected_item = previous_state.next_element(grammar) 193 | if isinstance(expected_item, BNFNonterminal): 194 | if expected_item.name == rule_name: 195 | adder(previous_state.advance(), 196 | Completed( 197 | finished_rule_chart_idx=current_chart_idx, 198 | finished_rule_state_idx=current_proc_state_idx, 199 | complete_into_chart_idx=span_start_idx, 200 | complete_into_state_idx=complete_into_state_idx 201 | )) 202 | 203 | 204 | def predictor(state: LexEarleyState, next_rule_name: str, grammar: SimpleBNF, current_chart_idx: int, adder: Adder, 205 | predicted_from_state_idx: int): 206 | matching_rule = grammar.rules[next_rule_name] 207 | 208 | for i, production in enumerate(matching_rule.productions): 209 | adder(LexEarleyState( 210 | span_start=current_chart_idx, 211 | rule_name=next_rule_name, 212 | production_index=i, 213 | position=0, 214 | max_position=len(production.elements) 215 | ), Predicted( 216 | from_state_idx=predicted_from_state_idx, 217 | from_chart_idx=current_chart_idx 218 | )) 219 | 220 | if next_rule_name in grammar.nullable_rules: 221 | adder(state.advance(), PredictedNullableCompletion( 222 | from_chart_idx=current_chart_idx, 223 | from_state_idx=predicted_from_state_idx, 224 | )) 225 | 226 | 227 | def process_token(grammar: SimpleBNF, charts: Sequence[LexEarleyAlgorithmChart], token: Token) -> Tuple[ 228 | LexEarleyAlgorithmChart, Iterable[str]]: 229 | scanned_rules = scan_all(grammar=grammar, prev_chart=charts[-1].states, prev_chart_idx=len(charts) - 1, 230 | symbol=token) 231 | return predictor_completer(charts, scanned_rules, grammar) 232 | 233 | 234 | def initial_chart_allowed_tokens(grammar: SimpleBNF) -> Tuple[LexEarleyAlgorithmChart, Iterable[str]]: 235 | initial_earley_states = [] 236 | for initial_rule in grammar.top_level_rules: 237 | for prod_idx, production in enumerate(grammar.rules[initial_rule].productions): 238 | initial_earley_states.append((LexEarleyState( 239 | span_start=0, 240 | rule_name=initial_rule, 241 | production_index=prod_idx, 242 | position=0, 243 | max_position=len(production.elements) 244 | ), TopLevel())) 245 | return predictor_completer(prev_charts=(), items_from_scanner=initial_earley_states, bnf=grammar) 246 | 247 | 248 | def charts_completable(charts: Sequence[LexEarleyAlgorithmChart]) -> bool: 249 | """ 250 | Returns true if the state is completable. 251 | """ 252 | return len(charts[-1].states) > 0 253 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/earley_trie.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Dict, Tuple, Optional, overload, List, Generator, Sequence 3 | 4 | from typing_extensions import Self 5 | 6 | from incremental_parsing.lex_earley.earley_base import process_token, charts_completable, LexEarleyAlgorithmChart, \ 7 | initial_chart_allowed_tokens 8 | from incremental_parsing.lex_earley.lexer import Token 9 | from incremental_parsing.lex_earley.simple_bnf import SimpleBNF 10 | from incremental_parsing.utils.lookback_trie import LookbackTrieNode 11 | 12 | 13 | class AbstractEarleyTrieNode(metaclass=ABCMeta): 14 | @classmethod 15 | @abstractmethod 16 | def create_root(cls, grammar: SimpleBNF): 17 | pass 18 | 19 | @abstractmethod 20 | def get_child(self, token: Token) -> Self: 21 | pass 22 | 23 | @abstractmethod 24 | def is_completable(self) -> bool: 25 | pass 26 | 27 | @abstractmethod 28 | def is_complete(self) -> bool: 29 | pass 30 | 31 | @overload 32 | def __getitem__(self, index: int) -> LexEarleyAlgorithmChart: 33 | ... 34 | 35 | @overload 36 | def __getitem__(self, index: slice) -> List[LexEarleyAlgorithmChart]: 37 | ... 38 | 39 | @abstractmethod 40 | def __getitem__(self, item): 41 | pass 42 | 43 | @property 44 | @abstractmethod 45 | def allowed_token_names(self) -> Sequence[str]: 46 | pass 47 | 48 | 49 | class EarleyTrieNode(AbstractEarleyTrieNode): 50 | """ 51 | A simple structure to cache parse results to avoid needless computation 52 | """ 53 | 54 | @classmethod 55 | def create_root(cls, grammar: SimpleBNF): 56 | initial_chart, allowed_tokens = initial_chart_allowed_tokens(grammar) 57 | root_chart = LookbackTrieNode.create_root_node() 58 | return cls(grammar=grammar, charts=root_chart.get_child(initial_chart), allowed_token_names=tuple(allowed_tokens)) 59 | 60 | def __init__(self, grammar: SimpleBNF, 61 | charts: LookbackTrieNode[LexEarleyAlgorithmChart], 62 | allowed_token_names: Tuple[str, ...]): 63 | self.grammar = grammar 64 | self.charts = charts 65 | self.children: Dict[Token, EarleyTrieNode] = {} 66 | self._allowed_token_names = allowed_token_names # This parameter doesn't really matter for the initial chart 67 | 68 | def get_child(self, token: Token) -> "EarleyTrieNode": 69 | if token not in self.children: 70 | assert not token.loose_behavior, ("Use Earley NFA for loose behavior. " 71 | "Potentially modify NFA so that it can be built incrementally?") 72 | next_chart, allowed_tokens = process_token(self.grammar, self.charts, token) 73 | self.children[token] = EarleyTrieNode(self.grammar, 74 | self.charts.get_child(next_chart), 75 | tuple(allowed_tokens)) 76 | 77 | return self.children[token] 78 | 79 | def is_completable(self): 80 | return charts_completable(self.charts) 81 | 82 | def is_complete(self) -> bool: 83 | last_chart = self[-1] 84 | return any(state.is_complete() and state.rule_name in self.grammar.top_level_rules 85 | and state.span_start == 0 for state, _ in 86 | last_chart.states) 87 | 88 | @property 89 | def allowed_token_names(self) -> Tuple[str, ...]: 90 | return self._allowed_token_names 91 | 92 | def __getitem__(self, item): 93 | return self.charts.__getitem__(item) 94 | 95 | def __len__(self): 96 | return len(self.charts) 97 | 98 | 99 | class DummyEarleyTrieNode(AbstractEarleyTrieNode): 100 | @classmethod 101 | def create_root(cls, grammar: SimpleBNF): 102 | raise NotImplemented 103 | 104 | def __init__(self, 105 | parent: Optional["DummyEarleyTrieNode"], 106 | this_token: Optional[Token], 107 | allowed_token_names: Tuple[str, ...]): 108 | self.parent = parent 109 | self.this_token = this_token 110 | self._allowed_token_names = allowed_token_names 111 | self.children: Dict[Token, DummyEarleyTrieNode] = {} 112 | 113 | def get_child(self, token: Token) -> "DummyEarleyTrieNode": 114 | if token not in self.children: 115 | self.children[token] = DummyEarleyTrieNode(self, token, self.allowed_token_names) 116 | 117 | return self.children[token] 118 | 119 | def is_completable(self) -> bool: 120 | return True 121 | 122 | def is_complete(self) -> bool: 123 | return True 124 | 125 | @property 126 | def allowed_token_names(self) -> Tuple[str, ...]: 127 | return self._allowed_token_names 128 | 129 | def __getitem__(self, item): 130 | raise NotImplementedError("Dummy object") 131 | 132 | def get_reverse_token_sequence(self) -> Generator[Token, None, None]: 133 | node = self 134 | while True: 135 | if node.parent is not None: 136 | assert node.this_token is not None 137 | yield node.this_token 138 | node = node.parent 139 | else: 140 | break 141 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/incremental_pattern.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, NamedTuple, FrozenSet, Optional 3 | 4 | import regex 5 | from lark.utils import get_regexp_width 6 | from regex import Pattern 7 | 8 | from incremental_parsing.regex_ixes.regex_compile import compile 9 | 10 | 11 | class FullMatchResult(NamedTuple): 12 | """ 13 | See documentation for IncrementalPattern::fullmatch and MatchResult 14 | """ 15 | is_full_match: bool 16 | is_partial_match: bool 17 | is_inextensible_match: bool 18 | 19 | def to_partial_match_result(self, string_len: int) -> "MatchResult": 20 | return MatchResult(is_full_match=self.is_full_match, 21 | is_partial_match=self.is_partial_match, 22 | is_inextensible_match=self.is_inextensible_match, 23 | match_end=string_len) 24 | 25 | 26 | class MatchResult(NamedTuple): 27 | """ 28 | See documentation for IncrementalPattern::match 29 | """ 30 | is_full_match: bool # Does the string match this lexeme 31 | is_partial_match: bool # Is there some string s (including empty), such that if you concatenate this string with s, 32 | # it matches the lexeme 33 | is_inextensible_match: bool # Is is_full_match true, and there not a non-empty string s such that this string 34 | # concat with s matches the lexeme 35 | match_end: int 36 | 37 | def to_full_match_result(self, string_len: int) -> "FullMatchResult": 38 | if self.match_end == string_len: 39 | return FullMatchResult(is_full_match=self.is_full_match, 40 | is_partial_match=self.is_partial_match, 41 | is_inextensible_match=self.is_inextensible_match) 42 | else: 43 | return FullMatchResult(False, False, False) 44 | 45 | 46 | class IncrementalPattern(abc.ABC): 47 | """The implementations of IncrementalPatternString are fairly simple, it is worth looking through that""" 48 | @abc.abstractmethod 49 | def fullmatch(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> FullMatchResult: 50 | """ 51 | Does the entirety of text match this pattern 52 | """ 53 | pass 54 | 55 | @abc.abstractmethod 56 | def match(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> MatchResult: 57 | """ 58 | Is there some prefix of text that matches this pattern? If so, what is the longest one. 59 | """ 60 | pass 61 | 62 | @abc.abstractmethod 63 | def sort_order(self) -> Tuple[int, ...]: 64 | pass 65 | 66 | @abc.abstractmethod 67 | def is_nullable(self) -> bool: 68 | pass 69 | 70 | # The rest of this interface essentially deals with NFA states directly 71 | 72 | @property 73 | @abc.abstractmethod 74 | def initial_states(self) -> FrozenSet[int]: 75 | pass 76 | 77 | @property 78 | @abc.abstractmethod 79 | def final_states(self) -> FrozenSet[int]: 80 | pass 81 | 82 | @abc.abstractmethod 83 | def step_forwards_any(self, states: FrozenSet[int]) -> FrozenSet[int]: 84 | """ 85 | What states are reachable after performing a single step forward with _any_ character? 86 | """ 87 | pass 88 | 89 | @abc.abstractmethod 90 | def step_forwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 91 | pass 92 | 93 | @abc.abstractmethod 94 | def step_backwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 95 | pass 96 | 97 | @abc.abstractmethod 98 | def reachable_forward(self, states: FrozenSet[int]) -> FrozenSet[int]: 99 | pass 100 | 101 | @abc.abstractmethod 102 | def is_extensible(self, states: FrozenSet[int]) -> bool: 103 | pass 104 | 105 | 106 | class IncrementalPatternString(IncrementalPattern): 107 | """ 108 | For exact string matches, using a whole NFA is overkill. This implements a lightweight version of IncrementalPattern 109 | """ 110 | def __init__(self, pattern: str): 111 | self.pattern = pattern 112 | 113 | def fullmatch(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> FullMatchResult: 114 | endpos = endpos if endpos is not None else len(text) 115 | if (endpos - pos) > len(self.pattern): 116 | return FullMatchResult(False, False, False) 117 | 118 | endpoint = min(len(text), endpos, pos + len(self.pattern)) 119 | relevant_text = text[pos:endpoint] 120 | if relevant_text == self.pattern: 121 | return FullMatchResult(True, True, True) 122 | elif self.pattern.startswith(relevant_text): 123 | return FullMatchResult(is_full_match=False, is_partial_match=True, is_inextensible_match=False) 124 | else: 125 | return FullMatchResult(False, False, False) 126 | 127 | def match(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> MatchResult: 128 | endpos = endpos if endpos is not None else len(text) 129 | endpoint = min(len(text), endpos, pos + len(self.pattern)) 130 | relevant_text = text[pos:endpoint] 131 | if relevant_text == self.pattern: 132 | return MatchResult(True, True, True, endpoint) 133 | elif self.pattern.startswith(relevant_text): 134 | return MatchResult(is_full_match=False, is_partial_match=True, is_inextensible_match=False, 135 | match_end=endpoint) 136 | elif relevant_text.startswith(self.pattern): 137 | assert False, "This should not happen, covered by the == case" 138 | else: 139 | return MatchResult(False, False, False, -1) 140 | 141 | def sort_order(self) -> Tuple[int, ...]: 142 | return 0, len(self.pattern) 143 | 144 | def is_nullable(self) -> bool: 145 | return len(self.pattern) == 0 146 | 147 | @property 148 | def initial_states(self) -> FrozenSet[int]: 149 | return frozenset({0}) 150 | 151 | @property 152 | def final_states(self) -> FrozenSet[int]: 153 | return frozenset({len(self.pattern)}) 154 | 155 | def step_forwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 156 | result = [] 157 | for state in states: 158 | if self.pattern[state:].startswith(text): 159 | result.append(state + len(text)) 160 | 161 | return frozenset(result) 162 | 163 | def step_forwards_any(self, states: FrozenSet[int]) -> FrozenSet[int]: 164 | return frozenset(i + 1 for i in states if i < len(self.pattern)) 165 | 166 | def step_backwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 167 | result = [] 168 | for state in states: 169 | if self.pattern[:state].endswith(text): 170 | result.append(state - len(text)) 171 | 172 | return frozenset(result) 173 | 174 | def reachable_forward(self, states: FrozenSet[int]) -> FrozenSet[int]: 175 | if len(states) == 0: 176 | return frozenset() 177 | min_state = min(states) 178 | return frozenset(range(min_state, len(self.pattern) + 1)) 179 | 180 | def is_extensible(self, states: FrozenSet[int]) -> bool: 181 | return any(s < len(self.pattern) for s in states) 182 | 183 | 184 | # These are BANNED in regular expressions that we use 185 | LAZY_QUANTIFIERS = regex.compile("([\\*\\+\\?]|\\{\\d+(,(\\d+)?)?})\\?") 186 | 187 | 188 | class IncrementalPatternRegex(IncrementalPattern): 189 | 190 | def __init__(self, pattern: Pattern): 191 | self.pattern = pattern 192 | if LAZY_QUANTIFIERS.search(pattern.pattern) is not None: 193 | raise ValueError(f"Pattern {pattern.pattern} contains a lazy quantifier") 194 | (self.min_pattern_width, self.max_pattern_width) = get_regexp_width(pattern.pattern) 195 | 196 | self.nfa_regex = compile(pattern.pattern, pattern.flags) 197 | 198 | def fullmatch(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> FullMatchResult: 199 | endpos = endpos if endpos is not None else len(text) 200 | 201 | # noinspection PyArgumentList 202 | m = self.pattern.fullmatch(text, pos, endpos, partial=True) # type:ignore[call-overload] 203 | if m: 204 | if m.partial: 205 | return FullMatchResult(is_full_match=False, is_partial_match=True, is_inextensible_match=False) 206 | elif len(text) == self.max_pattern_width: 207 | return FullMatchResult(True, True, True) 208 | else: 209 | return FullMatchResult(is_full_match=True, is_partial_match=True, is_inextensible_match=False) 210 | else: 211 | return FullMatchResult(False, False, False) 212 | 213 | def match(self, text: str, pos: int = 0, endpos: Optional[int] = None) -> MatchResult: 214 | endpos = endpos if endpos is not None else len(text) 215 | 216 | # noinspection PyArgumentList 217 | m = self.pattern.match(text, pos, endpos, partial=True) # type:ignore[call-overload] 218 | if m: 219 | if m.partial: 220 | return MatchResult(is_full_match=False, is_partial_match=True, is_inextensible_match=False, 221 | match_end=m.end()) 222 | elif len(text) == self.max_pattern_width: 223 | return MatchResult(True, True, True, match_end=m.end()) 224 | else: 225 | return MatchResult(is_full_match=True, is_partial_match=True, is_inextensible_match=False, 226 | match_end=m.end()) 227 | else: 228 | return MatchResult(False, False, False, match_end=-1) 229 | 230 | def sort_order(self) -> Tuple[int, ...]: 231 | return 10, self.max_pattern_width 232 | 233 | def is_nullable(self) -> bool: 234 | return self.min_pattern_width == 0 235 | 236 | @property 237 | def initial_states(self) -> FrozenSet[int]: 238 | return self.nfa_regex.start_states 239 | 240 | @property 241 | def final_states(self) -> FrozenSet[int]: 242 | return self.nfa_regex.end_states 243 | 244 | def step_forwards_any(self, states: FrozenSet[int]) -> FrozenSet[int]: 245 | return self.nfa_regex.step_forward_any(states) 246 | 247 | def step_forwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 248 | for char in text: 249 | states = self.nfa_regex.step_forward(states, char) 250 | return states 251 | 252 | def step_backwards(self, states: FrozenSet[int], text: str) -> FrozenSet[int]: 253 | for char in text[::-1]: 254 | states = self.nfa_regex.step_backward(states, char) 255 | return states 256 | 257 | def reachable_forward(self, states: FrozenSet[int]) -> FrozenSet[int]: 258 | return self.nfa_regex.get_reachable_forward(states) 259 | 260 | def is_extensible(self, states: FrozenSet[int]) -> bool: 261 | return self.nfa_regex.is_extensible(states) 262 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/lark_grammar.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Tuple, List 3 | 4 | import lark 5 | import regex 6 | from lark import Lark 7 | from lark.grammar import Rule, Terminal, NonTerminal 8 | from lark.indenter import PythonIndenter 9 | from lark.lexer import Pattern as LarkPattern, Token as LarkToken 10 | from lark.lexer import PatternRE as LarkPatternRE 11 | from lark.lexer import PatternStr as LarkPatternStr 12 | 13 | from incremental_parsing.lex_earley.incremental_pattern import IncrementalPattern, IncrementalPatternString, \ 14 | IncrementalPatternRegex 15 | from incremental_parsing.lex_earley.lex_earley import LexEarleyAlgorithmContext 16 | from incremental_parsing.lex_earley.lexer import IncrementalLexer 17 | from incremental_parsing.lex_earley.python_lex_wrapper import PythonLexWrapper 18 | from incremental_parsing.lex_earley.simple_bnf import BNFTerminal, BNFNonterminal, BNFElement, SimpleBNFProduction, \ 19 | SimpleBNFRule, \ 20 | SimpleBNF 21 | from incremental_parsing.utils.flags_to_regex_flags import flags_to_regex_flags 22 | 23 | 24 | # Utilities to convert from Lark's internal representation to our internal representation 25 | 26 | 27 | def get_name_of_rule(rule_origin: NonTerminal): 28 | origin_name = rule_origin.name 29 | if isinstance(origin_name, LarkToken): 30 | return origin_name.value 31 | elif isinstance(origin_name, str): 32 | return origin_name 33 | else: 34 | raise ValueError(origin_name) 35 | 36 | 37 | def pattern_to_pattern(p: LarkPattern) -> IncrementalPattern: 38 | if isinstance(p, LarkPatternStr): 39 | return IncrementalPatternString(p.value) 40 | elif isinstance(p, LarkPatternRE): 41 | return IncrementalPatternRegex(regex.compile(p.value, flags_to_regex_flags(p.flags))) 42 | else: 43 | raise ValueError(p) 44 | 45 | 46 | def rule_to_simple_bnf(r: Rule) -> Tuple[str, SimpleBNFProduction]: 47 | bnf_list: List[BNFElement] = [] 48 | for symbol in r.expansion: 49 | if isinstance(symbol, Terminal): 50 | bnf_list.append(BNFTerminal(symbol.name)) 51 | else: 52 | bnf_list.append(BNFNonterminal(symbol.name)) 53 | 54 | return get_name_of_rule(r.origin), SimpleBNFProduction(tuple(bnf_list)) 55 | 56 | 57 | def lark_to_lex_earley_context_python(l: Lark) -> LexEarleyAlgorithmContext: 58 | tokens = {} 59 | for terminal_def in l.terminals: 60 | tokens[terminal_def.name] = pattern_to_pattern(terminal_def.pattern) 61 | 62 | # Otherwise we will interpret '''' then end of file as two strings 63 | # Will still get subsumed by long-strings, which is what we want 64 | tokens["BAD_STRING_1"] = IncrementalPatternString("''''") 65 | tokens["BAD_STRING_2"] = IncrementalPatternString('""""') 66 | 67 | lexer = IncrementalLexer(tokens) 68 | lexer_wrapped = PythonLexWrapper(lexer, l.ignore_tokens) 69 | 70 | rules = defaultdict(list) 71 | for rule in l.rules: 72 | rule_name, rule_body = rule_to_simple_bnf(rule) 73 | rules[rule_name].append(rule_body) 74 | 75 | rules_bnf = {name: SimpleBNFRule(tuple(productions)) for name, productions in rules.items()} 76 | grammar = SimpleBNF(rules_bnf, tuple(l.options.start)) 77 | 78 | return LexEarleyAlgorithmContext(grammar=grammar, lexer=lexer_wrapped) 79 | 80 | 81 | def get_python_context(): 82 | kwargs = dict(postlex=PythonIndenter(), start='file_input') 83 | l = lark.Lark.open("../../grammars/python.lark", rel_to=__file__, **kwargs) 84 | return lark_to_lex_earley_context_python(l) 85 | 86 | 87 | def lark_to_context(l: Lark, use_true_leftmost_longest: bool = False) -> LexEarleyAlgorithmContext: 88 | tokens = {} 89 | for terminal_def in l.terminals: 90 | tokens[terminal_def.name] = pattern_to_pattern(terminal_def.pattern) 91 | 92 | lexer = IncrementalLexer(tokens, use_true_leftmost_longest) 93 | rules = defaultdict(list) 94 | for rule in l.rules: 95 | rule_name, rule_body = rule_to_simple_bnf(rule) 96 | rules[rule_name].append(rule_body) 97 | 98 | rules_bnf = {name: SimpleBNFRule(tuple(productions)) for name, productions in rules.items()} 99 | grammar = SimpleBNF(rules_bnf, tuple(l.options.start)) 100 | 101 | return LexEarleyAlgorithmContext(grammar=grammar, lexer=lexer) 102 | 103 | 104 | def get_calc_lang_context(): 105 | l = lark.Lark.open("../../grammars/calculator.lark", start='start', rel_to=__file__) 106 | return lark_to_context(l, True) 107 | 108 | 109 | if __name__ == '__main__': 110 | get_python_context() 111 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/native_earley_nfa.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from incremental_parsing.lex_earley.earley_nfa import AbstractEarleyNFA 4 | from incremental_parsing.lex_earley.middle_earley import ContainsStatesAndCreationMethods 5 | from incremental_parsing.lex_earley.native_earley_trie import NativeEarleyChart 6 | from incremental_parsing.lex_earley.simple_bnf import SimpleBNF 7 | from incremental_parsing.utils.indexable_container import IndexableContainer 8 | from incremental_parsing.utils.simple_nfa import SimpleNFA 9 | 10 | from incremental_parsing._native import NativeGrammar, NativeEarleyCharts 11 | 12 | 13 | class NativeEarleyNFA(AbstractEarleyNFA): 14 | def __init__(self, grammar: NativeGrammar, charts: NativeEarleyCharts, state_mapping: Dict[int, int], reverse_state_mapping: List[int]): 15 | self._charts = charts 16 | self._grammar = grammar 17 | self._state_mapping = state_mapping 18 | self._reverse_state_mapping = reverse_state_mapping 19 | self._cache = {} 20 | 21 | @classmethod 22 | def create(cls, grammar: SimpleBNF, nfa: SimpleNFA[str, str]): 23 | native_grammar = grammar.to_native() 24 | 25 | reverse_state_mapping = list(nfa.states) 26 | state_mapping = {state: idx for idx, state in enumerate(reverse_state_mapping)} 27 | 28 | transitions = [] 29 | for start_state, outgoing_transitions in nfa.atom_transitions_forward.items(): 30 | for nonterminal, dests in outgoing_transitions.items(): 31 | for dest in dests: 32 | transitions.append((state_mapping[start_state], state_mapping[dest], nonterminal)) 33 | 34 | start_states = [state_mapping[s] for s in nfa.start_states] 35 | 36 | charts = NativeEarleyCharts.create_earley_nfa(native_grammar, len(state_mapping), start_states, 37 | transitions) 38 | 39 | return cls(native_grammar, charts, state_mapping, reverse_state_mapping) 40 | 41 | @property 42 | def charts(self) -> "NativeEarleyNFA": 43 | return self 44 | 45 | def __getitem__(self, key: int) -> ContainsStatesAndCreationMethods: 46 | if key in self._cache: 47 | return self._cache[key] 48 | 49 | element = NativeEarleyChart(self._grammar, self._charts, self._state_mapping[key], self._reverse_state_mapping) 50 | self._cache[key] = element 51 | return element 52 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/native_earley_trie.py: -------------------------------------------------------------------------------- 1 | from builtins import str 2 | from typing import Tuple, overload, List, Sequence, Dict, Iterable 3 | from typing_extensions import Self 4 | 5 | from incremental_parsing.lex_earley.earley_base import LexEarleyAlgorithmChart, LexEarleyState, StateCreationMethod, \ 6 | TopLevel, Scanned, Predicted, PredictedNullableCompletion, Completed 7 | from incremental_parsing.lex_earley.earley_trie import AbstractEarleyTrieNode 8 | from incremental_parsing.lex_earley.lexer import Token 9 | 10 | from incremental_parsing._native import NativeGrammar, NativeEarleyCharts 11 | 12 | from incremental_parsing.lex_earley.simple_bnf import SimpleBNF 13 | from incremental_parsing.utils.indexable_container import IndexableContainer 14 | 15 | 16 | def deserialize_state_creation_method(info: List[int], chart_map: IndexableContainer[int]) -> List[StateCreationMethod]: 17 | current_index = 0 18 | result = [] 19 | while current_index < len(info): 20 | val_at_idx = info[current_index] 21 | if val_at_idx == 0: 22 | result.append(TopLevel()) 23 | current_index += 1 24 | elif val_at_idx == 1: 25 | result.append(Scanned(chart_map[info[current_index + 1]], info[current_index + 2])) 26 | current_index += 3 27 | elif val_at_idx == 2: 28 | result.append(Predicted(chart_map[info[current_index+1]], info[current_index+2])) 29 | current_index += 3 30 | elif val_at_idx == 3: 31 | result.append(PredictedNullableCompletion(chart_map[info[current_index+1]], info[current_index+2])) 32 | current_index += 3 33 | elif val_at_idx == 4: 34 | result.append(Completed(chart_map[info[current_index+1]], info[current_index+2], chart_map[info[current_index+3]], info[current_index+4])) 35 | current_index += 5 36 | 37 | return result 38 | 39 | 40 | class IdentMap: 41 | def __init__(self): 42 | pass 43 | 44 | def __getitem__(self, item): 45 | return item 46 | 47 | 48 | class NativeEarleyChart: 49 | def __init__(self, grammar: NativeGrammar, charts: NativeEarleyCharts, index: int, chart_map: IndexableContainer[int] = IdentMap()): 50 | self.grammar = grammar 51 | self.charts = charts 52 | self.index = index 53 | self.chart_map = chart_map 54 | self.cache = [None for _ in range(len(self))] 55 | 56 | def __getitem__(self, item: int) -> Tuple[LexEarleyState, Iterable[StateCreationMethod]]: 57 | if self.cache[item] is not None: 58 | return self.cache[item] 59 | 60 | span_start, name, prod_idx, dot_idx, prod_length, creation_methods = self.charts.get_earley_state(self.grammar, self.index, item) 61 | result = (LexEarleyState( 62 | span_start=self.chart_map[span_start], 63 | rule_name=name, 64 | position=dot_idx, 65 | max_position=prod_length, 66 | production_index=prod_idx 67 | ), deserialize_state_creation_method(creation_methods, self.chart_map)) 68 | 69 | self.cache[item] = result 70 | return result 71 | 72 | def __len__(self): 73 | return self.charts.get_chart_len(self.index) 74 | 75 | def get_states_and_creation_methods(self) -> Sequence[Tuple[LexEarleyState, Iterable[StateCreationMethod]]]: 76 | return [self[i] for i in range(len(self))] 77 | 78 | def __repr__(self): 79 | return list(self[i] for i in range(len(self))).__repr__() 80 | 81 | 82 | class NativeEarleyTrieNode(AbstractEarleyTrieNode): 83 | @classmethod 84 | def create_root(cls, grammar: SimpleBNF): 85 | native_grammar = grammar.to_native() 86 | charts, root_chart_num, allowed_terminals, completable, complete = NativeEarleyCharts.create_initial_earley_charts(native_grammar) 87 | return cls( 88 | native_grammar=native_grammar, 89 | native_charts=charts, 90 | chart_num=root_chart_num, 91 | allowed_token_names=allowed_terminals, 92 | completable=completable, 93 | complete=complete, 94 | depth=1, 95 | cache={}, 96 | orig_grammar=grammar, 97 | ) 98 | 99 | def __init__(self, native_grammar: NativeGrammar, native_charts: NativeEarleyCharts, chart_num: int, allowed_token_names: Sequence[str], completable: bool, complete: bool, depth: int, cache: Dict[int, NativeEarleyChart], orig_grammar: SimpleBNF): 100 | self.native_grammar = native_grammar 101 | self.native_charts = native_charts 102 | self.chart_num = chart_num 103 | self._allowed_token_names = allowed_token_names 104 | self.children : Dict[str, NativeEarleyTrieNode] = {} 105 | self._completable = completable 106 | self._complete = complete 107 | self._depth = depth 108 | self._cache = cache 109 | self.orig_grammar = orig_grammar 110 | 111 | def get_child(self, token: Token) -> Self: 112 | if token.name not in self.children: 113 | assert not token.loose_behavior 114 | next_chart_num, next_allowed_token_names, completable, complete = self.native_charts.parse(self.native_grammar, self.chart_num, token.name) 115 | self.children[token.name] = NativeEarleyTrieNode(self.native_grammar, self.native_charts, next_chart_num, next_allowed_token_names, completable, complete, self._depth + 1, self._cache, self.orig_grammar) 116 | 117 | return self.children[token.name] 118 | 119 | def is_completable(self) -> bool: 120 | return self._completable 121 | 122 | def is_complete(self) -> bool: 123 | return self._complete 124 | 125 | @overload 126 | def __getitem__(self, index: int) -> LexEarleyAlgorithmChart: 127 | ... 128 | 129 | @overload 130 | def __getitem__(self, index: slice) -> List[LexEarleyAlgorithmChart]: 131 | ... 132 | 133 | def __getitem__(self, item): 134 | if isinstance(item, int): 135 | if item in self._cache: 136 | return self._cache[item] 137 | 138 | element = NativeEarleyChart(self.native_grammar, self.native_charts, item) 139 | self._cache[item] = element 140 | return element 141 | raise NotImplemented 142 | 143 | def __len__(self): 144 | return self._depth 145 | 146 | @property 147 | def allowed_token_names(self) -> Sequence[str]: 148 | return self._allowed_token_names -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/simple_bnf.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import defaultdict 3 | from typing import NamedTuple, Union, Tuple, Dict, AbstractSet, Iterable, FrozenSet, Set, Container, \ 4 | DefaultDict, List 5 | from numpy.ma.testutils import assert_equal 6 | from incremental_parsing._native import BridgeBNFElement, BridgeBNFProduction, BridgeBNFRule, BridgeBNFGrammar 7 | 8 | 9 | class BNFTerminal(NamedTuple): 10 | name: str 11 | 12 | def __str__(self): 13 | return self.name 14 | 15 | def to_bridge(self): 16 | return BridgeBNFElement.Terminal(self.name) 17 | 18 | 19 | class BNFNonterminal(NamedTuple): 20 | name: str 21 | 22 | def __str__(self): 23 | return self.name 24 | 25 | def to_bridge(self): 26 | return BridgeBNFElement.Nonterminal(self.name) 27 | 28 | 29 | BNFElement = Union[BNFTerminal, BNFNonterminal] 30 | 31 | 32 | class SimpleBNFProduction(NamedTuple): 33 | elements: Tuple[BNFElement, ...] 34 | 35 | def __str__(self): 36 | if len(self.elements) == 0: 37 | return 'λ' 38 | else: 39 | return ' '.join(map(str, self.elements)) 40 | 41 | def reverse(self) -> "SimpleBNFProduction": 42 | return SimpleBNFProduction(tuple(reversed(self.elements))) 43 | 44 | def to_bridge(self) -> BridgeBNFProduction: 45 | return BridgeBNFProduction(tuple(element.to_bridge() for element in self.elements)) 46 | 47 | 48 | class SimpleBNFRule(NamedTuple): 49 | productions: Tuple[SimpleBNFProduction, ...] 50 | 51 | def __str__(self): 52 | return '\n | '.join(map(str, self.productions)) 53 | 54 | def reverse(self) -> "SimpleBNFRule": 55 | return SimpleBNFRule(tuple(prod.reverse() for prod in self.productions)) 56 | 57 | def to_bridge(self) -> BridgeBNFProduction: 58 | return BridgeBNFRule(tuple(prod.to_bridge() for prod in self.productions)) 59 | 60 | 61 | class SimpleBNF: 62 | rules: Dict[str, SimpleBNFRule] 63 | nullable_rules: AbstractSet[str] 64 | top_level_rules: Tuple[str, ...] 65 | 66 | def __init__(self, rules: Dict[str, SimpleBNFRule], top_level_rules: Tuple[str, ...], 67 | remove_unreachable_rules=True): 68 | self.reachable_rules = frozenset(self.get_reachable_rules(top_level_rules, rules)) 69 | if remove_unreachable_rules: 70 | self.rules = {rule_name: rule for rule_name, rule in rules.items() if rule_name in self.reachable_rules} 71 | else: 72 | self.rules = rules 73 | self.nullable_rules = frozenset(self.get_nullable_rules(self.rules)) 74 | self.top_level_rules = top_level_rules # Always reachable 75 | 76 | @staticmethod 77 | def get_reachable_rules(top_level_rules: Iterable[str], rules: Dict[str, SimpleBNFRule]) -> AbstractSet[str]: 78 | """ 79 | Returns a list of rules that are reachable from the top level rules. 80 | """ 81 | processed_reachable_rules = set() 82 | recently_reachable_rules = list(top_level_rules) 83 | 84 | while len(recently_reachable_rules) > 0: 85 | rule_to_check = recently_reachable_rules.pop() 86 | 87 | if rule_to_check not in processed_reachable_rules: 88 | processed_reachable_rules.add(rule_to_check) 89 | for production in rules[rule_to_check].productions: 90 | for element in production.elements: 91 | if isinstance(element, BNFNonterminal): 92 | recently_reachable_rules.append(element.name) 93 | 94 | return processed_reachable_rules 95 | 96 | @staticmethod 97 | def get_nullable_rules(rules: Dict[str, SimpleBNFRule]) -> AbstractSet[str]: 98 | """ 99 | Returns a list of rules that are nullable. 100 | Essentially the algorithm in 101 | https://github.com/jeffreykegler/old_kollos/blob/master/notes/misc/loup2.md 102 | 103 | We don't care if any lexer tokens are nullable here- a nullable lexer token will 104 | still become a parser token; the BNF rule which produces this is not nullable 105 | """ 106 | nullable_rules = set() 107 | recently_nullable_rules = [] 108 | 109 | rules_referenced_by_rules = defaultdict(set) 110 | 111 | for rule_name, body in rules.items(): 112 | for production in body.productions: 113 | if len(production.elements) == 0: 114 | nullable_rules.add(rule_name) 115 | recently_nullable_rules.append(rule_name) 116 | else: 117 | for element in production.elements: 118 | if isinstance(element, BNFNonterminal): 119 | rules_referenced_by_rules[element.name].add(rule_name) 120 | 121 | while len(recently_nullable_rules) > 0: 122 | rule_to_check = recently_nullable_rules.pop() 123 | for rule_referenced_by_rule in rules_referenced_by_rules[rule_to_check]: 124 | if rule_referenced_by_rule not in nullable_rules: 125 | for production in rules[rule_referenced_by_rule].productions: 126 | if all(isinstance(element, BNFNonterminal) and element.name in nullable_rules 127 | for element in production.elements): 128 | nullable_rules.add(rule_referenced_by_rule) 129 | recently_nullable_rules.append(rule_referenced_by_rule) 130 | break 131 | 132 | return nullable_rules 133 | 134 | def get_all_final_terminals(self) -> FrozenSet[str]: 135 | final_terminals: Set[str] = set() 136 | 137 | processed: Set[str] = set() 138 | to_process: Set[str] = set(self.top_level_rules) 139 | while len(to_process) > 0: 140 | p = to_process.pop() 141 | if p in processed: 142 | continue 143 | processed.add(p) 144 | 145 | for production in self.rules[p].productions: 146 | last_element = production.elements[-1] 147 | if isinstance(last_element, BNFTerminal): 148 | final_terminals.add(last_element.name) 149 | else: 150 | to_process.add(last_element.name) 151 | 152 | return frozenset(final_terminals) 153 | 154 | def to_bnf_ending_in(self, permissible_final_terminals: Container[str]): 155 | """ 156 | Returns a BNF in which all paths must end in one of the permissible_final_terminals 157 | """ 158 | last_productions_referenced_by_rules: DefaultDict[str, Set[Tuple[str, int, int]]] = defaultdict(set) 159 | 160 | acceptable_productions: Set[Tuple[str, int, int]] = set() 161 | seen_dest_rules: Set[str] = set() 162 | dest_rule_queue: List[str] = [] 163 | 164 | def potential_last_elements(production: SimpleBNFProduction): 165 | for idx, element in reversed(list(enumerate(production.elements))): 166 | yield idx, element 167 | if isinstance(element, BNFNonterminal) and element.name in self.nullable_rules: 168 | continue 169 | else: 170 | break 171 | 172 | for rule_name, body in self.rules.items(): 173 | for prod_idx, production in enumerate(body.productions): 174 | for end_idx, element in potential_last_elements(production): 175 | if isinstance(element, BNFNonterminal): 176 | last_productions_referenced_by_rules[element.name].add((rule_name, prod_idx, end_idx)) 177 | else: 178 | if element.name in permissible_final_terminals: 179 | acceptable_productions.add((rule_name, prod_idx, end_idx)) 180 | if rule_name not in seen_dest_rules: 181 | seen_dest_rules.add(rule_name) 182 | dest_rule_queue.append(rule_name) 183 | 184 | while len(dest_rule_queue) != 0: 185 | dest_rule = dest_rule_queue.pop() 186 | for referring_rule, referring_prod_idx, referring_prod_end_idx in last_productions_referenced_by_rules[ 187 | dest_rule]: 188 | last_element = self.rules[referring_rule].productions[referring_prod_idx].elements[ 189 | referring_prod_end_idx] 190 | if isinstance(last_element, BNFTerminal) and last_element.name not in permissible_final_terminals: 191 | assert False 192 | acceptable_productions.add((referring_rule, referring_prod_idx, referring_prod_end_idx)) 193 | if referring_rule not in seen_dest_rules: 194 | seen_dest_rules.add(referring_rule) 195 | dest_rule_queue.append(referring_rule) 196 | 197 | modified_bnf_elements: DefaultDict[str, List[SimpleBNFProduction]] = defaultdict(list) 198 | for rule_name, prod_idx, prod_end_idx in acceptable_productions: 199 | original_production = self.rules[rule_name].productions[prod_idx] 200 | # Let's say that the production is A B C, and C is nullable in the original BNF 201 | # We add two productions: A B-final, and A B C-final, where B-final and C-final are both guaranteed to end 202 | # in a permissible final terminal 203 | # Note that C-final is no longer nullable 204 | last_element = original_production.elements[prod_end_idx] 205 | modified_last_element: BNFElement 206 | if isinstance(last_element, BNFTerminal): 207 | assert last_element.name in permissible_final_terminals 208 | modified_last_element = last_element 209 | else: 210 | modified_last_element = BNFNonterminal(name=(last_element.name + "/final")) 211 | 212 | modified_bnf_elements[rule_name].append(SimpleBNFProduction( 213 | original_production.elements[:prod_end_idx] + (modified_last_element,) 214 | )) 215 | 216 | rules: Dict[str, SimpleBNFRule] = self.rules.copy() 217 | for name, productions in modified_bnf_elements.items(): 218 | rules[name + "/final"] = SimpleBNFRule(tuple(productions)) 219 | 220 | top_level_rules = tuple(tlr + "/final" for tlr in self.top_level_rules if tlr in modified_bnf_elements) 221 | return SimpleBNF(rules, top_level_rules) 222 | 223 | def reverse(self) -> "SimpleBNF": 224 | return SimpleBNF( 225 | rules={rule_name: rule.reverse() for rule_name, rule in self.rules.items()}, 226 | top_level_rules=self.top_level_rules, 227 | remove_unreachable_rules=False # This transformation doesn't change reachability 228 | ) 229 | 230 | def __str__(self): 231 | return "\n\n".join(f"{name} : {rule}" for name, rule in self.rules.items()) 232 | 233 | def to_bridge(self): 234 | return BridgeBNFGrammar(rules={name: rule.to_bridge() for (name, rule) in self.rules.items()}, 235 | top_level_rules=self.top_level_rules) 236 | 237 | def to_native(self): 238 | return self.to_bridge().to_native() 239 | 240 | 241 | class TestBNFEndingIn(unittest.TestCase): 242 | def test_calc_lang(self): 243 | from incremental_parsing.lex_earley.lark_grammar import get_calc_lang_context 244 | from incremental_parsing.lex_earley.lex_earley import LexEarleyAlgorithmContext 245 | from lex_earley import lex_earley_parse 246 | 247 | calc_lang_context = get_calc_lang_context() 248 | assert_equal(lex_earley_parse(calc_lang_context, "(1-2)"), True) 249 | assert_equal(lex_earley_parse(calc_lang_context, "(1-2)+2"), True) 250 | assert_equal(lex_earley_parse(calc_lang_context, "(1-2)+2+(2)"), True) 251 | 252 | modified_context = LexEarleyAlgorithmContext(calc_lang_context.grammar.to_bnf_ending_in(["RPAR"]), 253 | calc_lang_context.lexer) 254 | 255 | assert_equal(lex_earley_parse(modified_context, "(1-2)"), True) 256 | assert_equal(lex_earley_parse(modified_context, "(1-2)+2"), False) 257 | assert_equal(lex_earley_parse(modified_context, "(1-2)+2+(2)"), True) 258 | 259 | modified_context = LexEarleyAlgorithmContext(calc_lang_context.grammar.to_bnf_ending_in(["NUMBER"]), 260 | calc_lang_context.lexer) 261 | 262 | assert_equal(lex_earley_parse(modified_context, "(1-2)"), False) 263 | assert_equal(lex_earley_parse(modified_context, "(1-2)+2"), True) 264 | assert_equal(lex_earley_parse(modified_context, "(1-2)+2+(2)"), False) 265 | -------------------------------------------------------------------------------- /incremental_parsing/lex_earley/test.py: -------------------------------------------------------------------------------- 1 | from incremental_parsing.lex_earley.lark_grammar import get_python_context 2 | from incremental_parsing.lex_earley.lex_earley import lex_earley_init, lex_earley_run, is_completable, is_complete, \ 3 | lex_earley_to_middle, LexEarleyAlgorithmState, LexEarleyAlgorithmPrefixState 4 | 5 | 6 | def run_tests(): 7 | context = get_python_context() 8 | init_state = lex_earley_init(context) 9 | 10 | def assert_state(state, completable, complete): 11 | if completable: 12 | assert state is not None 13 | assert is_completable(state) 14 | 15 | if complete: 16 | assert is_complete(context, state) 17 | else: 18 | assert not is_complete(context, state) 19 | else: 20 | assert state is None or not is_completable(state) 21 | 22 | def assert_text(text, completable, complete): 23 | state = lex_earley_run(context, init_state, text) 24 | assert_state(state, completable, complete) 25 | assert_text_context("", text, "", completable, complete) 26 | 27 | def assert_text_context(prefix, middle, suffix, completable, complete): 28 | # There are numerous more or less equivalent ways to actually invoke the parser 29 | # Test all of them 30 | 31 | # No left context, no lookahead 32 | state: LexEarleyAlgorithmState = \ 33 | lex_earley_to_middle(context=context, state=init_state, suffix=suffix, middle_lookahead="") 34 | state = lex_earley_run(context=context, state=state, value=(prefix + middle)) 35 | assert_state(state, completable, complete) 36 | 37 | # No left context, with lookahead 38 | # In some cases, a known lookahead can be more efficient 39 | state = lex_earley_to_middle(context=context, state=init_state, suffix=suffix, 40 | middle_lookahead=(prefix + middle)) 41 | state = lex_earley_run(context=context, state=state, value=(prefix + middle)) 42 | assert_state(state, completable, complete) 43 | 44 | # With left context, no lookahead 45 | # In some cases, a known lookahead can be more efficient 46 | state = lex_earley_run(context=context, state=init_state, value=prefix) 47 | assert isinstance(state, LexEarleyAlgorithmPrefixState) 48 | state = lex_earley_to_middle(context=context, state=state, suffix=suffix, middle_lookahead="") 49 | state = lex_earley_run(context=context, state=state, value=middle) 50 | assert_state(state, completable, complete) 51 | 52 | # With left context, and lookahead 53 | state = lex_earley_init(context, lookahead=prefix) 54 | state = lex_earley_run(context=context, state=state, value=prefix) 55 | assert isinstance(state, LexEarleyAlgorithmPrefixState) 56 | state = lex_earley_to_middle(context=context, state=state, suffix=suffix, middle_lookahead=middle) 57 | state = lex_earley_run(context=context, state=state, value=middle) 58 | assert_state(state, completable, complete) 59 | 60 | assert_text("f(a and", True, False) 61 | assert_text("f(a andy", False, False) 62 | assert_text("f(aandy", True, False) 63 | assert_text("f(aand and b)\n", True, True) 64 | assert_text("f(aand and b andy)", False, False) 65 | assert_text("""# foo 66 | assert bar(a) == "b" 67 | assert bar(c) == "d" 68 | """, True, True) 69 | 70 | assert_text(""" 71 | blah = 1 72 | """, True, False) 73 | 74 | assert_text(""" 75 | blah = 1 76 | f""", False, False) 77 | 78 | assert_text(""" 79 | if a: 80 | blah = 1 81 | """, True, False) 82 | 83 | assert_text(""" 84 | if a: 85 | blah = 1 86 | f""", False, False) 87 | 88 | assert_text(""" 89 | if a: 90 | blah = 1 91 | f""", True, False) 92 | 93 | assert_text(""" 94 | if a: 95 | blah = 1 96 | f""", False, False) 97 | 98 | assert_text(""" 99 | if a: 100 | blah = ( 101 | f""", True, False) 102 | 103 | assert_text(""" 104 | ''' 105 | asdf\\\n 106 | ''' 107 | """, True, True) 108 | 109 | assert_text(""" 110 | if foo: 111 | ''' 112 | blah 113 | ''' 114 | """, True, True) 115 | 116 | assert_text(""" 117 | if foo: 118 | ''' 119 | blah 120 | ''' 121 | return 122 | """, False, False) 123 | 124 | assert_text("assert find_literals('The quick brown fox jumps over the lazy dog.', 'fox') == ('fox', 16, 19)\n", True, True) 125 | 126 | assert_text_context("0", "o", "r 1\n", True, False) 127 | assert_text_context("0", "or", " 1\n", False, False) 128 | assert_text_context("0 ", "o", "r 1\n", True, True) 129 | assert_text_context("def foo(", "asdf)", " pass\n", True, False) 130 | assert_text_context("def foo(", "asdf):", " pass\n", True, True) 131 | 132 | 133 | if __name__ == "__main__": 134 | run_tests() -------------------------------------------------------------------------------- /incremental_parsing/regex_ixes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/regex_ixes/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/regex_ixes/regex_compile.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from incremental_parsing.regex_ixes.regex_nfa import regex_match, ATOM_PATTERN_TYPE 4 | from incremental_parsing.regex_ixes.regex_parse import parse_regex 5 | from incremental_parsing.utils.simple_nfa import SimpleNFA, SimpleNFAMutable 6 | 7 | 8 | def compile(regex: str, flags: int) -> SimpleNFA[ATOM_PATTERN_TYPE, str]: 9 | r = parse_regex(regex, flags) 10 | n: SimpleNFAMutable[ATOM_PATTERN_TYPE] = SimpleNFAMutable() 11 | r.to_nfa(n, 0, 1) 12 | return n.finalize(regex_match) 13 | 14 | 15 | class TestRegexNFA(unittest.TestCase): 16 | def test_regex_comment(self): 17 | r = compile(r'#[^\n]*', 0) 18 | self.assertFalse(r.fullmatch("")) 19 | self.assertFalse(r.fullmatch(' #')) 20 | self.assertTrue(r.fullmatch('#')) 21 | self.assertFalse(r.fullmatch('#\n')) 22 | self.assertTrue(r.fullmatch('# asdfaweproiajweioprjaoisejf')) 23 | 24 | def test_regex_linecont(self): 25 | r = compile(r'\\[\t \f]*\r?\n', 0) 26 | 27 | self.assertFalse(r.fullmatch("")) 28 | self.assertFalse(r.fullmatch('\\\na')) 29 | self.assertTrue(r.fullmatch('\\\n')) 30 | self.assertTrue(r.fullmatch('\\ \n')) 31 | self.assertTrue(r.fullmatch('\\ \t\t\n')) 32 | self.assertTrue(r.fullmatch('\\ \t\t\r\n')) 33 | self.assertFalse(r.fullmatch('\\ \t\r\t\r\n')) 34 | 35 | def test_regex_newline(self): 36 | r = compile(r'((\r?\n[\t ]*|#[^\n]*))+', 0) 37 | self.assertFalse(r.fullmatch("")) 38 | self.assertFalse(r.fullmatch('a')) 39 | self.assertFalse(r.fullmatch('a\n')) 40 | self.assertFalse(r.fullmatch('\na')) 41 | self.assertTrue(r.fullmatch('\n')) 42 | self.assertTrue(r.fullmatch('#hello\n')) 43 | self.assertFalse(r.fullmatch('#hello\nfoo')) 44 | -------------------------------------------------------------------------------- /incremental_parsing/regex_ixes/regex_nfa.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import regex 4 | 5 | ATOM_PATTERN_TYPE = Tuple[str, int] 6 | CACHED_ATOM_REGEXES: Dict[ATOM_PATTERN_TYPE, regex.Pattern] = dict() 7 | CACHED_VALUES: Dict[Tuple[ATOM_PATTERN_TYPE, str], bool] = dict() 8 | 9 | 10 | def regex_match(pattern: ATOM_PATTERN_TYPE, char: str) -> bool: 11 | val = CACHED_VALUES.get((pattern, char)) 12 | if val is not None: 13 | return val 14 | 15 | pat = CACHED_ATOM_REGEXES.get(pattern) 16 | if pat is not None: 17 | val = (pat.fullmatch(char) is not None) 18 | CACHED_VALUES[(pattern, char)] = val 19 | return val 20 | 21 | pat = regex.compile(pattern[0], flags=pattern[1]) 22 | CACHED_ATOM_REGEXES[pattern] = pat 23 | val = (pat.fullmatch(char) is not None) 24 | CACHED_VALUES[(pattern, char)] = val 25 | return val 26 | -------------------------------------------------------------------------------- /incremental_parsing/regex_ixes/regex_parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unittest 3 | from typing import Tuple 4 | 5 | from incremental_parsing.regex_ixes.regex_tree import RegexSequence, RegexAlternates, RegexGroup, RegexAtom, RegexNode, \ 6 | RegexRepeat 7 | 8 | 9 | def parse_regex(regex: str, flags: int = 0) -> RegexNode: 10 | assert not regex.startswith('^') 11 | assert not regex.endswith('$') 12 | regex_tree, remaining_regex = parse_regex_alternates(regex, flags) 13 | assert remaining_regex == "" 14 | return regex_tree 15 | 16 | 17 | def parse_regex_alternates(regex: str, flags: int) -> Tuple[RegexNode, str]: 18 | elems = [] 19 | while regex != "" and regex[0] != ')': 20 | elem, regex = parse_regex_sequence(regex, flags) 21 | elems.append(elem) 22 | if regex == "" or regex[0] == ')': 23 | break 24 | elif regex[0] == '|': 25 | regex = regex[1:] 26 | continue 27 | else: 28 | raise ValueError("Expected | or ), got " + regex) 29 | 30 | if not elems: 31 | return RegexSequence(()), regex 32 | elif len(elems) == 1: 33 | return elems[0], regex 34 | else: 35 | return RegexAlternates(tuple(elems)), regex 36 | 37 | 38 | def parse_regex_sequence(regex: str, flags: int) -> Tuple[RegexNode, str]: 39 | seq = [] 40 | while regex != "" and regex[0] not in '|)': 41 | if regex[0] == "(": 42 | elem, regex = parse_group(regex, flags) 43 | elif regex[0] == "[": 44 | elem, regex = parse_char_class(regex, flags) 45 | else: 46 | elem, regex = parse_char(regex, flags) 47 | 48 | elem, regex = parse_operator(elem, regex, flags) 49 | seq.append(elem) 50 | 51 | if len(seq) == 1: 52 | return seq[0], regex 53 | else: 54 | return RegexSequence(tuple(seq)), regex 55 | 56 | 57 | def parse_group(regex: str, flags: int) -> Tuple[RegexNode, str]: 58 | assert regex[0] == '(' 59 | regex = regex[1:] 60 | if regex.startswith('?:'): 61 | prefix = '?:' 62 | regex = regex[2:] 63 | elif regex.startswith('?'): 64 | raise NotImplementedError(f"Regex group ({regex}") 65 | else: 66 | prefix = '' 67 | 68 | inside, regex = parse_regex_alternates(regex, flags) 69 | assert regex[0] == ')' 70 | regex = regex[1:] 71 | return RegexGroup(inside, prefix), regex 72 | 73 | 74 | def parse_char_class(regex: str, flags: int) -> Tuple[RegexNode, str]: 75 | assert regex[0] == '[' 76 | regex = regex[1:] 77 | char_class_text = "[" 78 | 79 | while True: 80 | assert "]" in regex, "Unclosed character class" 81 | 82 | pre, regex = regex.split(']', 1) 83 | char_class_text += pre + "]" 84 | 85 | if len(pre) == 0 or pre[-1] != '\\': 86 | break 87 | else: 88 | # We just saw (and added) a \], didn't close the char class 89 | continue 90 | 91 | return RegexAtom(char_class_text, flags), regex 92 | 93 | 94 | def parse_char(regex: str, flags: int) -> Tuple[RegexNode, str]: 95 | if regex[0] == '\\': 96 | if regex[1].isdigit(): 97 | if regex[2].isdigit(): 98 | return RegexAtom(regex[:3], flags), regex[3:] 99 | return RegexAtom(regex[:2], flags), regex[2:] 100 | else: 101 | return RegexAtom(regex[0], flags), regex[1:] 102 | 103 | 104 | number = re.compile(r"\d+") 105 | 106 | 107 | def parse_number(regex: str) -> Tuple[int, str]: 108 | match = number.match(regex) 109 | if match is None: 110 | raise ValueError(f"Invalid number: {regex}") 111 | return int(match.group()), regex[match.end():] 112 | 113 | 114 | def parse_operator(elem: RegexNode, regex: str, flags: int) -> Tuple[RegexNode, str]: 115 | if regex == "": 116 | return elem, "" 117 | elif regex[0] == "*": 118 | if len(regex) > 1 and regex[1] == "?": 119 | raise NotImplementedError(f"Lazy regex operator *?") 120 | return RegexRepeat(elem, 0, None), regex[1:] 121 | elif regex[0] == "+": 122 | if len(regex) > 1 and regex[1] == "?": 123 | raise NotImplementedError(f"Lazy regex operator +?") 124 | return RegexRepeat(elem, 1, None), regex[1:] 125 | elif regex[0] == "?": 126 | if len(regex) > 1 and regex[1] == "?": 127 | raise NotImplementedError(f"Lazy regex operator ??") 128 | return RegexRepeat(elem, 0, 1), regex[1:] 129 | elif regex[0] == "{": 130 | regex = regex[1:] 131 | 132 | if regex[0] == "}": 133 | # Make sure we don't have x{} 134 | raise ValueError(f"Empty number in {regex}") 135 | 136 | if regex[0] == ",": 137 | min_repeat = 0 138 | else: 139 | min_repeat, regex = parse_number(regex) 140 | 141 | if regex[0] == ",": 142 | regex = regex[1:] 143 | if regex[0] == "}": 144 | max_repeat = None 145 | else: 146 | max_repeat, regex = parse_number(regex) 147 | elif regex[0] == "}": 148 | max_repeat = min_repeat 149 | else: 150 | raise ValueError(f"Invalid number in {regex}") 151 | 152 | assert regex[0] == "}" 153 | regex = regex[1:] 154 | if len(regex) > 0 and regex[0] == "?": 155 | raise NotImplementedError(f"Lazy regex operator") 156 | 157 | return RegexRepeat(elem, min_repeat, max_repeat), regex 158 | else: 159 | return elem, regex 160 | 161 | 162 | class TestRegexParses(unittest.TestCase): 163 | def test_regex_parse_repeat(self): 164 | self.assertEqual(str(parse_regex('a{,}')), 'a*') 165 | self.assertEqual(str(parse_regex('a{1,}')), 'a+') 166 | 167 | def test_bad_parses(self): 168 | self.assertRaises(ValueError, parse_regex, 'a{}') 169 | self.assertRaises(NotImplementedError, parse_regex, 'a{3,4}?') 170 | self.assertRaises(NotImplementedError, parse_regex, 'a??') 171 | self.assertRaises(NotImplementedError, parse_regex, 'a*?') 172 | self.assertRaises(NotImplementedError, parse_regex, 'a+?') 173 | self.assertRaises(ValueError, parse_regex, 'a{asdf}') 174 | self.assertRaises(ValueError, parse_regex, 'a{,asdf}') 175 | self.assertRaises(ValueError, parse_regex, 'a{1asdf}') 176 | self.assertRaises(IndexError, parse_regex, 'a(b') 177 | self.assertRaises(AssertionError, parse_regex, 'a)b') 178 | self.assertRaises(AssertionError, parse_regex, 'a)') 179 | self.assertRaises(AssertionError, parse_regex, '[abc\\]') 180 | self.assertRaises(NotImplementedError, parse_regex, '(?>a)') 181 | -------------------------------------------------------------------------------- /incremental_parsing/regex_ixes/regex_tree.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, Optional 3 | 4 | from incremental_parsing.utils.simple_nfa import SimpleNFAMutable 5 | 6 | 7 | class RegexNode(abc.ABC): 8 | @abc.abstractmethod 9 | def __str__(self): 10 | pass 11 | 12 | @abc.abstractmethod 13 | def to_nfa(self, nfa: SimpleNFAMutable, start_ix: int, end_ix: int): 14 | pass 15 | 16 | 17 | class RegexSequence(RegexNode): 18 | def __init__(self, nodes: Tuple[RegexNode, ...]): 19 | self.nodes = nodes 20 | 21 | def __str__(self): 22 | return ''.join(map(str, self.nodes)) 23 | 24 | def to_nfa(self, nfa: SimpleNFAMutable, start_ix: int, end_ix: int): 25 | if len(self.nodes) == 0: 26 | nfa.add_eps_transition(start_ix, end_ix) 27 | else: 28 | endpoints = [start_ix] 29 | for i in range(len(self.nodes) - 1): 30 | endpoints.append(nfa.add_state()) 31 | endpoints.append(end_ix) 32 | 33 | for i, node in enumerate(self.nodes): 34 | node.to_nfa(nfa, endpoints[i], endpoints[i + 1]) 35 | 36 | 37 | class RegexRepeat(RegexNode): 38 | def __init__(self, node: RegexNode, min_repeat: int, max_repeat: Optional[int]): 39 | self.node = node 40 | self.min_repeat = min_repeat 41 | self.max_repeat = max_repeat 42 | 43 | def __str__(self): 44 | if self.min_repeat == 0 and self.max_repeat == 0: 45 | return '' 46 | elif self.min_repeat == 0 and self.max_repeat is None: 47 | return f'{self.node}*' 48 | elif self.min_repeat == 1 and self.max_repeat is None: 49 | return f'{self.node}+' 50 | elif self.min_repeat == 0 and self.max_repeat == 1: 51 | return f'{self.node}?' 52 | elif self.min_repeat == 0: 53 | return f'{self.node}{{,{self.max_repeat}}}' 54 | elif self.max_repeat is None: 55 | return f'{self.node}{{{self.min_repeat},}}' 56 | elif self.min_repeat == self.max_repeat: 57 | return f'{self.node}{{{self.min_repeat}}}' 58 | else: 59 | return f'{self.node}{{{self.min_repeat},{self.max_repeat}}}' 60 | 61 | def to_nfa(self, nfa: SimpleNFAMutable, start_idx: int, end_idx: int): 62 | idx_after_mandatory_repeats = start_idx 63 | 64 | for i in range(self.min_repeat): 65 | next_start_idx = nfa.add_state() 66 | self.node.to_nfa(nfa, idx_after_mandatory_repeats, next_start_idx) 67 | idx_after_mandatory_repeats = next_start_idx 68 | 69 | nfa.add_eps_transition(idx_after_mandatory_repeats, end_idx) 70 | 71 | if self.max_repeat == self.min_repeat: 72 | return 73 | else: 74 | idx_after_optional_repeats = nfa.add_state() 75 | self.node.to_nfa(nfa, idx_after_mandatory_repeats, idx_after_optional_repeats) 76 | nfa.add_eps_transition(idx_after_optional_repeats, end_idx) 77 | if self.max_repeat is None: 78 | nfa.add_eps_transition(idx_after_optional_repeats, idx_after_mandatory_repeats) 79 | else: 80 | for i in range(self.max_repeat - self.min_repeat - 1): 81 | next_idx_after_optional_repeat = nfa.add_state() 82 | self.node.to_nfa(nfa, idx_after_optional_repeats, next_idx_after_optional_repeat) 83 | nfa.add_eps_transition(next_idx_after_optional_repeat, end_idx) 84 | idx_after_optional_repeats = next_idx_after_optional_repeat 85 | 86 | 87 | class RegexAtom(RegexNode): 88 | def __init__(self, value: str, flags: int): 89 | self.value = value 90 | self.flags = flags 91 | 92 | def __str__(self): 93 | return self.value 94 | 95 | def to_nfa(self, nfa: SimpleNFAMutable, start_ix: int, end_ix: int): 96 | nfa.add_atom_transition(start_ix, end_ix, (self.value, self.flags)) 97 | 98 | 99 | class RegexAlternates(RegexNode): 100 | def __init__(self, nodes: Tuple[RegexNode, ...]): 101 | self.nodes = nodes 102 | 103 | def __str__(self): 104 | return '|'.join(map(str, self.nodes)) 105 | 106 | def to_nfa(self, nfa: SimpleNFAMutable, start_ix: int, end_ix: int): 107 | for node in self.nodes: 108 | node.to_nfa(nfa, start_ix, end_ix) 109 | 110 | 111 | class RegexGroup(RegexNode): 112 | def __init__(self, node: RegexNode, group_prefix: str): 113 | self.node = node 114 | self.group_prefix = group_prefix 115 | 116 | def __str__(self): 117 | return f'({self.group_prefix}{self.node})' 118 | 119 | def to_nfa(self, nfa: SimpleNFAMutable, start_ix: int, end_ix: int): 120 | self.node.to_nfa(nfa, start_ix, end_ix) 121 | -------------------------------------------------------------------------------- /incremental_parsing/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/incremental-parsing/24e501190b5eff2abbf6680c2d2a36ef37781402/incremental_parsing/utils/__init__.py -------------------------------------------------------------------------------- /incremental_parsing/utils/flags_to_regex_flags.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import regex 4 | 5 | flag_map = { 6 | 'i': regex.I, 7 | 'm': regex.M, 8 | 's': regex.S, 9 | 'l': regex.L, 10 | 'u': regex.U, 11 | 'x': regex.X 12 | } 13 | 14 | 15 | def flags_to_regex_flags(re_flags: Iterable[str]) -> int: 16 | current_flags = 0 17 | for flag in re_flags: 18 | current_flags |= flag_map[flag] 19 | 20 | return current_flags 21 | -------------------------------------------------------------------------------- /incremental_parsing/utils/indexable_container.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, TypeVar 2 | 3 | ReturnType = TypeVar("ReturnType", covariant=True) 4 | 5 | 6 | class IndexableContainer(Protocol[ReturnType]): 7 | def __getitem__(self, key: int) -> ReturnType: 8 | ... 9 | -------------------------------------------------------------------------------- /incremental_parsing/utils/lookback_trie.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from random import random 3 | from typing import Generic, TypeVar, Optional, Tuple, List, Sequence, overload, Dict 4 | 5 | import math 6 | 7 | T = TypeVar('T') 8 | 9 | 10 | class LookbackTrieNode(Generic[T], Sequence[T]): 11 | depth: int 12 | """ 13 | Depth from root; root node has depth 0 14 | """ 15 | 16 | value: Optional[T] 17 | 18 | parent: Optional["LookbackTrieNode[T]"] 19 | 20 | parent_pointer_stack: Tuple[Tuple[int, "LookbackTrieNode[T]"], ...] 21 | """ 22 | Each element represents the depth of a parent node, and the node itself. 23 | These are the non-bottom levels of a skip list 24 | """ 25 | 26 | children: Dict[T, "LookbackTrieNode[T]"] 27 | 28 | @staticmethod 29 | def create_root_node() -> "LookbackTrieNode[T]": 30 | return LookbackTrieNode(0, None, None, ()) 31 | 32 | def __init__(self, depth: int, parent: Optional["LookbackTrieNode[T]"], 33 | value: Optional[T], 34 | parent_pointer_stack: Tuple[Tuple[int, "LookbackTrieNode[T]"], ...]): 35 | self.depth = depth 36 | self.parent = parent 37 | self.value = value 38 | self.parent_pointer_stack = parent_pointer_stack 39 | self.children = {} 40 | 41 | def __len__(self) -> int: 42 | return self.depth 43 | 44 | def get_node_at_depth(self, goal_depth: int): 45 | """ 46 | First value is at depth 1 47 | """ 48 | if goal_depth < 0: 49 | raise IndexError("depth must be nonnegative") 50 | elif goal_depth > self.depth: 51 | raise IndexError("n must be less than or equal to the depth of the node") 52 | 53 | current_node = self 54 | current_node_depth = self.depth 55 | 56 | while current_node_depth > goal_depth: 57 | assert current_node.parent is not None 58 | best_parent = current_node.parent 59 | best_parent_depth = current_node_depth - 1 60 | 61 | for possible_ancestor_depth, possible_ancestor in best_parent.parent_pointer_stack: 62 | if possible_ancestor_depth >= goal_depth: 63 | best_parent = possible_ancestor 64 | best_parent_depth = possible_ancestor_depth 65 | else: 66 | break 67 | 68 | current_node = best_parent 69 | current_node_depth = best_parent_depth 70 | 71 | assert current_node_depth == goal_depth 72 | return current_node 73 | 74 | def get_value_at_index(self, value_index: int): 75 | if value_index < 0: 76 | raise IndexError("value_index must be nonnegative") 77 | elif value_index >= self.depth: 78 | raise IndexError("value_index must be less than the depth of the node") 79 | 80 | return self.get_node_at_depth(value_index + 1).value 81 | 82 | def get_child(self, value: T) -> "LookbackTrieNode[T]": 83 | if value in self.children: 84 | return self.children[value] 85 | 86 | num_levels = math.floor(-math.log2(random())) 87 | 88 | parent_pointer_stack = [(self.depth, self) for _ in range(num_levels)] 89 | parent_pointer_stack.extend(self.parent_pointer_stack[num_levels:]) 90 | 91 | new_node = LookbackTrieNode(self.depth + 1, self, value, tuple(parent_pointer_stack)) 92 | self.children[value] = new_node 93 | return new_node 94 | 95 | def get_n_suffix(self, n: int) -> List[T]: 96 | if n < 0: 97 | raise IndexError("n must be nonnegative") 98 | elif n > self.depth: 99 | raise IndexError("n must be less than or equal to the depth of the node") 100 | 101 | sequence = [] 102 | current_node: LookbackTrieNode[T] = self 103 | for _ in range(n): 104 | assert current_node.parent is not None 105 | assert current_node.value is not None 106 | sequence.append(current_node.value) 107 | current_node = current_node.parent 108 | return list(reversed(sequence)) 109 | 110 | def get_full_sequence(self) -> List[T]: 111 | return self.get_n_suffix(self.depth) 112 | 113 | @overload 114 | def __getitem__(self, index: int) -> T: 115 | ... 116 | 117 | @overload 118 | def __getitem__(self, index: slice) -> List[T]: 119 | ... 120 | 121 | def __getitem__(self, index): 122 | if isinstance(index, int): 123 | real_index = index if index >= 0 else self.depth + index 124 | 125 | return self.get_value_at_index(real_index) 126 | elif isinstance(index, slice): 127 | start, stop, step = index.indices(self.depth) 128 | real_start = start if start >= 0 else self.depth + start 129 | real_stop = stop if stop >= 0 else self.depth + stop 130 | 131 | if step != 1: 132 | raise ValueError("slice step must be 1") 133 | 134 | if real_start >= real_stop: 135 | return [] 136 | 137 | end_node = self.get_node_at_depth(real_stop) 138 | 139 | length = real_stop - real_start 140 | return end_node.get_n_suffix(length) 141 | else: 142 | raise NotImplementedError() 143 | 144 | 145 | class LookbackTrieNodeTest(unittest.TestCase): 146 | 147 | def test_get_full_sequence(self): 148 | node: LookbackTrieNode[str] = LookbackTrieNode.create_root_node() 149 | node = node.get_child('a').get_child('b') 150 | node.get_child("f") 151 | node = node.get_child('c') 152 | node.get_child("g") 153 | self.assertEqual("".join(node.get_full_sequence()), "abc") 154 | 155 | def test_indexing(self): 156 | node: LookbackTrieNode[str] = LookbackTrieNode.create_root_node() 157 | for i in range(100): 158 | node = node.get_child(str(i)) 159 | 160 | for i in range(100): 161 | self.assertEqual(node[i], str(i)) 162 | 163 | full_result = list(str(i) for i in range(100)) 164 | for i in range(-100, 100): 165 | for j in range(-100, 100): 166 | self.assertEqual(node[i:j], full_result[i:j]) 167 | 168 | for i in range(-100, 100): 169 | self.assertEqual(node[i:], full_result[i:]) 170 | self.assertEqual(node[:i], full_result[:i]) 171 | 172 | self.assertEqual(node[100:], []) 173 | self.assertEqual(node[:], full_result) 174 | 175 | def test_reasonable_speed(self): 176 | node: LookbackTrieNode[str] = LookbackTrieNode.create_root_node() 177 | for i in range(100000): 178 | node = node.get_child(str(i)) 179 | 180 | for i in range(100000): 181 | self.assertEqual(node[i], str(i)) 182 | -------------------------------------------------------------------------------- /incremental_parsing/utils/simple_nfa.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | from typing import TypeVar, Generic, AbstractSet, Mapping, Callable, Dict, FrozenSet, Set, Sequence, Iterable, \ 3 | DefaultDict 4 | 5 | P = TypeVar("P") # Pattern Type 6 | A = TypeVar("A") # Atom Type 7 | 8 | 9 | class SimpleNFA(Generic[P, A]): 10 | def __init__(self, states: AbstractSet[int], start_states: AbstractSet[int], end_states: AbstractSet[int], 11 | atom_transitions_forward: Mapping[int, Mapping[P, AbstractSet[int]]], 12 | atom_transitions_backward: Mapping[int, Mapping[P, AbstractSet[int]]], 13 | matcher: Callable[[P, A], bool]): 14 | 15 | self.states = frozenset(states) 16 | self.start_states = frozenset(start_states) 17 | self.end_states = frozenset(end_states) 18 | 19 | self.atom_transitions_forward = {state: {atom: frozenset(dests) 20 | for atom, dests in atom_transitions_forward[state].items()} 21 | for state in atom_transitions_forward} 22 | self.atom_transitions_backward = {state: {atom: frozenset(dests) 23 | for atom, dests in atom_transitions_backward[state].items()} 24 | for state in atom_transitions_backward} 25 | self.matcher = matcher 26 | 27 | def step_state(self, state: int, atom: A, transition_map: Dict[int, Dict[P, FrozenSet[int]]]): 28 | if state not in transition_map: 29 | return frozenset() 30 | 31 | next_states: Set[int] = set() 32 | 33 | for possible_atom in transition_map[state]: 34 | if self.matcher(possible_atom, atom): 35 | next_states.update(transition_map[state][possible_atom]) 36 | 37 | return frozenset(next_states) 38 | 39 | def step_forward(self, state_set: FrozenSet[int], atom: A) -> FrozenSet[int]: 40 | if len(state_set) == 0: 41 | return frozenset() 42 | return frozenset.union(*[self.step_state(state, atom, self.atom_transitions_forward) for state in state_set]) 43 | 44 | def step_forward_any(self, state_set: FrozenSet[int]) -> FrozenSet[int]: 45 | ret_set: Set[int] = set() 46 | for state in state_set: 47 | for atom, dests in self.atom_transitions_forward[state].items(): 48 | ret_set.update(dests) 49 | 50 | return frozenset(ret_set) 51 | 52 | def fullmatch(self, string: Sequence[A]) -> bool: 53 | current_states = self.start_states 54 | for char in string: 55 | current_states = self.step_forward(current_states, char) 56 | if not current_states: 57 | return False 58 | 59 | return any(current_state in self.end_states for current_state in current_states) 60 | 61 | def step_backward(self, state_set: FrozenSet[int], atom: A) -> FrozenSet[int]: 62 | if len(state_set) == 0: 63 | return frozenset() 64 | return frozenset.union(*[self.step_state(state, atom, self.atom_transitions_backward) for state in state_set]) 65 | 66 | def get_reachable_forward(self, state_set: FrozenSet[int]) -> FrozenSet[int]: 67 | return frozenset(get_reachable(state_set, self.atom_transitions_forward)) 68 | 69 | def is_extensible(self, state_set: FrozenSet[int]) -> bool: 70 | return any((state in self.atom_transitions_forward 71 | and any(len(outgoing) > 0 for _, outgoing in self.atom_transitions_forward[state].items())) 72 | for state in state_set) 73 | 74 | 75 | def get_reachable(start_set: Iterable[int], mapping_set: Mapping[int, Mapping[P, AbstractSet[int]]]) -> \ 76 | Set[int]: 77 | reachable = set() 78 | frontier = deque(start_set) 79 | while len(frontier) > 0: 80 | state = frontier.popleft() 81 | if state in reachable: 82 | continue 83 | 84 | reachable.add(state) 85 | if state in mapping_set: 86 | for atom, dests in mapping_set[state].items(): 87 | frontier.extend(dests) 88 | 89 | return reachable 90 | 91 | 92 | class SimpleNFAMutable(Generic[P]): 93 | 94 | def __init__(self): 95 | self.state_counter = 2 96 | self.states = {0, 1} 97 | self.start_states = {0} 98 | self.end_states = {1} 99 | self.eps_transitions_forward: DefaultDict[int, Set[int]] = defaultdict(set) 100 | self.eps_transitions_backward: DefaultDict[int, Set[int]] = defaultdict(set) 101 | self.atom_transitions_forward: DefaultDict[int, DefaultDict[P, Set[int]]] = defaultdict( 102 | lambda: defaultdict(set)) 103 | self.atom_transitions_backward: DefaultDict[int, DefaultDict[P, Set[int]]] = defaultdict( 104 | lambda: defaultdict(set)) 105 | 106 | def add_state(self) -> int: 107 | snum = self.state_counter 108 | self.state_counter += 1 109 | self.states.add(snum) 110 | return snum 111 | 112 | def add_eps_transition(self, origin: int, dest: int): 113 | self.eps_transitions_forward[origin].add(dest) 114 | self.eps_transitions_backward[dest].add(origin) 115 | 116 | def add_atom_transition(self, origin: int, dest: int, pattern: P): 117 | self.atom_transitions_forward[origin][pattern].add(dest) 118 | self.atom_transitions_backward[dest][pattern].add(origin) 119 | 120 | def eliminate_eps_transition(self, origin: int, dest: int): 121 | if origin == dest: 122 | return 123 | 124 | if dest in self.atom_transitions_forward: 125 | for atom, final_dests in self.atom_transitions_forward[dest].items(): 126 | for final_dest in final_dests: 127 | self.atom_transitions_forward[origin][atom].add(final_dest) 128 | self.atom_transitions_backward[final_dest][atom].add(origin) 129 | 130 | if origin in self.start_states: 131 | self.start_states.add(dest) 132 | 133 | if dest in self.end_states: 134 | self.end_states.add(origin) 135 | 136 | def compute_backwards_eps_reachability(self, orig_state: int) -> Set[int]: 137 | reachable = set() 138 | frontier = deque([orig_state]) 139 | while len(frontier) > 0: 140 | state = frontier.popleft() 141 | if state in reachable: 142 | continue 143 | else: 144 | reachable.add(state) 145 | 146 | if state in self.eps_transitions_backward: 147 | for dest in self.eps_transitions_backward[state]: 148 | frontier.append(dest) 149 | 150 | return reachable 151 | 152 | def eliminate_all_eps_transitions(self): 153 | backwards_eps_map = {state: self.compute_backwards_eps_reachability(state) for state in 154 | self.eps_transitions_backward.keys()} 155 | for state, backwards_eps in backwards_eps_map.items(): 156 | for backwards_eps_state in backwards_eps: 157 | if state != backwards_eps_state: 158 | self.eliminate_eps_transition(backwards_eps_state, state) 159 | 160 | self.eps_transitions_forward.clear() 161 | self.eps_transitions_backward.clear() 162 | 163 | def get_reachable_states(self): 164 | reachable_states_forward: Set[int] = get_reachable(self.start_states, self.atom_transitions_forward) 165 | reachable_states_backward: Set[int] = get_reachable(self.end_states, self.atom_transitions_backward) 166 | return reachable_states_forward.intersection(reachable_states_backward) 167 | 168 | def remove_states(self, states_to_remove: Set[int]): 169 | assert len(self.eps_transitions_forward) == 0, "Can only remove states if there are no eps transitions; call " \ 170 | "eliminate_all_eps_transitions first" 171 | 172 | new_states = set() 173 | new_start_states = set() 174 | new_end_states = set() 175 | new_eps_transitions_forward: DefaultDict[int, Set[int]] = defaultdict(set) 176 | new_eps_transitions_backward: DefaultDict[int, Set[int]] = defaultdict(set) 177 | new_atom_transitions_forward: DefaultDict[int, DefaultDict[P, Set[int]]] = defaultdict( 178 | lambda: defaultdict(set)) 179 | new_atom_transitions_backward: DefaultDict[int, DefaultDict[P, Set[int]]] = defaultdict( 180 | lambda: defaultdict(set)) 181 | 182 | for state in self.states: 183 | if state in states_to_remove: 184 | continue 185 | 186 | new_states.add(state) 187 | if state in self.start_states: 188 | new_start_states.add(state) 189 | if state in self.end_states: 190 | new_end_states.add(state) 191 | 192 | for dest in self.eps_transitions_forward[state]: 193 | if dest not in states_to_remove: 194 | new_eps_transitions_forward[state].add(dest) 195 | new_eps_transitions_backward[dest].add(state) 196 | 197 | for atom, dests in self.atom_transitions_forward[state].items(): 198 | for dest in dests: 199 | if dest not in states_to_remove: 200 | new_atom_transitions_forward[state][atom].add(dest) 201 | new_atom_transitions_backward[dest][atom].add(state) 202 | 203 | self.states = new_states 204 | self.start_states = new_start_states 205 | self.end_states = new_end_states 206 | self.eps_transitions_forward = new_eps_transitions_forward 207 | self.eps_transitions_backward = new_eps_transitions_backward 208 | self.atom_transitions_forward = new_atom_transitions_forward 209 | self.atom_transitions_backward = new_atom_transitions_backward 210 | 211 | def remove_unreachable_states(self): 212 | reachable_states = self.get_reachable_states() 213 | self.remove_states(self.states - reachable_states) 214 | 215 | def finalize(self, matcher: Callable[[P, A], bool]): 216 | self.eliminate_all_eps_transitions() 217 | self.remove_unreachable_states() 218 | return SimpleNFA(self.states, self.start_states, self.end_states, 219 | self.atom_transitions_forward, self.atom_transitions_backward, 220 | matcher) 221 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | disable_error_code = import 4 | check_untyped_defs = True -------------------------------------------------------------------------------- /notebooks/create_figures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "initial_id", 7 | "metadata": { 8 | "collapsed": true, 9 | "ExecuteTime": { 10 | "end_time": "2023-08-25T18:16:14.041421Z", 11 | "start_time": "2023-08-25T18:16:12.991872Z" 12 | } 13 | }, 14 | "outputs": [], 15 | "source": [ 16 | "import matplotlib.pyplot as plt\n", 17 | "import pandas as pd\n", 18 | "import numpy as np" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "outputs": [], 25 | "source": [ 26 | "PATH_TO_CSV = \"path/to/results.csv\"" 27 | ], 28 | "metadata": { 29 | "collapsed": false, 30 | "ExecuteTime": { 31 | "end_time": "2023-08-25T18:16:14.046823Z", 32 | "start_time": "2023-08-25T18:16:14.041165Z" 33 | } 34 | }, 35 | "id": "ebf4258605d82e9" 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "outputs": [], 41 | "source": [ 42 | "df = pd.read_csv(PATH_TO_CSV)" 43 | ], 44 | "metadata": { 45 | "collapsed": false, 46 | "ExecuteTime": { 47 | "end_time": "2023-08-25T18:16:14.191685Z", 48 | "start_time": "2023-08-25T18:16:14.043678Z" 49 | } 50 | }, 51 | "id": "f841ad18d056789c" 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "outputs": [], 57 | "source": [ 58 | "def filter_percentile(df, col, percentile):\n", 59 | " threshold = df[col].quantile(percentile)\n", 60 | " return df[df[col] < threshold]" 61 | ], 62 | "metadata": { 63 | "collapsed": false, 64 | "ExecuteTime": { 65 | "end_time": "2023-08-25T18:16:14.197129Z", 66 | "start_time": "2023-08-25T18:16:14.191890Z" 67 | } 68 | }, 69 | "id": "2cb04fc4f5713094" 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "outputs": [], 75 | "source": [ 76 | "df[\"context_size\"] = df[\"prefix_size\"] + df[\"suffix_size\"]\n", 77 | "df = df[df[\"constrained_overhead_eval_p50\"].notnull()]\n" 78 | ], 79 | "metadata": { 80 | "collapsed": false, 81 | "ExecuteTime": { 82 | "end_time": "2023-08-25T18:16:14.204731Z", 83 | "start_time": "2023-08-25T18:16:14.196928Z" 84 | } 85 | }, 86 | "id": "820afa523f8ff143" 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "outputs": [], 92 | "source": [ 93 | "def plot_with_line(df, raw_df, xcol, ycol, scatter_args, line_args):\n", 94 | " plt.scatter(df[xcol], df[ycol], **scatter_args)\n", 95 | " plt.plot(raw_df[xcol], np.poly1d(np.polyfit(df[xcol], df[ycol], 1))(raw_df[xcol]), **line_args)" 96 | ], 97 | "metadata": { 98 | "collapsed": false, 99 | "ExecuteTime": { 100 | "end_time": "2023-08-25T18:16:14.209789Z", 101 | "start_time": "2023-08-25T18:16:14.205372Z" 102 | } 103 | }, 104 | "id": "b421ec9a4911a859" 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "outputs": [], 110 | "source": [ 111 | "plot_with_line(df, df, \"context_size\", \"constrained_overhead_eval_p50\", {\"alpha\": 0.05}, {\"color\": \"blue\", \"label\": \"Constrained\"})\n", 112 | "plot_with_line(df, df, \"context_size\", \"unconstrained_checking_overhead_p50\", {\"alpha\": 0.05}, {\"color\": \"orange\", \"label\": \"Checked Unconstrained\"})\n", 113 | "plt.xlabel(\"Context length (chars)\")\n", 114 | "plt.ylabel(\"Time per token (s)\")\n", 115 | "plt.xlim(0)\n", 116 | "plt.ylim((0, 0.1))\n", 117 | "plt.title(\"Per-token Incremental Quotient Language Check (p50)\")\n", 118 | "plt.legend()\n", 119 | "plt.savefig(\"per_token_eval_performance_p50.png\", dpi=600, format=\"png\")" 120 | ], 121 | "metadata": { 122 | "collapsed": false 123 | }, 124 | "id": "ed8a8407d45f3959" 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "outputs": [], 130 | "source": [ 131 | "plot_with_line(df, df, \"context_size\", \"constrained_overhead_eval_p90\", {\"alpha\": 0.05}, {\"color\": \"blue\", \"label\": \"Constrained\"})\n", 132 | "plot_with_line(df, df, \"context_size\", \"unconstrained_checking_overhead_p90\", {\"alpha\": 0.05}, {\"color\": \"orange\", \"label\": \"Checked Unconstrained\"})\n", 133 | "plt.xlabel(\"Context length (chars)\")\n", 134 | "plt.ylabel(\"Time per token (s)\")\n", 135 | "plt.xlim(0)\n", 136 | "plt.ylim((0, 0.3))\n", 137 | "plt.title(\"Per-token Incremental Quotient Language Check (p90)\")\n", 138 | "plt.legend()\n", 139 | "plt.savefig(\"per_token_eval_performance_p90.png\", dpi=600, format=\"png\")" 140 | ], 141 | "metadata": { 142 | "collapsed": false 143 | }, 144 | "id": "8a9efb555cb19b5c" 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "outputs": [], 150 | "source": [ 151 | "plot_with_line(df, df, \"context_size\", \"pre_time\", {\"alpha\": 0.05, \"color\":\"olive\"}, {\"color\":\"green\"})\n", 152 | "plt.xlabel(\"Context length (chars)\")\n", 153 | "plt.ylabel(\"Time (s)\")\n", 154 | "plt.xlim(0)\n", 155 | "plt.ylim((0, 500))\n", 156 | "plt.title(\"One-time Processing for Constrained Generation\")\n", 157 | "plt.savefig(\"one_time_processing.png\", dpi=600, format=\"png\")" 158 | ], 159 | "metadata": { 160 | "collapsed": false 161 | }, 162 | "id": "3df59a877ee2bd97" 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "Python 3", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 2 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython2", 181 | "version": "2.7.6" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 5 186 | } 187 | -------------------------------------------------------------------------------- /notebooks/create_parse_hierarchy_viz_calc_lang.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "outputs": [], 7 | "source": [ 8 | "%load_ext autoreload\n", 9 | "%autoreload 2" 10 | ], 11 | "metadata": { 12 | "collapsed": false, 13 | "ExecuteTime": { 14 | "end_time": "2023-08-24T16:51:37.764859Z", 15 | "start_time": "2023-08-24T16:51:37.724207Z" 16 | } 17 | } 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Requirement already satisfied: pygraphviz in /Users/dmelcer/miniconda3/envs/IncrementalParsing/lib/python3.9/site-packages (1.9)\r\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "!pip install pygraphviz" 33 | ], 34 | "metadata": { 35 | "collapsed": false, 36 | "ExecuteTime": { 37 | "end_time": "2023-08-24T16:51:38.981049Z", 38 | "start_time": "2023-08-24T16:51:37.753346Z" 39 | } 40 | } 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "outputs": [], 46 | "source": [ 47 | "from incremental_parsing.lex_earley.middle_earley import create_middle_bnf, create_parse_hierarchy\n", 48 | "from incremental_parsing.lex_earley.lark_grammar import get_calc_lang_context\n", 49 | "from incremental_parsing.lex_earley.lex_earley import lex_earley_init, lex_earley_run, _to_suffix_parser_state, force_eof\n", 50 | "from incremental_parsing.lex_earley.earley_nfa import EarleyNFA, tokens_to_nfa\n", 51 | "import networkx as nx" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "ExecuteTime": { 56 | "end_time": "2023-08-24T16:51:39.172830Z", 57 | "start_time": "2023-08-24T16:51:38.981963Z" 58 | } 59 | } 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "outputs": [], 65 | "source": [ 66 | "context = get_calc_lang_context()\n", 67 | "init_state = lex_earley_init(context)" 68 | ], 69 | "metadata": { 70 | "collapsed": false, 71 | "ExecuteTime": { 72 | "end_time": "2023-08-24T16:51:39.198803Z", 73 | "start_time": "2023-08-24T16:51:39.173176Z" 74 | } 75 | } 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "outputs": [], 81 | "source": [ 82 | "prefix = \"(1\"\n", 83 | "suffix = \"\"\n", 84 | "pre_branch_num = 0\n", 85 | "post_branch_num = 0" 86 | ], 87 | "metadata": { 88 | "collapsed": false, 89 | "ExecuteTime": { 90 | "end_time": "2023-08-24T16:51:39.214941Z", 91 | "start_time": "2023-08-24T16:51:39.200003Z" 92 | } 93 | } 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "outputs": [], 99 | "source": [ 100 | "pre = lex_earley_run(context=context, state=init_state, value=prefix)\n", 101 | "g_pre = create_parse_hierarchy(context.grammar, pre.branches[pre_branch_num].parser_state.earley_trie.charts, (len(pre.branches[pre_branch_num].parser_state.earley_trie.charts) - 1,))\n", 102 | "ag = nx.nx_agraph.to_agraph(g_pre.reverse()) # Formalization assumes arrows go in one direction, implementation has arrows go in other direction\n", 103 | "ag.layout(prog=\"dot\")\n", 104 | "ag.draw(\"pre.png\")\n", 105 | "suf = _to_suffix_parser_state(context=context, state=pre, suffix=suffix, make_dummy_trie=True)\n", 106 | "post = lex_earley_run(context=context, state=suf, value=suffix)\n", 107 | "post_eof = force_eof(context, post)\n", 108 | "branch = post_eof.branches[post_branch_num]\n", 109 | "tokens_reverse = list(branch.suffix_state.parser_state.earley_trie.get_reverse_token_sequence())\n", 110 | "token_nfa, final_states = tokens_to_nfa(tokens_reverse)\n", 111 | "earley_nfa = EarleyNFA(context.grammar.reverse(), token_nfa)\n", 112 | "g_post = create_parse_hierarchy(context.grammar, earley_nfa.charts, final_states, reverse_state_positions=True)\n", 113 | "ag = nx.nx_agraph.to_agraph(g_post.reverse())\n", 114 | "ag.layout(prog=\"dot\")\n", 115 | "ag.draw(\"post.png\")\n" 116 | ], 117 | "metadata": { 118 | "collapsed": false, 119 | "ExecuteTime": { 120 | "end_time": "2023-08-24T16:51:39.596882Z", 121 | "start_time": "2023-08-24T16:51:39.216128Z" 122 | } 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "comma_expr : COMMA expression\n", 134 | " | COMMA expression comma_expr\n", 135 | "\n", 136 | "expression : LPAR expression RPAR\n", 137 | " | LPAR expression comma_expr RPAR\n", 138 | " | expression binop expression\n", 139 | " | unop expression\n", 140 | " | NUMBER\n", 141 | "\n", 142 | "binop : MINUS\n", 143 | " | PLUS\n", 144 | " | STAR\n", 145 | " | SLASH\n", 146 | "\n", 147 | "unop : MINUS\n", 148 | "\n", 149 | "start<0-> : expression<0->\n", 150 | "\n", 151 | "expression<0-> : RPAR\n", 152 | " | comma_expr RPAR\n", 153 | " | expression<0-> binop expression\n", 154 | " | expression<1-> RPAR\n", 155 | " | expression<1-> comma_expr RPAR\n", 156 | "\n", 157 | "expression<1-> : binop expression\n", 158 | " | expression<1-> binop expression\n", 159 | "('start<0->',)\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "r = create_middle_bnf(context.grammar, g_pre, g_post, (len(pre.branches[pre_branch_num].parser_state.earley_trie.charts) - 1,), final_states)\n", 165 | "print(str(r))\n", 166 | "print(r.top_level_rules)" 167 | ], 168 | "metadata": { 169 | "collapsed": false, 170 | "ExecuteTime": { 171 | "end_time": "2023-08-24T16:51:39.614994Z", 172 | "start_time": "2023-08-24T16:51:39.597473Z" 173 | } 174 | } 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 8, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "comma_expr : COMMA expression\n", 185 | " | COMMA expression comma_expr\n", 186 | "\n", 187 | "expression : LPAR expression RPAR\n", 188 | " | LPAR expression comma_expr RPAR\n", 189 | " | expression binop expression\n", 190 | " | unop expression\n", 191 | " | NUMBER\n", 192 | "\n", 193 | "binop : MINUS\n", 194 | " | PLUS\n", 195 | " | STAR\n", 196 | " | SLASH\n", 197 | "\n", 198 | "unop : MINUS\n", 199 | "\n", 200 | "expression<1-> : binop<2-> expression\n", 201 | " | λ\n", 202 | " | binop expression\n", 203 | " | expression<1-> binop expression\n", 204 | "\n", 205 | "expression<0-> : RPAR\n", 206 | " | expression<1-> RPAR\n", 207 | " | expression<1-> comma_expr RPAR\n", 208 | " | expression<0-> binop expression\n", 209 | " | comma_expr RPAR\n", 210 | " | comma_expr<2-> RPAR\n", 211 | "\n", 212 | "comma_expr<2-> : COMMA expression\n", 213 | " | COMMA expression comma_expr\n", 214 | "\n", 215 | "binop<2-> : SLASH\n", 216 | " | PLUS\n", 217 | " | MINUS\n", 218 | " | STAR\n", 219 | "\n", 220 | "start<0-> : expression<0->\n", 221 | "('start<0->',)\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "from incremental_parsing.lex_earley.middle_earley import create_bnf_direct\n", 227 | "\n", 228 | "relevant_charts = pre.branches[pre_branch_num].parser_state.earley_trie.charts\n", 229 | "r = create_bnf_direct(grammar=context.grammar, final_chart_indices=(len(relevant_charts) - 1,), charts=relevant_charts, is_right_context=False)\n", 230 | "print(str(r))\n", 231 | "print(r.top_level_rules)" 232 | ], 233 | "metadata": { 234 | "collapsed": false, 235 | "ExecuteTime": { 236 | "end_time": "2023-08-24T16:51:39.631Z", 237 | "start_time": "2023-08-24T16:51:39.613359Z" 238 | } 239 | } 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 2 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython2", 258 | "version": "2.7.6" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 0 263 | } 264 | -------------------------------------------------------------------------------- /notebooks/example_parsing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "source": [ 7 | "%load_ext autoreload\n", 8 | "%autoreload 2" 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | }, 13 | "outputs": [] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "source": [ 19 | "from incremental_parsing.lex_earley.lark_grammar import get_python_context\n", 20 | "from incremental_parsing.lex_earley.lex_earley import lex_earley_init, lex_earley_run, lex_earley_to_middle, is_complete\n", 21 | "from incremental_parsing.evaluation.text_cuts import cut_text_random\n", 22 | "\n", 23 | "import datasets\n", 24 | "import random" 25 | ], 26 | "metadata": { 27 | "collapsed": false 28 | }, 29 | "outputs": [] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "source": [ 35 | "context = get_python_context()\n", 36 | "init_state = lex_earley_init(context)\n", 37 | "dataset = datasets.load_dataset(\"bigcode/the-stack-smol-xl\", data_dir=\"data/python\")[\"train\"]" 38 | ], 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "outputs": [], 47 | "source": [ 48 | "DATA_IDX = 0\n", 49 | "CUT_IDX = 0" 50 | ], 51 | "metadata": { 52 | "collapsed": false 53 | }, 54 | "execution_count": null 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "source": [ 60 | "# Load data from the stack\n", 61 | "data = dataset[DATA_IDX][\"content\"]\n", 62 | "random.seed(hash((DATA_IDX, CUT_IDX)) % (2 ** 32))\n", 63 | "prefix, hidden_middle, suffix = cut_text_random(data, 0, .9, .2, None)\n", 64 | "suffix += \"\\n\" # This line is important\n", 65 | "pre_branch_num = 0\n", 66 | "post_branch_num = 0\n" 67 | ], 68 | "metadata": { 69 | "collapsed": false 70 | }, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "outputs": [], 76 | "source": [ 77 | "print(prefix)" 78 | ], 79 | "metadata": { 80 | "collapsed": false 81 | }, 82 | "execution_count": null 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "source": [ 88 | "# Or manually specify a prefix and suffix\n", 89 | "if False:\n", 90 | " prefix = \"\"\n", 91 | " suffix = \"else d\\n\"\n", 92 | " pre_branch_num = 0\n", 93 | " post_branch_num = 0" 94 | ], 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "source": [ 104 | "pre = lex_earley_run(context=context, state=init_state, value=prefix)" 105 | ], 106 | "metadata": { 107 | "collapsed": false 108 | }, 109 | "outputs": [] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "outputs": [], 114 | "source": [ 115 | "mid = lex_earley_to_middle(context, pre, suffix, \"\")" 116 | ], 117 | "metadata": { 118 | "collapsed": false 119 | }, 120 | "execution_count": null 121 | }, 122 | { 123 | "cell_type": "code", 124 | "outputs": [], 125 | "source": [ 126 | "post = lex_earley_run(context, mid, hidden_middle);" 127 | ], 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "execution_count": null 132 | }, 133 | { 134 | "cell_type": "code", 135 | "outputs": [], 136 | "source": [ 137 | "is_complete(context, post)" 138 | ], 139 | "metadata": { 140 | "collapsed": false 141 | }, 142 | "execution_count": null 143 | }, 144 | { 145 | "cell_type": "code", 146 | "outputs": [], 147 | "source": [], 148 | "metadata": { 149 | "collapsed": false 150 | } 151 | } 152 | ], 153 | "metadata": { 154 | "kernelspec": { 155 | "display_name": "Python 3", 156 | "language": "python", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 2 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython2", 169 | "version": "2.7.6" 170 | } 171 | }, 172 | "nbformat": 4, 173 | "nbformat_minor": 0 174 | } 175 | -------------------------------------------------------------------------------- /notebooks/interactive_constrained_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "source": [ 7 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n", 8 | "\n", 9 | "from incremental_parsing.generation.constrained_generation import prefix_suffix_constrained_generation\n", 10 | "from incremental_parsing.lex_earley.lark_grammar import get_python_context\n", 11 | "from ansi.colour import bg, fg\n", 12 | "import datetime" 13 | ], 14 | "metadata": { 15 | "collapsed": false 16 | }, 17 | "outputs": [] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ], 26 | "metadata": { 27 | "collapsed": false 28 | }, 29 | "outputs": [] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "source": [ 35 | "BEAM_SIZE = 1\n", 36 | "MAX_GENERATION_LENGTH = 200" 37 | ], 38 | "metadata": { 39 | "collapsed": false 40 | }, 41 | "outputs": [] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "source": [ 47 | "# NOTE: If you get a warning in this cell about weights not being fine-tuned, use an older version of the transformers library\n", 48 | "MODEL_NAME = \"bigcode/santacoder\"\n", 49 | "DEVICE = \"cuda:0\"\n", 50 | "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", 51 | "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(DEVICE)\n", 52 | "context = get_python_context()" 53 | ], 54 | "metadata": { 55 | "collapsed": false 56 | }, 57 | "outputs": [] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "source": [ 63 | "import random\n", 64 | "import datasets\n", 65 | "from incremental_parsing.evaluation.text_cuts import cut_text_random\n", 66 | "import transformers\n", 67 | "\n", 68 | "idx = 0\n", 69 | "cut = 0\n", 70 | "\n", 71 | "dataset = datasets.load_dataset(\"bigcode/the-stack-smol-xl\", data_dir=\"data/python\")[\"train\"]\n", 72 | "data = dataset[idx][\"content\"]\n", 73 | "random.seed(hash((idx, cut)) % (2 ** 32))\n", 74 | "prefix_text, middle, suffix_text = cut_text_random(data, 0, .9, .2, None)\n", 75 | "suffix_text += \"\\n\"\n", 76 | "\n", 77 | "\n", 78 | "begin = datetime.datetime.now()\n", 79 | "transformers.set_seed(hash((idx, cut, 0)) % (2 ** 32))\n", 80 | "\n", 81 | "middle_text, *_ = prefix_suffix_constrained_generation(\n", 82 | " tokenizer=tokenizer, model=model, context=context, prefix_text=prefix_text,\n", 83 | " suffix_text=suffix_text, beam_size=BEAM_SIZE, max_generation_length=MAX_GENERATION_LENGTH, device=DEVICE,\n", 84 | " debug=True\n", 85 | ")\n", 86 | "end = datetime.datetime.now()\n", 87 | "td = end - begin\n", 88 | "\n", 89 | "if middle_text is None:\n", 90 | " print(\"Generation failed\")\n", 91 | "else:\n", 92 | " print(prefix_text + bg.boldgreen(fg.black(middle_text)) + suffix_text)\n", 93 | "\n", 94 | "print(f\"{int(td.total_seconds())}.{td.microseconds // 1000:03} seconds elapsed\")" 95 | ], 96 | "metadata": { 97 | "collapsed": false 98 | }, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "source": [ 105 | "import ast\n", 106 | "ast.parse(prefix_text + middle_text + suffix_text)" 107 | ], 108 | "metadata": { 109 | "collapsed": false 110 | }, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "source": [ 117 | "# Sometimes it won't display properly if the source file used the wrong line endings. Use these next two cells to display result in that case\n", 118 | "print(prefix_text.replace(\"\\r\", \"\\n\") + bg.boldgreen(fg.black(middle_text.replace(\"\\r\", \"\\n\")) + suffix_text.replace(\"\\r\", \"\\n\")))" 119 | ], 120 | "metadata": { 121 | "collapsed": false 122 | }, 123 | "outputs": [] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "source": [ 129 | "print(data.replace(\"\\r\", \"\\n\"))" 130 | ], 131 | "metadata": { 132 | "collapsed": false 133 | }, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "outputs": [], 139 | "source": [], 140 | "metadata": { 141 | "collapsed": false 142 | } 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 2 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython2", 161 | "version": "2.7.6" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 0 166 | } 167 | -------------------------------------------------------------------------------- /notebooks/interactive_recognition.py: -------------------------------------------------------------------------------- 1 | import curses 2 | from typing import List, Tuple 3 | 4 | from incremental_parsing.lex_earley.earley_nfa import EarleyNFA 5 | from incremental_parsing.lex_earley.earley_trie import EarleyTrieNode 6 | from incremental_parsing.lex_earley.lark_grammar import get_python_context 7 | from incremental_parsing.lex_earley.lex_earley import lex_earley_init, LexEarleyAlgorithmState, is_complete, \ 8 | is_completable, lex_earley_step, lex_earley_to_middle 9 | 10 | 11 | def run(stdscr): 12 | context = get_python_context() 13 | current_state: LexEarleyAlgorithmState = lex_earley_init(context, earley_class=EarleyTrieNode) 14 | 15 | if True: 16 | current_state = lex_earley_to_middle(context, current_state, "", "", 17 | earley_class=EarleyTrieNode, earley_nfa_class=EarleyNFA) 18 | 19 | tokens: List[Tuple[str, LexEarleyAlgorithmState]] = [] 20 | 21 | while True: 22 | stdscr.addstr(0, 0, "Input:") 23 | 24 | input_str = "".join(t[0] for t in tokens) 25 | stdscr.addstr(1, 0, input_str) 26 | 27 | stdscr.addstr("\n=====\n") 28 | 29 | if current_state is None: 30 | stdscr.addstr("Invalid") 31 | elif is_complete(context, current_state): 32 | stdscr.addstr("Complete") 33 | elif is_completable(current_state): 34 | stdscr.addstr("Completable") 35 | else: 36 | stdscr.addstr("Invalid") 37 | 38 | c = stdscr.getch() 39 | stdscr.clear() 40 | 41 | add = False 42 | char = None 43 | 44 | if 32 <= c < 127: 45 | char = chr(c) 46 | add = True 47 | elif c in (curses.KEY_ENTER, 10, 13): 48 | add = True 49 | char = "\n" 50 | elif c in (curses.KEY_BACKSPACE, 127): 51 | if len(tokens) > 0: 52 | _, current_state = tokens.pop() 53 | 54 | if current_state is None: 55 | add = False 56 | 57 | if add: 58 | assert char is not None 59 | tokens.append((char, current_state)) 60 | current_state = lex_earley_step(context, current_state, char) 61 | # assert len(current_state.branches) <= 2 62 | 63 | 64 | if __name__ == '__main__': 65 | curses.wrapper(run) 66 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.0,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [tool.maturin] 6 | module-name = "incremental_parsing._native" 7 | features = ["pyo3/extension-module"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | urllib3 2 | transformers==4.30.2 3 | torch 4 | torchvision 5 | torchaudio 6 | regex 7 | lark~=1.1.5 8 | datasets 9 | tokenizers 10 | accelerate 11 | bitsandbytes 12 | scipy 13 | tqdm 14 | hapless 15 | networkx 16 | mypy 17 | types-urllib3 18 | types-regex 19 | types-tqdm 20 | jupyter 21 | matplotlib 22 | ansi 23 | cchardet 24 | pandas 25 | numpy 26 | maturin 27 | scikit-learn -------------------------------------------------------------------------------- /scripts/constrained_generation_nice_cuts.sh: -------------------------------------------------------------------------------- 1 | for i in {0..39}; do 2 | device_num=$((i % 8)) 3 | device_name="cuda:$device_num" 4 | 5 | RUST_BACKTRACE=1 hap run -- python incremental_parsing/generation/stack_completion_right_context.py \ 6 | --min-seed 0 --max-seed 0 --min-data-idx $((i * 250)) --max-data-idx $((i * 250 + 249)) \ 7 | --min-cut-idx 0 --max-cut-idx 9 --beam-size 1 --nice-cuts \ 8 | --model-name "bigcode/santacoder" --dataset-name "bigcode/the-stack-smol-xl" --dataset-path "data/python" \ 9 | --results-path "/path/to/results" \ 10 | --device $device_name --max-tokens 500 --num-token-finding-attempts 50 11 | donae -------------------------------------------------------------------------------- /scripts/constrained_generation_random_cuts.sh: -------------------------------------------------------------------------------- 1 | for i in {0..39}; do 2 | device_num=$((i % 8)) 3 | device_name="cuda:$device_num" 4 | 5 | RUST_BACKTRACE=1 hap run -- python incremental_parsing/generation/stack_completion_right_context.py \ 6 | --min-seed 0 --max-seed 0 --min-data-idx $((i * 250)) --max-data-idx $((i * 250 + 249)) \ 7 | --min-cut-idx 0 --max-cut-idx 9 --beam-size 1 \ 8 | --model-name "bigcode/santacoder" --dataset-name "bigcode/the-stack-smol-xl" --dataset-path "data/python" \ 9 | --results-path "/path/to/results" \ 10 | --device $device_name --max-tokens 500 --num-token-finding-attempts 50 11 | done -------------------------------------------------------------------------------- /src/bridge.rs: -------------------------------------------------------------------------------- 1 | pub mod bnf; 2 | pub mod charts; 3 | -------------------------------------------------------------------------------- /src/bridge/bnf.rs: -------------------------------------------------------------------------------- 1 | use crate::grammar::bnf::{BNFElement, BNFProduction, BNFRule, NonterminalIdx, TerminalIdx, BNF}; 2 | use crate::grammar::names::BNFNames; 3 | use pyo3::prelude::*; 4 | use std::collections::{HashMap, HashSet}; 5 | 6 | pub fn extension_bnf(m: &Bound<'_, PyModule>) -> PyResult<()> { 7 | m.add_class::()?; 8 | m.add_class::()?; 9 | m.add_class::()?; 10 | m.add_class::()?; 11 | m.add_class::()?; 12 | Ok(()) 13 | } 14 | 15 | #[pyclass] 16 | #[derive(Clone, Debug)] 17 | pub enum BridgeBNFElement { 18 | Terminal { name: String }, 19 | Nonterminal { name: String }, 20 | } 21 | 22 | #[pymethods] 23 | impl BridgeBNFElement { 24 | fn __repr__(&self) -> String { 25 | format!("{:?}", self) 26 | } 27 | } 28 | 29 | impl BridgeBNFElement { 30 | pub fn collect_terminals(&self, terminals: &mut HashSet) { 31 | match self { 32 | BridgeBNFElement::Terminal { name } => { 33 | terminals.insert(name.clone()); 34 | } 35 | BridgeBNFElement::Nonterminal { name: _ } => {} 36 | } 37 | } 38 | 39 | pub fn to_native( 40 | &self, 41 | terminal_map: &HashMap, 42 | nonterminal_map: &HashMap, 43 | ) -> BNFElement { 44 | match self { 45 | BridgeBNFElement::Terminal { name } => BNFElement::BNFTerminal(terminal_map[name]), 46 | BridgeBNFElement::Nonterminal { name } => { 47 | BNFElement::BNFNonterminal(nonterminal_map[name]) 48 | } 49 | } 50 | } 51 | } 52 | 53 | #[pyclass] 54 | #[derive(Clone, Debug)] 55 | pub struct BridgeBNFProduction { 56 | #[pyo3(get, set)] 57 | pub elements: Vec, 58 | } 59 | 60 | #[pymethods] 61 | impl BridgeBNFProduction { 62 | #[new] 63 | pub fn new(elements: Vec) -> Self { 64 | Self { elements } 65 | } 66 | 67 | pub fn __repr__(&self) -> String { 68 | format!("{:?}", self) 69 | } 70 | } 71 | 72 | impl BridgeBNFProduction { 73 | pub fn collect_terminals(&self, terminals: &mut HashSet) { 74 | for element in &self.elements { 75 | element.collect_terminals(terminals); 76 | } 77 | } 78 | 79 | pub fn to_native( 80 | &self, 81 | terminal_map: &HashMap, 82 | nonterminal_map: &HashMap, 83 | ) -> BNFProduction { 84 | BNFProduction { 85 | elements: self 86 | .elements 87 | .iter() 88 | .map(|element| element.to_native(terminal_map, nonterminal_map)) 89 | .collect(), 90 | } 91 | } 92 | } 93 | 94 | #[pyclass] 95 | #[derive(Clone, Debug)] 96 | pub struct BridgeBNFRule { 97 | #[pyo3(get, set)] 98 | productions: Vec, 99 | } 100 | 101 | #[pymethods] 102 | impl BridgeBNFRule { 103 | #[new] 104 | pub fn new(productions: Vec) -> Self { 105 | Self { productions } 106 | } 107 | 108 | pub fn __repr__(&self) -> String { 109 | format!("{:?}", self) 110 | } 111 | } 112 | 113 | impl BridgeBNFRule { 114 | pub fn collect_terminals(&self, terminals: &mut HashSet) { 115 | for production in &self.productions { 116 | production.collect_terminals(terminals); 117 | } 118 | } 119 | 120 | pub fn to_native(&self, names: &BNFNames) -> BNFRule { 121 | BNFRule { 122 | productions: self 123 | .productions 124 | .iter() 125 | .map(|production| { 126 | production.to_native(names.terminal_map(), names.nonterminal_map()) 127 | }) 128 | .collect(), 129 | } 130 | } 131 | } 132 | 133 | #[pyclass] 134 | #[derive(Clone, Debug)] 135 | pub struct BridgeBNFGrammar { 136 | #[pyo3(get, set)] 137 | pub rules: HashMap, 138 | 139 | #[pyo3(get, set)] 140 | pub top_level_rules: Vec, 141 | } 142 | 143 | #[pymethods] 144 | impl BridgeBNFGrammar { 145 | #[new] 146 | pub fn new(rules: HashMap, top_level_rules: Vec) -> Self { 147 | Self { 148 | rules, 149 | top_level_rules, 150 | } 151 | } 152 | 153 | pub fn __repr__(&self) -> String { 154 | format!("{:?}", self) 155 | } 156 | 157 | pub fn to_native(&self) -> NativeGrammar { 158 | self.clone().into() 159 | } 160 | } 161 | 162 | impl BridgeBNFGrammar { 163 | pub fn collect_terminals_nonterminals(&self) -> (Vec, Vec) { 164 | let mut terminals: HashSet = HashSet::new(); 165 | 166 | for rule in self.rules.values() { 167 | rule.collect_terminals(&mut terminals); 168 | } 169 | 170 | let nonterminals: Vec = self.rules.keys().cloned().collect(); 171 | 172 | (terminals.into_iter().collect(), nonterminals) 173 | } 174 | } 175 | 176 | #[pyclass(frozen)] 177 | #[derive(Clone)] 178 | pub struct NativeGrammar { 179 | pub grammar: BNF, 180 | pub names: BNFNames, 181 | } 182 | 183 | #[pymethods] 184 | impl NativeGrammar { 185 | fn __repr__(&self) -> String { 186 | "".to_string() 187 | } 188 | } 189 | 190 | impl From for NativeGrammar { 191 | fn from(bnf: BridgeBNFGrammar) -> Self { 192 | let (terminals, nonterminals) = bnf.collect_terminals_nonterminals(); 193 | let names = BNFNames::new(terminals, nonterminals); 194 | 195 | let rules = names 196 | .nonterminals() 197 | .iter() 198 | .map(|ntname| bnf.rules[ntname].to_native(&names)) 199 | .collect(); 200 | 201 | let top_level_rules = bnf 202 | .top_level_rules 203 | .iter() 204 | .map(|ntname| names.nonterminal_idx(ntname)) 205 | .collect(); 206 | 207 | NativeGrammar { 208 | grammar: BNF::new(rules, top_level_rules), 209 | names, 210 | } 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /src/bridge/charts.rs: -------------------------------------------------------------------------------- 1 | use crate::bridge::bnf::NativeGrammar; 2 | use crate::grammar::bnf::TerminalIdx; 3 | use crate::parser::earley::EarleyStateCreationMethod::TopLevel; 4 | use crate::parser::earley::{get_all_initial_earley_states, EarleyChartIdx, EarleyCollection}; 5 | use pyo3::prelude::*; 6 | use pyo3::types::PyTuple; 7 | use pyo3::{Bound, PyResult}; 8 | use std::collections::HashSet; 9 | 10 | pub fn extension_charts(m: &Bound<'_, PyModule>) -> PyResult<()> { 11 | m.add_class::()?; 12 | Ok(()) 13 | } 14 | 15 | #[pyclass] 16 | pub struct NativeEarleyCharts { 17 | earley_collection: EarleyCollection, 18 | } 19 | 20 | #[pymethods] 21 | impl NativeEarleyCharts { 22 | #[staticmethod] 23 | pub fn create_initial_earley_charts<'a>( 24 | python: Python<'a>, 25 | grammar: &NativeGrammar, 26 | ) -> Py { 27 | let mut collection = EarleyCollection::new(); 28 | let root_chart = collection.add_chart(); 29 | for state in get_all_initial_earley_states(&grammar.grammar, root_chart) { 30 | collection.add_state(&grammar.grammar, root_chart, state, TopLevel); 31 | } 32 | 33 | let charts = NativeEarleyCharts { 34 | earley_collection: collection, 35 | }; 36 | let this_chart = &charts.earley_collection[root_chart]; 37 | 38 | let allowed_terminals: HashSet = 39 | this_chart.allowed_terminals(&grammar.grammar); 40 | // Terminals are already deduped, so okay to just use vec 41 | let allowed_terminal_names: Vec<&String> = allowed_terminals 42 | .into_iter() 43 | .map(|terminal| grammar.names.terminal(terminal)) 44 | .collect(); 45 | let completeable = this_chart.is_completable(); 46 | let complete = this_chart.is_complete(&grammar.grammar); 47 | 48 | let result_tup = ( 49 | charts, 50 | root_chart.0, 51 | allowed_terminal_names, 52 | completeable, 53 | complete, 54 | ); 55 | 56 | result_tup.into_py(python) 57 | } 58 | 59 | pub fn parse<'a>( 60 | &mut self, 61 | python: Python<'a>, 62 | grammar: &NativeGrammar, 63 | chart_idx: usize, 64 | token_name: String, 65 | ) -> Py { 66 | let src_chart_idx = EarleyChartIdx(chart_idx); 67 | 68 | let this_chart_idx = self.earley_collection.add_chart(); 69 | let terminal_idx = grammar.names.terminal_idx(&token_name); 70 | if let Some(terminal_idx) = terminal_idx { 71 | // If terminal isn't in the grammar, adding a link won't do anything anyway 72 | self.earley_collection.add_link( 73 | &grammar.grammar, 74 | terminal_idx, 75 | src_chart_idx, 76 | this_chart_idx, 77 | ); 78 | } 79 | 80 | let this_chart = &self.earley_collection[this_chart_idx]; 81 | 82 | let allowed_terminals: HashSet = 83 | this_chart.allowed_terminals(&grammar.grammar); 84 | let allowed_terminal_names: Vec<&String> = allowed_terminals 85 | .into_iter() 86 | .map(|terminal| grammar.names.terminal(terminal)) 87 | .collect(); 88 | let completable = this_chart.is_completable(); 89 | let complete = this_chart.is_complete(&grammar.grammar); 90 | 91 | let result_tup = ( 92 | this_chart_idx.0, 93 | allowed_terminal_names, 94 | completable, 95 | complete, 96 | ); 97 | result_tup.into_py(python) 98 | } 99 | 100 | pub fn get_chart_len(&self, chart_idx: usize) -> usize { 101 | self.earley_collection[EarleyChartIdx(chart_idx)] 102 | .states 103 | .len() 104 | } 105 | 106 | pub fn get_earley_state( 107 | &self, 108 | python: Python<'_>, 109 | native_grammar: &NativeGrammar, 110 | chart_idx: usize, 111 | state_idx: usize, 112 | ) -> Py { 113 | let (state, creation_methods) = 114 | &self.earley_collection[EarleyChartIdx(chart_idx)].states[state_idx]; 115 | let nonterminal_name = native_grammar.names.nonterminal(state.nonterminal); 116 | 117 | let mut creation_methods_serialized = vec![]; 118 | for creation_method in creation_methods { 119 | creation_method.serialize(&mut creation_methods_serialized); 120 | } 121 | 122 | let result_tup = ( 123 | state.span_start.0, 124 | nonterminal_name, 125 | state.production_index, 126 | state.dot_index, 127 | state.production_length, 128 | creation_methods_serialized, 129 | ); 130 | result_tup.into_py(python) 131 | } 132 | 133 | #[staticmethod] 134 | pub fn create_earley_nfa<'a>( 135 | grammar: &NativeGrammar, 136 | num_charts: usize, 137 | start_charts: Vec, 138 | transitions: Vec<(usize, usize, String)>, 139 | ) -> NativeEarleyCharts { 140 | let mut collection = EarleyCollection::new(); 141 | 142 | for _ in 0..num_charts { 143 | collection.add_chart(); 144 | } 145 | 146 | for start_chart in start_charts { 147 | for state in 148 | get_all_initial_earley_states(&grammar.grammar, EarleyChartIdx(start_chart)) 149 | { 150 | collection.add_state( 151 | &grammar.grammar, 152 | EarleyChartIdx(start_chart), 153 | state, 154 | TopLevel, 155 | ); 156 | } 157 | } 158 | 159 | for (origin_chart, dest_chart, terminal) in transitions { 160 | if let Some(terminal_idx) = grammar.names.terminal_idx(&terminal) { 161 | collection.add_link( 162 | &grammar.grammar, 163 | terminal_idx, 164 | EarleyChartIdx(origin_chart), 165 | EarleyChartIdx(dest_chart), 166 | ); 167 | } 168 | } 169 | 170 | NativeEarleyCharts { 171 | earley_collection: collection, 172 | } 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/grammar.rs: -------------------------------------------------------------------------------- 1 | pub mod bnf; 2 | pub mod names; 3 | -------------------------------------------------------------------------------- /src/grammar/bnf.rs: -------------------------------------------------------------------------------- 1 | use crate::grammar::names::BNFNames; 2 | use std::collections::{HashMap, HashSet}; 3 | use std::fmt::Display; 4 | use std::hash::Hash; 5 | use std::ops::Index; 6 | 7 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] 8 | pub struct TerminalIdx(pub usize); 9 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] 10 | pub struct NonterminalIdx(pub usize); 11 | 12 | #[derive(Hash, Eq, Copy, Clone, PartialEq)] 13 | pub enum BNFElement { 14 | BNFTerminal(TerminalIdx), 15 | BNFNonterminal(NonterminalIdx), 16 | } 17 | 18 | impl BNFElement { 19 | pub fn to_string( 20 | &self, 21 | names: &BNFNames, 22 | ) -> String { 23 | match self { 24 | BNFElement::BNFTerminal(t) => names.terminal(*t).to_string(), 25 | BNFElement::BNFNonterminal(nt) => names.nonterminal(*nt).to_string(), 26 | } 27 | } 28 | } 29 | 30 | #[derive(Clone, Eq, PartialEq, Hash)] 31 | pub struct BNFProduction { 32 | pub elements: Vec, 33 | } 34 | 35 | impl BNFProduction { 36 | pub fn reverse(&self) -> Self { 37 | BNFProduction { 38 | elements: self.elements.iter().cloned().rev().collect(), 39 | } 40 | } 41 | 42 | pub fn to_string( 43 | &self, 44 | names: &BNFNames, 45 | ) -> String { 46 | self.elements 47 | .iter() 48 | .map(|e| e.to_string(names)) 49 | .collect::>() 50 | .join(" ") 51 | } 52 | } 53 | 54 | #[derive(Clone, Eq, PartialEq, Hash)] 55 | pub struct BNFRule { 56 | pub productions: Vec, 57 | } 58 | 59 | impl BNFRule { 60 | pub fn reverse(&self) -> Self { 61 | BNFRule { 62 | productions: self.productions.iter().map(|p| p.reverse()).collect(), 63 | } 64 | } 65 | } 66 | 67 | #[derive(Clone)] 68 | pub struct BNF { 69 | rules: Vec, 70 | nullable_rules: HashSet, 71 | top_level_rules: Vec, 72 | } 73 | 74 | fn get_nullable_rules(rules: &Vec) -> HashSet { 75 | let mut reverse_rule_mapping = HashMap::new(); 76 | for (key, value) in rules.iter().enumerate() { 77 | for production in value.productions.iter() { 78 | for element in production.elements.iter() { 79 | match element { 80 | BNFElement::BNFNonterminal(nt) => { 81 | reverse_rule_mapping 82 | .entry(*nt) 83 | .or_insert(Vec::new()) 84 | .push(NonterminalIdx(key)); 85 | } 86 | BNFElement::BNFTerminal(_) => {} 87 | } 88 | } 89 | } 90 | } 91 | 92 | let mut nullable_rules = HashSet::new(); 93 | let mut queue = Vec::new(); 94 | 95 | for (key, rule) in rules.iter().enumerate() { 96 | if rule.productions.iter().any(|p| p.elements.is_empty()) { 97 | let key = NonterminalIdx(key); 98 | nullable_rules.insert(key); 99 | queue.push(key); 100 | } 101 | } 102 | 103 | while let Some(rule_key) = queue.pop() { 104 | if let Some(referencer_rules) = reverse_rule_mapping.get(&rule_key) { 105 | for referencer_rule in referencer_rules.iter() { 106 | if !nullable_rules.contains(referencer_rule) { 107 | for referencer_rule_production in rules[referencer_rule.0].productions.iter() { 108 | if referencer_rule_production 109 | .elements 110 | .iter() 111 | .all(|element| match element { 112 | BNFElement::BNFNonterminal(ntidx) => nullable_rules.contains(ntidx), 113 | _ => false, 114 | }) 115 | { 116 | nullable_rules.insert(*referencer_rule); 117 | queue.push(*referencer_rule); 118 | } 119 | } 120 | } 121 | } 122 | } 123 | } 124 | 125 | nullable_rules 126 | } 127 | 128 | impl BNF { 129 | pub fn new(rules: Vec, top_level_rules: Vec) -> Self { 130 | let nullable_rules = get_nullable_rules(&rules); 131 | 132 | BNF { 133 | rules, 134 | nullable_rules, 135 | top_level_rules, 136 | } 137 | } 138 | 139 | pub fn nullable_rules(&self) -> &HashSet { 140 | &self.nullable_rules 141 | } 142 | 143 | pub fn top_level_rules(&self) -> &Vec { 144 | &self.top_level_rules 145 | } 146 | } 147 | 148 | impl Index for BNF { 149 | type Output = BNFRule; 150 | 151 | fn index(&self, index: NonterminalIdx) -> &Self::Output { 152 | &self.rules[index.0] 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/grammar/names.rs: -------------------------------------------------------------------------------- 1 | use crate::grammar::bnf::{NonterminalIdx, TerminalIdx}; 2 | use std::collections::HashMap; 3 | use std::fmt::Display; 4 | use std::hash::Hash; 5 | 6 | #[derive(Debug, Clone)] 7 | pub struct BNFNames { 8 | terminals: Vec, 9 | nonterminals: Vec, 10 | terminal_map: HashMap, 11 | nonterminal_map: HashMap, 12 | } 13 | 14 | impl BNFNames { 15 | pub fn new(terminals: Vec, nonterminals: Vec) -> Self { 16 | let terminal_map = terminals 17 | .iter() 18 | .cloned() 19 | .enumerate() 20 | .map(|(i, t)| (t, TerminalIdx(i))) 21 | .collect(); 22 | let nonterminal_map = nonterminals 23 | .iter() 24 | .cloned() 25 | .enumerate() 26 | .map(|(i, t)| (t, NonterminalIdx(i))) 27 | .collect(); 28 | 29 | Self { 30 | terminals, 31 | nonterminals, 32 | terminal_map, 33 | nonterminal_map, 34 | } 35 | } 36 | 37 | pub fn terminal(&self, idx: TerminalIdx) -> &T { 38 | &self.terminals[idx.0] 39 | } 40 | 41 | pub fn nonterminal(&self, idx: NonterminalIdx) -> &NT { 42 | &self.nonterminals[idx.0] 43 | } 44 | 45 | pub fn terminals(&self) -> &Vec { 46 | &self.terminals 47 | } 48 | 49 | pub fn nonterminals(&self) -> &Vec { 50 | &self.nonterminals 51 | } 52 | 53 | pub fn terminal_idx(&self, t: &T) -> Option { 54 | self.terminal_map.get(t).cloned() 55 | } 56 | 57 | pub fn nonterminal_idx(&self, t: &NT) -> NonterminalIdx { 58 | self.nonterminal_map[t] 59 | } 60 | 61 | pub fn terminal_map(&self) -> &HashMap { 62 | &self.terminal_map 63 | } 64 | 65 | pub fn nonterminal_map(&self) -> &HashMap { 66 | &self.nonterminal_map 67 | } 68 | } 69 | 70 | enum NonterminalName { 71 | Simple(String), 72 | Prefix(usize, Box), 73 | Suffix(usize, Box), 74 | PrefixSuffix(usize, Box, usize), 75 | } 76 | 77 | impl Display for NonterminalName { 78 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 79 | match self { 80 | NonterminalName::Simple(name) => write!(f, "{}", name), 81 | NonterminalName::Prefix(prefix, name) => write!(f, "{}<{}->", name, prefix), 82 | NonterminalName::Suffix(suffix, name) => write!(f, "{}<-{}>", name, suffix), 83 | NonterminalName::PrefixSuffix(prefix, name, suffix) => { 84 | write!(f, "{}<{}-{}>", name, prefix, suffix) 85 | } 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod bridge; 2 | mod grammar; 3 | mod parser; 4 | 5 | use crate::bridge::bnf::extension_bnf; 6 | use crate::bridge::charts::extension_charts; 7 | use pyo3::prelude::*; 8 | 9 | #[pymodule] 10 | #[pyo3(name = "_native")] 11 | fn incremental_parsing(m: &Bound<'_, PyModule>) -> PyResult<()> { 12 | extension_bnf(m)?; 13 | extension_charts(m)?; 14 | Ok(()) 15 | } 16 | -------------------------------------------------------------------------------- /src/parser.rs: -------------------------------------------------------------------------------- 1 | pub mod earley; 2 | --------------------------------------------------------------------------------