├── .gitignore ├── LICENSE ├── README.md ├── configs ├── casename_classification │ ├── casename.kogpt2.e2.yaml │ ├── casename.kogpt2.e3.yaml │ ├── casename.kogpt2.yaml │ ├── casename.lcube-base.e2.yaml │ ├── casename.lcube-base.e3.yaml │ └── casename.lcube-base.yaml ├── ljp │ ├── civil │ │ ├── ljp.civil.kogpt2.e2.yaml │ │ ├── ljp.civil.kogpt2.e3.yaml │ │ ├── ljp.civil.kogpt2.yaml │ │ ├── ljp.civil.lcube-base.e2.yaml │ │ ├── ljp.civil.lcube-base.e3.yaml │ │ └── ljp.civil.lcube-base.yaml │ └── criminal │ │ ├── ljp.criminal.kogpt2.e2.yaml │ │ ├── ljp.criminal.kogpt2.e3.yaml │ │ ├── ljp.criminal.kogpt2.yaml │ │ ├── ljp.criminal.lcube-base.e2.yaml │ │ ├── ljp.criminal.lcube-base.e3.yaml │ │ └── ljp.criminal.lcube-base.yaml ├── statute_classification │ ├── statute.kogpt2.e2.yaml │ ├── statute.kogpt2.e3.yaml │ ├── statute.kogpt2.yaml │ ├── statute.lcube-base.e2.yaml │ ├── statute.lcube-base.e3.yaml │ └── statute.lcube-base.yaml └── summarization │ ├── summarization.kogpt2.yaml │ ├── summarization.lcube-base.yaml │ └── summarization.legal-mt5s.test.yaml ├── lbox_open ├── constants │ ├── __init__.py │ └── constants_fie.py ├── data_module │ ├── __init__.py │ └── data_precedent.py ├── datasets_script │ └── lbox_open.py ├── metric │ ├── exact_match.py │ └── rouge_metric_utils.py ├── model │ ├── generative_baseline_model.py │ └── model_optimizer.py ├── openprompt_wrapper │ ├── __init__.py │ ├── data_utils │ │ └── __init__.py │ ├── pipeline_base.py │ └── plms │ │ ├── __init__.py │ │ ├── lm.py │ │ ├── mt5_additional_special_tokens.json │ │ └── utils.py ├── parser │ ├── __init__.py │ ├── output_parser.py │ └── output_parser_utils.py ├── pipeline │ ├── __init__.py │ └── lbox_open_pipeline.py ├── template │ ├── __init__.py │ ├── prompt_generation_utils.py │ └── prompt_templates.py └── utils │ ├── __init__.py │ └── general_utils.py ├── requirements.txt ├── run_model.py └── scripts ├── predict_summarization.sh ├── test_casename.sh ├── test_ljp_civil.sh ├── test_ljp_criminal.sh ├── test_statute.sh ├── test_summarization.sh ├── train_casename.sh ├── train_ljp_civil.sh ├── train_ljp_criminal.sh ├── train_statute.sh └── train_summarization.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | configs/summarization/summarization.legal-mt5s.yaml 133 | saved 134 | configs/summarization/summarization.legal-mt5s.predict.yaml 135 | logs/ 136 | data/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | 3 | A multi-task benchmark for Korean legal language understanding and judgement prediction by [LBox](https://lbox.kr) 4 | 5 | # Authors 6 | 7 | - [Wonseok Hwang](mailto:wonseok.hwang@lbox.kr) 8 | - [Dongjun Lee](mailto:dongjun.lee@lbox.kr) 9 | - [Kyoungyeon Cho](mailto:kycho@lbox.kr) 10 | - [Hanuhl Lee](mailto:leehanuhl@lbox.kr) 11 | - [Minjoon Seo](mailto:minjoon@lbox.kr) 12 | 13 | # Updates 14 | - Dec 2, 2022: We release [additional 1024 examples of `drunk driving` cases](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/judgement_prediction/ljp_criminal_drunk_driving_plus_1024.jsonl) for `ljp_criminal` task. Compared to `ljp_criminal` data, it includes the parses extracted from the facts (blood alchol level, driving distance, types of car, previous criminal history) and the suspension of exeuction period. See also [this issue](https://github.com/lbox-kr/lbox-open/issues/10). The data shall be integrated to `ljp_criminal` in the next release. 15 | 16 | - Dec 2, 2022: We will present our recent work ["Data-efficient End-to-end Information Extraction for Statistical Legal Analysis"](https://arxiv.org/abs/2211.01692) at [NLLP workshop @ EMNLP22](https://nllpw.org/workshop/)! 17 | 18 | - Nov 8, 2022: We release [`legal-mt5-small`], a domain adapted mt5-small using `precedent_corpus`. We also release the `legal-mt5-small` fine-tuned on the `summarization` dataset. Both models can be download from [here](https://drive.google.com/file/d/1lZaUtDPCkAOcwaxBzFo-QHecGAQendOd/view?usp=share_link)! To use the models, `cd [project-dir]; tar xvfz legal-mt5-small.tar.gz`. 19 | - Oct 25, 2022: [`act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus`](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus/act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus.jsonl) corpus (고통사고처리특례법위반(치상)) has been released. The corpus consists of 768 criminal cases. The corpus will be integrated into `precedent corpus` in the future (the overlap between `precedent corpus` and `defamation corpus-v0.1` is expected). See also [this issue](https://github.com/lbox-kr/lbox-open/issues/9). 20 | - Oct 18, 2022: We release three new datasets `casename_classification_plus`, `statute_classification_plus`, and `summarization_plus`! 21 | - Oct 2, 2022: [`defamation corpus-v0.1`](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/defamation_corpus/defamation_corpus.jsonl) has been added. The corpus consists of 1,536 criminal cases related to "defamation (명예훼손)". The corpus will be integrated into `precedent corpus` in the future (at the moment, there can be some overlap between `precedent corpus` and `defamation corpus-v0.1`). See also [this issue](https://github.com/lbox-kr/lbox-open/issues/4#issue-1393652876). 22 | - Sep 2022: Our paper is accepted for publication in NeurIPS 2022 Datasets and Benchmarks track! There will be major updates on the paper, the dataets, and the models soon! Meanwile, one can check the most recent version of our paper from [OpenReview](https://openreview.net/forum?id=TaARsI_Iio) 23 | - Jun 2022: We release `lbox-open-v0.2`! 24 | - Two legal judgement prediction tasks, `ljp_criminal`, `ljp-civil`, are added to LBox Open. 25 | - `LCube-base`, a LBox Legal Language model with 124M parameters, is added. 26 | - The baseline scores and its training/test scripts are added. 27 | - Other updates 28 | - Some missing values in `facts` fields of `casename_classification` and `statute_classification` are updated. 29 | - `case_corpus` is renamed to `precedent_corpus` 30 | - Mar 2022: We release `lbox-open-v0.1`! 31 | 32 | # Paper 33 | 34 | [A Multi-Task Benchmark for Korean Legal Language Understanding and Judgement Prediction](https://arxiv.org/abs/2206.05224) 35 | 36 | # Benchmarks 37 | 38 | - Last updated at Oct 18 2022 39 | 40 | | **Model** | casename | statute | ljp-criminal | ljp-civil | summarization | 41 | |-------------------|----------------|----------------|-----------------------------------------------------------------------|----------------|------------------| 42 | | | EM | EM | F1-fine
F1-imprisonment w/ labor
F1-imprisonment w/o labor | EM | R1
R2
RL | 43 | | KoGPT2 | $78.5 \pm 0.3$ | $85.7 \pm 0.8$ | $49.9 \pm 1.7$
$67.5 \pm 1.1$
$69.2 \pm 1.6$ | $66.0 \pm 0.5$ | $47.2$
$39.1$
$45.7$ | 44 | | KoGPT2 + `d.a.` | $81.9 \pm 0.2$ | $89.4 \pm 0.5$ | $49.8$
$65.4$
$70.1$ | $64.7 \pm 1.1$ | $49.2$
$40.9$
$47.7$ | 45 | | LCube-base (ours) | $81.1 \pm 0.3$ | $87.6 \pm 0.5$ | $46.4 \pm 2.8$
$69.3 \pm 0.3$
$70.3 \pm 0.7$ | $67.6 \pm 1.3$ | $46.0$
$37.7$
$44.5$ | 46 | | LCube-base + `d.a.` (ours) | $82.7 \pm 0.6$ | $89.3 \pm 0.4$ | $48.1 \pm 1.2$
$67.4 \pm 1.5$
$69.9 \pm 1.1$ | $60.9 \pm 1.1$ | $47.8$
$39.5$
$46.4$ | 47 | | mt5-small | $81.0 \pm 1.3$ | $87.2 \pm 0.3$ | $49.1 \pm 1.3$
$66.6 \pm 0.6$
$69.8 \pm 1.0$ | $68.9 \pm 0.8$ | $56.2$
$47.8$
$54.7$ | 48 | | mt5-small + `d.a.`| $82.2 \pm 0.2$ | $88.8 \pm 0.5$ | $51.8 \pm 0.7$
$68.9 \pm 0.3$
$70.3 \pm 0.7$ | $69.1 \pm 0.1$ | $56.2$
$47.7$
$54.8$ | 49 | 50 | - The errors are estimated from three independent experiments performed with different random seeds. 51 | - ROUGE scores are computed at word level. 52 | - `d.a.` stands for domain adaptation, an additional pre-trainig with `Precedent` corpus only. 53 | 54 | # Dataset 55 | 56 | ## How to use the dataset 57 | 58 | We use [`datasets`](https://github.com/huggingface/datasets) library from `HuggingFace`. 59 | 60 | ```python 61 | # !pip install datasets 62 | from datasets import load_dataset 63 | 64 | # casename classficiation task 65 | data_cn = load_dataset("lbox/lbox_open", "casename_classification") 66 | ata_cn_plus = load_dataset("lbox/lbox_open", "casename_classification_plus") 67 | 68 | # statutes classification task 69 | data_st = load_dataset("lbox/lbox_open", "statute_classification") 70 | data_st_plus = load_dataset("lbox/lbox_open", "statute_classification_plus") 71 | 72 | # Legal judgement prediction tasks 73 | data_ljp_criminal = load_dataset("lbox/lbox_open", "ljp_criminal") 74 | data_ljp_civil = load_dataset("lbox/lbox_open", "ljp_civil") 75 | 76 | # case summarization task 77 | data_summ = load_dataset("lbox/lbox_open", "summarization") 78 | data_summ_plus = load_dataset("lbox/lbox_open", "summarization_plus") 79 | 80 | # precedent corpus 81 | data_corpus = load_dataset("lbox/lbox_open", "precedent_corpus") 82 | 83 | 84 | ``` 85 | 86 | - [Explore the dataset on Colab](https://colab.research.google.com/drive/1R4T91Ix__-4rjtxATh7JeTX69zYrmWy0?usp=sharing) 87 | 88 | ## Dataset Description 89 | ### `precedent_corpus` 90 | - Korean legal precedent corpus. 91 | - The corpus consists of 150k cases. 92 | - About 80k from [LAW OPEN DATA](https://www.law.go.kr/LSO/main.do) and 70k from LBox database. 93 | 94 | - Example 95 | ```json 96 | { 97 | "id": 99990, 98 | "precedent": "주문\n피고인을 징역 6개월에 처한다.\n다만, 이 판결 확정일로부터 1년간 위 형의 집행을 유예한다.\n\n이유\n범 죄 사 실\n1. 사기\n피고인은 2020. 12. 15. 16:00경 경북 칠곡군 B에 있는 피해자 C이 운영하는 ‘D’에서, 마치 정상적으로 대금을 지급할 것처럼 행세하면서 피해자에게 술을 주문하였다.\n그러나 사실 피고인은 수중에 충분한 현금이나 신용카드 등 결제 수단을 가지고 있지 않아 정상적으로 대금을 지급할 의사나 능력이 없었다.\n그럼에도 피고인은 위와 같이 피해자를 기망하여 이에 속은 피해자로부터 즉석에서 합계 8,000원 상당의 술을 교부받았다.\n2. 공무집행방해\n피고인은 제1항 기재 일시·장소에서, ‘손님이 술값을 지불하지 않고 있다’는 내용의 112신고를 접수하고 현장에 출동한 칠곡경찰서 E지구대 소속 경찰관 F로부터 술값을 지불하고 귀가할 것을 권유받자, “징역가고 싶은데 무전취식했으니 유치장에 넣어 달라”고 말하면서 순찰차에 타려고 하였다. 이에 경찰관들이 수회 귀가 할 것을 재차 종용하였으나, 피고인은 경찰관들을 향해 “내가 돌로 순찰차를 찍으면 징역갑니까?, 내여경 엉덩이 발로 차면 들어갈 수 있나?”라고 말하고, 이를 제지하는 F의 가슴을 팔꿈치로 수회 밀쳐 폭행하였다.\n이로써 피고인은 경찰관의 112신고사건 처리에 관한 정당한 직무집행을 방해하였다. 증거의 요지\n1. 피고인의 판시 제1의 사실에 부합하는 법정진술\n1. 증인 G, F에 대한 각 증인신문조서\n1. 영수증\n1. 현장 사진\n법령의 적용\n1. 범죄사실에 대한 해당법조 및 형의 선택\n형법 제347조 제1항, 제136조 제1항, 각 징역형 선택\n1. 경합범가중\n형법 제37조 전단, 제38조 제1항 제2호, 제50조\n1. 집행유예\n형법 제62조 제1항\n양형의 이유\n1. 법률상 처단형의 범위: 징역 1월∼15년\n2. 양형기준에 따른 권고형의 범위\n가. 제1범죄(사기)\n[유형의 결정]\n사기범죄 > 01. 일반사기 > [제1유형] 1억 원 미만\n[특별양형인자]\n- 감경요소: 미필적 고의로 기망행위를 저지른 경우 또는 기망행위의 정도가 약한 경우, 처벌불원\n[권고영역 및 권고형의 범위]\n특별감경영역, 징역 1월∼1년\n[일반양형인자] 없음\n나. 제2범죄(공무집행방해)\n[유형의 결정]\n공무집행방해범죄 > 01. 공무집행방해 > [제1유형] 공무집행방해/직무강요\n[특별양형인자]\n- 감경요소: 폭행·협박·위계의 정도가 경미한 경우\n[권고영역 및 권고형의 범위]\n감경영역, 징역 1월∼8월\n[일반양형인자]\n- 감경요소: 심신미약(본인 책임 있음)\n다. 다수범죄 처리기준에 따른 권고형의 범위: 징역 1월∼1년4월(제1범죄 상한 + 제2범죄 상한의 1/2)\n3. 선고형의 결정: 징역 6월에 집행유예 1년\n만취상태에서 식당에서 소란을 피웠고, 112신고로 출동한 경찰관이 여러 차례 귀가를 종용하였음에도 이를 거부하고 경찰관의 가슴을 밀친 점 등을 종합하면 죄책을 가볍게 볼 수 없으므로 징역형을 선택하되, 평소 주량보다 훨씬 많은 술을 마신 탓에 제정신을 가누지 못해 저지른 범행으로 보이고 폭행 정도가 매우 경미한 점, 피고인이 술이 깬 후 자신의 경솔한 언동을 깊이 반성하면서 재범하지 않기 위해 정신건강의학과의 치료 및 상담을 받고 있는 점, 식당 업주에게 피해를 변상하여 용서를 받은 점, 피고인의 나이와 가족관계 등의 사정을 참작하여 형의 집행을 유예하고, 범행 경위와 범행 후 피고인의 태도 등에 비추어 볼 때 재범의 위험성은 그다지 우려하지 않아도 될 것으로 보여 보호관찰 등 부수처분은 부과하지 않음.\n이상의 이유로 주문과 같이 판결한다." 99 | } 100 | ``` 101 | - `id`: a data id. 102 | - `precedent`: a case from the court of Korea. It includes the ruling (주문), the gist of claim (청구취지), the claim of appeal (항소취지), and 103 | the reasoning (이유). 104 | 105 | ### `casename_classification` 106 | 107 | - Task: for the given facts (사실관계), a model is asked to predict the case name. 108 | - The dataset consists of 10k `(facts, case name)` pairs extracted from Korean precedents. 109 | - There are 100 classes (case categories) and each class contains 100 corresponding examples. 110 | - 8,000 training, 1,000 validation, 1,000 test, and 1,294 test2 examples. The test2 set consists of examples that do not overlap with the precedents in `precedent_corpus`. 111 | - We also provide `casename_classification_plus`, a dataset that extends `casename_classification` by including infrequent case categories. `casename_classification_plus` consists of 31,283 examples with total 603 case categories. See our paper for the detail. 112 | - Example 113 | 114 | ```json 115 | { 116 | "id": 80, 117 | "casetype": "criminal", 118 | "casename": "감염병의예방및관리에관한법률위반", 119 | "facts": "질병관리청장, 시·도지사 또는 시장·군수·구청장은 제1급 감염병이 발생한 경우 감염병의 전파방지 및 예방을 위하여 감염병의심자를 적당한 장소에 일정한 기간 격리시키는 조치를 하여야 하고, 그 격리조치를 받은 사람은 이를 위반하여서는 아니 된다. 피고인은 해외에서 국내로 입국하였음을 이유로 2021. 4. 21.경 감염병의심자로 분류되었고, 같은 날 창녕군수로부터 ‘2021. 4. 21.부터 2021. 5. 5. 12:00경까지 피고인의 주거지인 경남 창녕군 B에서 격리해야 한다’는 내용의 자가격리 통지서를 수령하였다. 1. 2021. 4. 27.자 범행 그럼에도 불구하고 피고인은 2021. 4. 27. 11:20경에서 같은 날 11:59경까지 사이에 위 격리장소를 무단으로 이탈하여 자신의 승용차를 이용하여 경남 창녕군 C에 있는 ‘D’ 식당에 다녀오는 등 자가격리 조치를 위반하였다. 2. 2021. 5. 3.자 범행 피고인은 2021. 5. 3. 10:00경에서 같은 날 11:35경까지 사이에 위 격리장소를 무단으로 이탈하여 자신의 승용차를 이용하여 불상의 장소를 다녀오는 등 자가격리 조치를 위반하였다." 120 | } 121 | ``` 122 | - `id`: a data id. 123 | - `casetype`: a case type. The value is either `civil` (민사) or `criminal` (형사). 124 | - `casename`: a case name. 125 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases. 126 | 127 | ### `statute_classification` 128 | 129 | - Task: for a given facts (사실관계), a model is asked to predict related statutes (법령). 130 | - The dataset consists of 2760 `(facts, statutes)` pairs extracted from individual Korean legal cases. 131 | - There are 46 classes (case categories) and each class has 60 examples. 132 | - 2,208 training, 276 validation, 276 test, 538 test2 examples. The test2 set consists of examples that do not overlap with the precedents in `precedent_corpus`. 133 | - We also release `statute_classification_plus`, a dataset that extends `statute_classification` by including less frequent case categories.`statute_classification_plus` includes 17,730 examples with total 434 case categories and 1,015 statutes. 134 | - Example 135 | 136 | ```json 137 | { 138 | "id": 5180, 139 | "casetype": "criminal", 140 | "casename": "사문서위조, 위조사문서행사", 141 | "statutes": [ 142 | "형법 제231조", 143 | "형법 제234조" 144 | ], 145 | "facts": "1. 사문서위조 피고인은 2014. 5. 10.경 서울 송파구 또는 하남시 이하 알 수 없는 장소에서 영수증문구용지에 검정색 볼펜을 사용하여 수신인란에 ‘A’, 일금란에 ‘오천오백육십만원정’, 내역 란에 ‘2010가합7485사건의 합의금 및 피해 보상금 완결조’, 발행일란에 ‘2014년 5월 10일’이라고 기재한 뒤, 발행인 옆에 피고인이 임의로 만들었던 B의 도장을 찍었다. 이로써 피고인은 행사할 목적으로 사실증명에 관한 사문서인 B 명의의 영수증 1장을 위조하였다. 2. 위조사문서행사 피고인은 2014. 10. 16.경 하남시 이하 알 수 없는 장소에서 피고인이 B에 대한 채무를 모두 변제하였기 때문에 B가 C회사에 채권을 양도한 것을 인정할 수 없다는 취지의 내용증명원과 함께 위와 같이 위조한 영수증 사본을 마치 진정하게 성립한 문서인 것처럼 B에게 우편으로 보냈다. 이로써 피고인은 위조한 사문서를 행사하였다." 146 | } 147 | 148 | ``` 149 | 150 | - `id`: a data id. 151 | - `casetype`: a case type. The value is always `criminal`. 152 | - `casename`: a case name. 153 | - `statutes`: related statues. 154 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases. 155 | 156 | ### `ljp_criminal` 157 | 158 | - Task: a model needs to predict the ranges of fine (벌금), imprisonment with labor (징역), imprisonment without labor (금고). 159 | - 10,500 `facts` and the corresponding punishment are extracted from cases with following case categories are “indecent 160 | act by compulsion” (강제추행), “obstruction of performance of official duties” (공무집행방해), “bodily injuries from traffic 161 | accident” (교통사고처리특례법위반(치상)), “drunk driving” (도로교통 법위반(음주운전)), “fraud” (사기), “inflicting bodily injuries” (상해), and 162 | “violence” (폭행) 163 | - 8,400 training, 1,050 validation, 1,050 test, 928 test2 examples. The test2 set consists of the examples from the test set that do not overlap with the precedents in `precedent_corpus`. 164 | - Example 165 | ```json 166 | { 167 | "casename": "공무집행방해", 168 | "casetype": "criminal", 169 | "facts": "피고인은 2020. 3. 13. 18:57경 수원시 장안구 B 앞 노상에서 지인인 C와 술을 마시던 중 C를 때려 112신고를 받고 출동한 수원중부경찰서 D지구대 소속 경위 E가 C의 진술을 청취하고 있는 모습을 보고 화가 나 '씨발,개새끼'라며 욕설을 하고, 위 E가 이를 제지하며 귀가를 종용하자 그의 왼쪽 뺨을 오른 주먹으로 1회 때려 폭행하였다.\n이로써 피고인은 경찰관의 112신고사건 처리에 관한 정당한 직무집행을 방해하였다. 증거의 요지\n1. 피고인의 법정진술\n1. 피고인에 대한 경찰 피의자신문조서\n1. E에 대한 경찰 진술조서\n1. 현장사진 등, 바디캠영상", 170 | "id": 2300, 171 | "label": { 172 | "fine_lv": 0, 173 | "imprisonment_with_labor_lv": 2, 174 | "imprisonment_without_labor_lv": 0, 175 | "text": "징역 6월" 176 | }, 177 | "reason": "양형의 이유\n1. 법률상 처단형의 범위: 징역 1월∼5년\n2. 양형기준에 따른 권고형의 범위\n[유형의 결정]\n공무집행방해범죄 > 01. 공무집행방해 > [제1유형] 공무집행방해/직무강요\n[특별양형인자] 없음\n[권고영역 및 권고형의 범위] 기본영역, 징역 6월∼1년6월\n3. 선고형의 결정\n피고인이 싸움 발생 신고를 받고 출동한 경찰관에게 욕설을 퍼붓고 귀가를 종용한다는 이유로 경찰관의 뺨을 때리는 등 폭행을 행사하여 경찰관의 정당한 공무집행을 방해한 점에서 그 죄책이 매우 무겁다. 피고인의 범죄 전력도 상당히 많다.\n다만, 피고인이 범행을 인정하면서 반성하고 있는 점, 공무집행방해 범죄로 처벌받은 전력이 없는 점 등은 피고인에게 유리한 정상으로 참작한다.\n그 밖에 피고인의 연령, 성행, 환경, 가족관계, 건강상태, 범행의 동기와 수단 및 결과, 범행 후의 정황 등 이 사건 기록 및 변론에 나타난 모든 양형요소를 종합하여, 주문과 같이 형을 정한다.", 178 | "ruling": { 179 | "parse": { 180 | "fine": { 181 | "type": "", 182 | "unit": "", 183 | "value": -1 184 | }, 185 | "imprisonment": { 186 | "type": "징역", 187 | "unit": "mo", 188 | "value": 6 189 | } 190 | }, 191 | "text": "피고인을 징역 6월에 처한다.\n다만 이 판결 확정일로부터 2년간 위 형의 집행을 유예한다." 192 | } 193 | } 194 | ``` 195 | 196 | - `id`: a data id. 197 | - `casetype`: a case type. The value is always `criminal`. 198 | - `casename`: a case name. 199 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases. 200 | - `label` 201 | - `fine_lv`: a label representing individual ranges of the fine amount. See our paper for the detail. 202 | - `imprisonment_with_labor_lv`: a label representing the ranges of the imprisonemnt with labor. 203 | - `imprisonment_without_labor_lv`: a label for the imprisonment without labor case. 204 | - `reason`: the reason for the punishment (양형의 이유). 205 | - `ruling`: the ruling (주문) and its parsing result. `"" and -1` indicates null values. 206 | 207 | ### `ljp_civil` 208 | 209 | - Task: a model is asked to predict the claim acceptance level (= "the approved money" / "the claimed money") 210 | - 4,678 `facts` and the corresponding acceptance lv from 4 case categories: 929 examples from “price of 211 | indemnification” (구상금), 745 examples from “loan” (대여금), 1,004 examples from “unfair profits” (부당이득금), and 2,000 212 | examples from “lawsuit for damages (etc)” (손해배상(기)). 213 | - 3,742 training, 467 validation, 467 test, 403 test2 examples. The test2 set consists of the test set examples those do not overlap with the precedents in `precedent_corpus`. 214 | - Example 215 | ```json 216 | { 217 | "id": 99, 218 | "casetype": "civil", 219 | "casename": "구상금", 220 | "claim_acceptance_lv": 1, 221 | "facts": "가. C는 2017. 7. 21. D으로부터 100,000,000원을 이율 연 25%, 변제기 2017. 8. 20.로 정하여 차용하였고(이하 ‘이 사건 차용금채무'라고 한다), 피고는 이 사건 차용금 채무를 보증한도액 140,000,000원, 보증기한 10년으로 정하여 연대보증하였으며, 같은 날 이 사건 차용금채무에 관한 공정증서를 작성하였다(공증인가 법무법인 E 증서 2017년 제392호, 이하 ‘이 사건 공정증서'라고 한다).\n나. 원고는 이 사건 차용금채무와 관련하여 원고 소유의 안산시 상록구 F, G, H 및 그 지상 건물(이하 ‘이 사건 부동산'이라고 한다)을 담보로 제공하기로 하여 2017. 7. 21. 수원지방법원 안산지원 접수 제53820호로 채권최고액 140,000,000원, 채무자 C, 근저당권자 D으로 한 근저당권설정등기를 경료하는 한편, 2018. 7. 13. D에게 이 사건 공정증서에 기한 채무를 2018. 7. 31.까지 변제하고, 변제기 이후 연 24%의 비율로 계산한 지연손해금을 지급하기로 하는 차용증을 작성하여 주었다(이하 ‘이 사건 차용증'이라고 한다).\n다. 원고는 2019. 11. 29. D에게 이 사건 차용금채무 원리금으로 합계 157,500,000원을 변제하였다.", 222 | "gist_of_claim": { 223 | "money": { 224 | "provider": "피고", 225 | "taker": "원고", 226 | "unit": "won", 227 | "value": 140000000 228 | }, 229 | "text": "피고는 원고에게 140,000,000원 및 이에 대한 2019. 11. 30.부터 이 사건 소장 부본 송달일까지는 연 5%의, 그 다음날부터 다 갚는 날까지는 연 12%의 각 비율로 계산한 돈을 지급하라." 230 | }, 231 | "ruling": { 232 | "litigation_cost": 0.5, 233 | "money": { 234 | "provider": "피고", 235 | "taker": "원고", 236 | "unit": "won", 237 | "value": 78750000 238 | }, 239 | "text": "1. 피고는 원고에게 78,750,000원 및 이에 대한 2019. 11. 30.부터 2021. 11. 26.까지는 연 5%의, 그 다음날부터 다 갚는 날까지는 연 12%의 각 비율로 계산한 돈을 지급하라.\n2. 원고의 나머지 청구를 기각한다.\n3. 소송비용 중 1/2은 원고가 나머지는 피고가 각 부담한다.\n4. 제1항은 가집행할 수 있다." 240 | } 241 | } 242 | 243 | ``` 244 | 245 | - `id`: a data id. 246 | - `casetype`: a case type. The value is always `civil`. 247 | - `casename`: a case name. 248 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases. 249 | - `claim_acceptaance_lv`: the claim acceptance level. `0`, `1`, and `2` indicate rejection, partial approval, and full approval respectively. 250 | - `gist_of_claim`: a gist of claim from plaintiffs (청구 취지) and its parsing result. 251 | - `ruling`: a ruling (주문) and its parsing results. 252 | - `litigation_cost`: the ratio of the litigation cost that the plaintiff should pay. 253 | 254 | ### `summarization` 255 | 256 | - Task: a model is asked to summarize precedents from the Supreme Court of Korea. 257 | - The dataset is obtained from [LAW OPEN DATA](https://www.law.go.kr/LSO/main.do). 258 | - The dataset consists of 20k `(precendent, summary)` pairs. 259 | - 16,000 training, 2,000 validation, and 2,000 test examples. 260 | - We also provide `summarization_plus` by extending `summarization` with precedents with longer text making the task more challenging and realistic. In the extended dataset there are a total of 51,114 examples. The average number of tokens in the precedents and the corresponding summaries are 1,516 and 248 respectively. The maximum number of tokens in the input texts and the summaries are 93,420 and 6,536 respectively. 261 | 262 | - Example 263 | 264 | ```json 265 | { 266 | "id": 16454, 267 | "summary": "[1] 피고와 제3자 사이에 있었던 민사소송의 확정판결의 존재를 넘어서 그 판결의 이유를 구성하는 사실관계들까지 법원에 현저한 사실로 볼 수는 없다. 민사재판에 있어서 이미 확정된 관련 민사사건의 판결에서 인정된 사실은 특별한 사정이 없는 한 유력한 증거가 되지만, 당해 민사재판에서 제출된 다른 증거 내용에 비추어 확정된 관련 민사사건 판결의 사실인정을 그대로 채용하기 어려운 경우에는 합리적인 이유를 설시하여 이를 배척할 수 있다는 법리도 그와 같이 확정된 민사판결 이유 중의 사실관계가 현저한 사실에 해당하지 않음을 전제로 한 것이다.\n\n\n[2] 원심이 다른 하급심판결의 이유 중 일부 사실관계에 관한 인정 사실을 그대로 인정하면서, 위 사정들이 ‘이 법원에 현저한 사실’이라고 본 사안에서, 당해 재판의 제1심 및 원심에서 다른 하급심판결의 판결문 등이 증거로 제출된 적이 없고, 당사자들도 이에 관하여 주장한 바가 없음에도 이를 ‘법원에 현저한 사실’로 본 원심판단에 법리오해의 잘못이 있다고 한 사례.", 268 | "precedent": "주문\n원심판결을 파기하고, 사건을 광주지방법원 본원 합의부에 환송한다.\n\n이유\n상고이유를 판단한다.\n1. 피고와 제3자 사이에 있었던 민사소송의 확정판결의 존재를 넘어서 그 판결의 이유를 구성하는 사실관계들까지 법원에 현저한 사실로 볼 수는 없다(대법원 2010. 1. 14. 선고 2009다69531 판결 참조). 민사재판에 있어서 이미 확정된 관련 민사사건의 판결에서 인정된 사실은 특별한 사정이 없는 한 유력한 증거가 되지만, 당해 민사재판에서 제출된 다른 증거 내용에 비추어 확정된 관련 민사사건 판결의 사실인정을 그대로 채용하기 어려운 경우에는 합리적인 이유를 설시하여 이를 배척할 수 있다는 법리(대법원 2018. 8. 30. 선고 2016다46338, 46345 판결 등 참조)도 그와 같이 확정된 민사판결 이유 중의 사실관계가 현저한 사실에 해당하지 않음을 전제로 한 것이다.\n2. 원심은 광주고등법원 2003나8816 판결 이유 중 ‘소외인이 피고 회사를 설립한 경위’에 관한 인정 사실, 광주지방법원 목포지원 2001가합1664 판결과 광주고등법원 2003나416 판결 이유 중 ‘피고 회사 이사회의 개최 여부’에 관한 인정 사실을 그대로 인정하면서, 위 사정들이 ‘이 법원에 현저한 사실’이라고 보았다.\n그런데 이 사건 기록에 의하면, 광주고등법원 2003나8816 판결, 광주지방법원 목포지원 2001가합1664 판결, 광주고등법원 2003나416 판결은 제1심 및 원심에서 판결문 등이 증거로 제출된 적이 없고, 당사자들도 이에 관하여 주장한 바가 없다.\n그렇다면 원심은 ‘법원에 현저한 사실’에 관한 법리를 오해한 나머지 필요한 심리를 다하지 아니한 채, 당사자가 증거로 제출하지 않고 심리가 되지 않았던 위 각 판결들에서 인정된 사실관계에 기하여 판단한 잘못이 있다. 이 점을 지적하는 상고이유 주장은 이유 있다.\n3. 그러므로 나머지 상고이유에 대한 판단을 생략한 채 원심판결을 파기하고, 사건을 다시 심리·판단하게 하기 위하여 원심법원에 환송하기로 하여, 관여 대법관의 일치된 의견으로 주문과 같이 판결한다." 269 | } 270 | ``` 271 | 272 | - `id`: a data id. 273 | - `summary`: a summary (판결요지) of given precedent (판결문). 274 | - `precedent`: a case from the Korean supreme court. 275 | 276 | 277 | 278 | # Models 279 | 280 | ## How to use the language model `lcube-base` 281 | ```python 282 | # !pip instal transformers==4.19.4 283 | import transformers 284 | 285 | model = transformers.GPT2LMHeadModel.from_pretrained("lbox/lcube-base") 286 | tokenizer = transformers.AutoTokenizer.from_pretrained( 287 | "lbox/lcube-base", 288 | bos_token="[BOS]", 289 | unk_token="[UNK]", 290 | pad_token="[PAD]", 291 | mask_token="[MASK]", 292 | ) 293 | 294 | text = "피고인은 불상지에 있는 커피숍에서, 피해자 B으로부터" 295 | model_inputs = tokenizer(text, 296 | max_length=1024, 297 | padding=True, 298 | truncation=True, 299 | return_tensors='pt') 300 | out = model.generate( 301 | model_inputs["input_ids"], 302 | max_new_tokens=150, 303 | pad_token_id=tokenizer.pad_token_id, 304 | use_cache=True, 305 | repetition_penalty=1.2, 306 | top_k=5, 307 | top_p=0.9, 308 | temperature=1, 309 | num_beams=2, 310 | ) 311 | tokenizer.batch_decode(out) 312 | ``` 313 | 314 | ## Fine-tuning 315 | ### Setup 316 | 317 | ```bash 318 | conda create -n lbox-open pytyon=3.8.11 319 | conda install pytorch==1.10.1 torchvision torchaudio cudatoolkit=11.3 -c pytorch 320 | pip install -r requirements.txt 321 | ``` 322 | 323 | ### Training 324 | 325 | ```bash 326 | python run_model.py [TRINING_CONFIG_FILE_PATH] --mode train 327 | ```` 328 | See also `scripts/train_[TASK].sh` 329 | 330 | ### Test 331 | 332 | 1. Make the test config file from the training config file by copying and changing the values of `trained` and `path` fields as shown below. 333 | ```yaml 334 | train: 335 | weights: 336 | trained: true 337 | path: ./models/[THE NAME OF THE TRAININ CONFIG FILE]/epoch=[XX]-step=[XX].ckpt 338 | ``` 339 | 2. 340 | ```bash 341 | python run_model.py [TEST_CONFIG_FILE_PATH] --mode test 342 | ```` 343 | See also `scripts/test_[TASK].sh` 344 | 345 | 346 | 347 | # Licensing Information 348 | 349 | Copyright 2022-present [LBox Co. Ltd.](https://lbox.kr/) 350 | 351 | Licensed under the [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) 352 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.kogpt2.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.kogpt2.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.kogpt2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.lcube-base.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 64 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 1 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 64 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | 93 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.lcube-base.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 64 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 1 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 64 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | 93 | -------------------------------------------------------------------------------- /configs/casename_classification/casename.lcube-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 64 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: casename_classification 17 | subtask: casename_classification 18 | target_field: facts 19 | target_parses_dict: 20 | casename_classification: 21 | - casename 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 10 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 1 70 | val_check_interval: 1.0 71 | validation_metric: em 72 | validation_target_parse: casename_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 64 79 | max_new_tokens: 64 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | 93 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.kogpt2.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.kogpt2.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.kogpt2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.lcube-base.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.lcube-base.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/civil/ljp.civil.lcube-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1021 16 | task: ljp_civil 17 | subtask: civil 18 | target_field: facts 19 | target_parses_dict: 20 | claim_acceptance_lv: 21 | - claim_acceptance_lv 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: claim_acceptance_lv 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 3 80 | min_length: 1 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.kogpt2.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: kogpt2 29 | path: skt/kogpt2-base-v2 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 20 42 | multiple_trainloader_mode: 43 | seed: 2 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.kogpt2.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: kogpt2 29 | path: skt/kogpt2-base-v2 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 10 42 | multiple_trainloader_mode: 43 | seed: 3 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.kogpt2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: kogpt2 29 | path: skt/kogpt2-base-v2 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 20 42 | multiple_trainloader_mode: 43 | seed: 1 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.lcube-base.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: legal-gpt 29 | path: lbox/lcube-base 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 20 42 | multiple_trainloader_mode: 43 | seed: 2 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.lcube-base.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: legal-gpt 29 | path: lbox/lcube-base 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 10 42 | multiple_trainloader_mode: 43 | seed: 3 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/ljp/criminal/ljp.criminal.lcube-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1012 16 | task: ljp_criminal 17 | subtask: criminal 18 | target_field: facts 19 | target_parses_dict: 20 | fine_imprisonment_lvs: 21 | - fine_lv 22 | - imprisonment_with_labor_lv 23 | - imprisonment_without_labor_lv 24 | path_template: 25 | plm: 26 | freeze: false 27 | eval_mode: false 28 | name: legal-gpt 29 | path: lbox/lcube-base 30 | revision: 31 | precision: bf16 32 | 33 | train: 34 | accelerator: auto 35 | accumulate_grad_batches: 2 36 | limit_val_batches: 1.0 37 | batch_size: 4 38 | batch_size_prediction: 12 39 | check_val_every_n_epoch: 1 40 | fast_dev_run: false 41 | max_epochs: 20 42 | multiple_trainloader_mode: 43 | seed: 1 44 | strategy: null 45 | weight: 46 | trained: false 47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 48 | save_path_dir: ./data/models 49 | do_not_load_pretrained_weight: false 50 | old_format: false 51 | log_dir: ./logs 52 | optim: 53 | gradient_clip_val: 1.0 54 | gradient_clip_algorithm: norm 55 | prompt: 56 | lr: 0.1 57 | optimizer_type: adamw 58 | lr_scheduler_type: warmup_constant 59 | lr_scheduler_param: 60 | warmup_constant: 61 | num_warmup_steps: 10 62 | plm: 63 | lr: 0.00005 64 | optimizer_type: adamw 65 | swa: 66 | use: true 67 | lr: 0.00005 68 | swa_epoch_start: 4 69 | annealing_epochs: 6 70 | profiler: null 71 | num_sanity_val_steps: 0 72 | val_check_interval: 0.5 73 | validation_metric: em 74 | validation_target_parse: fine_imprisonment_lvs 75 | validation_sub_param: 76 | method: average 77 | target_sub_parse: average 78 | 79 | infer: 80 | max_length: 81 | max_new_tokens: 12 82 | min_length: 5 83 | temperature: 1.0 84 | do_sample: False 85 | top_k: 0 86 | top_p: 0.9 87 | repetition_penalty: 1.0 88 | num_beams: 1 89 | bad_words_ids: null 90 | parse_sep_token: "," 91 | value_sep_token: "|" 92 | empty_token: "0" 93 | 94 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.kogpt2.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.kogpt2.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.kogpt2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.lcube-base.e2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 2 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.lcube-base.e3.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 3 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/statute_classification/statute.lcube-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test2 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 960 16 | task: statute_classification 17 | subtask: statute_classification 18 | target_field: facts 19 | target_parses_dict: 20 | statute_classification: 21 | - statute 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: bf16 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 4 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 1 38 | fast_dev_run: false 39 | max_epochs: 15 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.00005 62 | optimizer_type: adamw 63 | swa: 64 | use: true 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: em 72 | validation_target_parse: statute_classification 73 | validation_sub_param: 74 | method: text_em 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 64 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "," 89 | value_sep_token: "|" 90 | empty_token: "0" 91 | 92 | -------------------------------------------------------------------------------- /configs/summarization/summarization.kogpt2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 768 16 | task: summarization 17 | subtask: summarization 18 | target_field: precedent 19 | target_parses_dict: 20 | summarization: 21 | - summarization 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: kogpt2 27 | path: skt/kogpt2-base-v2 28 | revision: 29 | precision: 32 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 6 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 2 38 | fast_dev_run: false 39 | max_epochs: 20 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: false 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 1.0 71 | validation_metric: rougeL 72 | validation_target_parse: summarization 73 | validation_sub_param: 74 | method: rougeL 75 | target_sub_parse: 76 | 77 | 78 | infer: 79 | max_length: 80 | max_new_tokens: 256 81 | min_length: 5 82 | temperature: 1.0 83 | do_sample: False 84 | top_k: 0 85 | top_p: 0.9 86 | repetition_penalty: 1.0 87 | num_beams: 1 88 | bad_words_ids: null 89 | parse_sep_token: "*" 90 | value_sep_token: "|" 91 | empty_token: "없음" 92 | 93 | -------------------------------------------------------------------------------- /configs/summarization/summarization.lcube-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 1024 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 768 16 | task: summarization 17 | subtask: summarization 18 | target_field: precedent 19 | target_parses_dict: 20 | summarization: 21 | - summarization 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: legal-gpt 27 | path: lbox/lcube-base 28 | revision: 29 | precision: 32 30 | 31 | train: 32 | accelerator: auto 33 | accumulate_grad_batches: 2 34 | limit_val_batches: 1.0 35 | batch_size: 6 36 | batch_size_prediction: 12 37 | check_val_every_n_epoch: 2 38 | fast_dev_run: false 39 | max_epochs: 20 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: false 45 | path: 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: false 65 | lr: 0.00005 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: rougeL 72 | validation_target_parse: summarization 73 | validation_sub_param: 74 | method: rougeL 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 79 | max_new_tokens: 256 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | -------------------------------------------------------------------------------- /configs/summarization/summarization.legal-mt5s.test.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_card: lbox/lbox_open 3 | training_set_name: train 4 | validation_set_name: validation 5 | test_set_name: test 6 | use_local_data: false 7 | path_train: 8 | path_valid: 9 | path_test: 10 | 11 | model: 12 | decoder_max_length: 512 13 | input_template_type: 0 14 | model_type: generative 15 | max_seq_length: 1024 16 | task: summarization 17 | subtask: summarization 18 | target_field: precedent 19 | target_parses_dict: 20 | summarization: 21 | - summarization 22 | path_template: 23 | plm: 24 | freeze: false 25 | eval_mode: false 26 | name: mt5 27 | path: google/mt5-small 28 | revision: 29 | precision: bf16 30 | train: 31 | accelerator: auto 32 | accumulate_grad_batches: 1 33 | limit_train_batches: 0.2 34 | limit_val_batches: 4 35 | batch_size: 12 36 | batch_size_prediction: 36 37 | check_val_every_n_epoch: 2 38 | fast_dev_run: false 39 | max_epochs: 60 40 | multiple_trainloader_mode: 41 | seed: 1 42 | strategy: null 43 | weight: 44 | trained: true 45 | path: saved/models/lbox-open/legal-mt5s-summarization.pt 46 | save_path_dir: ./data/models 47 | do_not_load_pretrained_weight: false 48 | old_format: false 49 | log_dir: ./logs 50 | optim: 51 | gradient_clip_val: 1.0 52 | gradient_clip_algorithm: norm 53 | prompt: 54 | lr: 0.1 55 | optimizer_type: adamw 56 | lr_scheduler_type: warmup_constant 57 | lr_scheduler_param: 58 | warmup_constant: 59 | num_warmup_steps: 10 60 | plm: 61 | lr: 0.0001 62 | optimizer_type: adamw 63 | swa: 64 | use: false 65 | lr: 0.0001 66 | swa_epoch_start: 4 67 | annealing_epochs: 6 68 | profiler: null 69 | num_sanity_val_steps: 0 70 | val_check_interval: 0.5 71 | validation_metric: rougeL 72 | validation_target_parse: summarization 73 | validation_sub_param: 74 | method: rougeL 75 | target_sub_parse: 76 | 77 | infer: 78 | max_length: 512 79 | max_new_tokens: 512 80 | min_length: 5 81 | temperature: 1.0 82 | do_sample: False 83 | top_k: 0 84 | top_p: 0.9 85 | repetition_penalty: 1.0 86 | num_beams: 1 87 | bad_words_ids: null 88 | parse_sep_token: "*" 89 | value_sep_token: "|" 90 | empty_token: "없음" 91 | 92 | 93 | -------------------------------------------------------------------------------- /lbox_open/constants/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants_fie import * 2 | -------------------------------------------------------------------------------- /lbox_open/constants/constants_fie.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL = { 6 | "fine_lv": "벌금", 7 | "imprisonment_with_labor_lv": "징역", 8 | "imprisonment_without_labor_lv": "금고", 9 | } 10 | -------------------------------------------------------------------------------- /lbox_open/data_module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/data_module/__init__.py -------------------------------------------------------------------------------- /lbox_open/data_module/data_precedent.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import datasets 6 | import pytorch_lightning as pl 7 | from openprompt import PromptDataLoader 8 | from pytorch_lightning.trainer.supporters import CombinedLoader 9 | 10 | from lbox_open import openprompt_wrapper 11 | from lbox_open.template import prompt_generation_utils 12 | 13 | 14 | class PrecedentData(object): 15 | def __init__(self, cfg, mode, target_parse, target_sub_parses, raw_data): 16 | assert mode in ["train", "valid", "test", "predict"] 17 | self.cfg = cfg 18 | self.mode = mode 19 | self.target_parse = target_parse 20 | self.label_key = self._get_label_key(target_parse) 21 | self.target_sub_parses = target_sub_parses 22 | self.data_aug_param = cfg.train.get("data_aug_param", None) 23 | self.doc_id_key = self._get_doc_id(cfg.model.task) 24 | if raw_data is not None: 25 | self.features = self._gen_input_features(raw_data) 26 | 27 | def __getitem__(self, idx): 28 | return self.features[idx] 29 | 30 | def get_text_a(self, raw_data1): 31 | if isinstance(self.cfg.model.target_field, list): 32 | text_a = "" 33 | if self.cfg.model.task == "ljp_civil": 34 | for i, k in enumerate(self.cfg.model.target_field): 35 | if k == "facts": 36 | text_a += f"사실관계: {raw_data1[k]}\n" 37 | elif k == "claim": 38 | text_a += f"청구취지: {raw_data1[k]['text']}\n" 39 | else: 40 | raise NotImplementedError 41 | text_a = text_a.strip() 42 | else: 43 | for i, k in enumerate(self.cfg.model.target_field): 44 | text_a += f"{raw_data1[k]}\n" 45 | text_a = text_a.strip() 46 | 47 | else: 48 | text_a = raw_data1[self.cfg.model.target_field] 49 | return text_a 50 | 51 | def _get_label_key(self, target_parse): 52 | if target_parse in ["claim_acceptance_lv"]: 53 | label_key = "claim_acceptance_lv" 54 | elif target_parse in ["casename_classification"]: 55 | label_key = "casename" 56 | elif target_parse in ["statute_classification"]: 57 | label_key = "statutes" 58 | elif target_parse in ["summarization"]: 59 | label_key = "summary" 60 | else: 61 | label_key = "label" 62 | 63 | return label_key 64 | 65 | def _gen_input_features(self, raw_data): 66 | features = [] 67 | 68 | for i, raw_data1 in enumerate(raw_data): 69 | try: 70 | text_a = self.get_text_a(raw_data1) 71 | 72 | if self.label_key in raw_data1: 73 | tgt_text = prompt_generation_utils.gen_output_template( 74 | self.cfg.model.task, 75 | self.target_parse, 76 | self.target_sub_parses, 77 | raw_data1[self.label_key], 78 | self.cfg.infer.parse_sep_token, 79 | ) 80 | else: 81 | assert self.mode == "predict" 82 | tgt_text = "This is a dummy text." 83 | 84 | feature = openprompt_wrapper.InputExampleWrapper( 85 | text_a=text_a, 86 | text_b="", 87 | tgt_text=tgt_text, 88 | guid=str(raw_data1[self.doc_id_key]), 89 | ) 90 | except Exception as e: 91 | print(f"doc_id: {self.doc_id_key}") 92 | print(repr(e)) 93 | raise e 94 | features.append(feature) 95 | return features 96 | 97 | def __len__(self): 98 | if self.mode != "predict": 99 | return len(self.features) 100 | else: 101 | return 0 102 | 103 | def __iter__(self): 104 | self.features.__iter__() 105 | 106 | def _get_doc_id(self, task): 107 | 108 | if task in [ 109 | "ljp_civil", 110 | "ljp_criminal", 111 | "casename_classification", 112 | "statute_classification", 113 | "summarization", 114 | ]: 115 | doc_id_key = "id" 116 | else: 117 | raise NotImplementedError 118 | return doc_id_key 119 | 120 | 121 | class PrecedentDataModule(pl.LightningDataModule): 122 | def __init__( 123 | self, cfg, plm_tokenizer, TokenizerWrapper, input_templates, raw_data=None 124 | ): 125 | super().__init__() 126 | self.cfg = cfg 127 | self.task = cfg.model.task 128 | self.raw_data = raw_data 129 | 130 | self.plm_tokenizer = plm_tokenizer 131 | self.TokenizerWrapperClass = TokenizerWrapper 132 | 133 | self.data_ts = {} 134 | self.data_vs = {} 135 | self.data_es = {} 136 | 137 | self.input_templates = input_templates 138 | self.target_parses_dict = cfg.model.target_parses_dict 139 | if len(self.target_parses_dict) > 1: 140 | raise Exception("Multitask learning is currently not supported!") 141 | 142 | self.use_local_data = cfg.data.use_local_data 143 | self.dataset_card = cfg.data.dataset_card 144 | 145 | self.training_set_name = cfg.data.training_set_name 146 | self.validation_set_name = cfg.data.validation_set_name 147 | self.test_set_name = cfg.data.test_set_name 148 | 149 | def setup(self, stage): 150 | if not self.use_local_data: 151 | assert self.raw_data is None 152 | self.raw_data = datasets.load_dataset(self.dataset_card, self.task) 153 | 154 | # Assign train/val datasets for use in dataloaders 155 | if stage in ["fit", "test"] or stage is None: 156 | for target_parse, target_sub_parses in self.target_parses_dict.items(): 157 | self.data_ts[target_parse] = PrecedentData( 158 | self.cfg, 159 | "train", 160 | target_parse, 161 | target_sub_parses, 162 | self.raw_data[self.training_set_name], 163 | ).features 164 | self.data_vs[target_parse] = PrecedentData( 165 | self.cfg, 166 | "valid", 167 | target_parse, 168 | target_sub_parses, 169 | self.raw_data[self.validation_set_name], 170 | ).features 171 | if "test" in self.raw_data: 172 | self.data_es[target_parse] = PrecedentData( 173 | self.cfg, 174 | "test", 175 | target_parse, 176 | target_sub_parses, 177 | self.raw_data[self.test_set_name], 178 | ).features 179 | 180 | def train_dataloader(self): 181 | data_loaders = {} 182 | for target_parse, target_sub_parses in self.target_parses_dict.items(): 183 | data_loaders[target_parse] = PromptDataLoader( 184 | dataset=self.data_ts[target_parse], 185 | template=self.input_templates[target_parse], 186 | tokenizer=self.plm_tokenizer, 187 | tokenizer_wrapper_class=self.TokenizerWrapperClass, 188 | max_seq_length=self.cfg.model.max_seq_length, 189 | decoder_max_length=self.cfg.model.decoder_max_length, 190 | batch_size=self.cfg.train.batch_size, 191 | shuffle=True, 192 | teacher_forcing=True, 193 | predict_eos_token=True, 194 | truncate_method="head", 195 | ).dataloader 196 | 197 | return data_loaders 198 | 199 | def val_dataloader(self): 200 | data_loaders = {} 201 | 202 | for target_parse, target_sub_parses in self.target_parses_dict.items(): 203 | data_loaders[target_parse] = PromptDataLoader( 204 | dataset=self.data_vs[target_parse], 205 | template=self.input_templates[target_parse], 206 | tokenizer=self.plm_tokenizer, 207 | tokenizer_wrapper_class=self.TokenizerWrapperClass, 208 | max_seq_length=self.cfg.model.max_seq_length, 209 | decoder_max_length=self.cfg.model.decoder_max_length, 210 | batch_size=self.cfg.train.batch_size_prediction, 211 | shuffle=False, 212 | teacher_forcing=False, 213 | predict_eos_token=True, 214 | truncate_method="head", 215 | ).dataloader 216 | 217 | data_loaders = CombinedLoader(data_loaders) 218 | 219 | return data_loaders 220 | 221 | def test_dataloader(self): 222 | data_loaders = {} 223 | for target_parse, target_sub_parses in self.target_parses_dict.items(): 224 | data_loaders[target_parse] = PromptDataLoader( 225 | dataset=self.data_es[target_parse], 226 | template=self.input_templates[target_parse], 227 | tokenizer=self.plm_tokenizer, 228 | tokenizer_wrapper_class=self.TokenizerWrapperClass, 229 | max_seq_length=self.cfg.model.max_seq_length, 230 | decoder_max_length=self.cfg.model.decoder_max_length, 231 | batch_size=self.cfg.train.batch_size_prediction, 232 | shuffle=False, 233 | teacher_forcing=False, 234 | predict_eos_token=True, 235 | truncate_method="head", 236 | ).dataloader 237 | 238 | data_loaders = CombinedLoader(data_loaders) 239 | 240 | return data_loaders 241 | -------------------------------------------------------------------------------- /lbox_open/datasets_script/lbox_open.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | # 2022.10.18, Wonseok: Add casename_classification_plus, statute_classification_plus, summarization_plus datasets 5 | 6 | 7 | import json 8 | 9 | import datasets 10 | 11 | _CASENAME_CLASSIFICATION_FEATURES = { 12 | "id": datasets.Value("int64"), 13 | "casetype": datasets.Value("string"), 14 | "casename": datasets.Value("string"), 15 | "facts": datasets.Value("string"), 16 | } 17 | 18 | 19 | _STATUTE_CLASSIFICATION_FEATURES = { 20 | "id": datasets.Value("int64"), 21 | "casetype": datasets.Value("string"), 22 | "casename": datasets.Value("string"), 23 | "statutes": datasets.features.Sequence(datasets.Value("string")), 24 | "facts": datasets.Value("string"), 25 | } 26 | 27 | _LJP_CRIMINAL = { 28 | "id": datasets.Value("int64"), 29 | "casetype": datasets.Value("string"), 30 | "casename": datasets.Value("string"), 31 | "facts": datasets.Value("string"), 32 | "reason": datasets.Value("string"), 33 | "label": { 34 | "text": datasets.Value("string"), 35 | "fine_lv": datasets.Value("int64"), 36 | "imprisonment_with_labor_lv": datasets.Value("int64"), 37 | "imprisonment_without_labor_lv": datasets.Value("int64"), 38 | }, 39 | "ruling": { 40 | "text": datasets.Value("string"), 41 | "parse": { 42 | "fine": { 43 | "type": datasets.Value("string"), 44 | "unit": datasets.Value("string"), 45 | "value": datasets.Value("int64"), 46 | }, 47 | "imprisonment": { 48 | "type": datasets.Value("string"), 49 | "unit": datasets.Value("string"), 50 | "value": datasets.Value("int64"), 51 | }, 52 | }, 53 | }, 54 | } 55 | 56 | _LJP_CIVIL = { 57 | "id": datasets.Value("int64"), 58 | "casetype": datasets.Value("string"), 59 | "casename": datasets.Value("string"), 60 | "facts": datasets.Value("string"), 61 | "claim_acceptance_lv": datasets.Value("int64"), 62 | "gist_of_claim": { 63 | "text": datasets.Value("string"), 64 | "money": { 65 | "provider": datasets.Value("string"), 66 | "taker": datasets.Value("string"), 67 | "unit": datasets.Value("string"), 68 | "value": datasets.Value("int64"), 69 | }, 70 | }, 71 | "ruling": { 72 | "text": datasets.Value("string"), 73 | "money": { 74 | "provider": datasets.Value("string"), 75 | "taker": datasets.Value("string"), 76 | "unit": datasets.Value("string"), 77 | "value": datasets.Value("int64"), 78 | }, 79 | "litigation_cost": datasets.Value("float32"), 80 | }, 81 | } 82 | 83 | _SUMMARIZATION_FEATURES = { 84 | "id": datasets.Value("int64"), 85 | "summary": datasets.Value("string"), 86 | "precedent": datasets.Value("string"), 87 | } 88 | 89 | _PRECEDENT_CORPUS_FEATURES = { 90 | "id": datasets.Value("int64"), 91 | "precedent": datasets.Value("string"), 92 | } 93 | 94 | 95 | class LBoxOpenConfig(datasets.BuilderConfig): 96 | """BuilderConfig for OpenLBox.""" 97 | 98 | def __init__( 99 | self, 100 | features, 101 | data_url, 102 | citation, 103 | url, 104 | label_classes=("False", "True"), 105 | **kwargs, 106 | ): 107 | # Version history: 108 | # 0.1.0: Initial version. 109 | super(LBoxOpenConfig, self).__init__( 110 | version=datasets.Version("0.2.0"), **kwargs 111 | ) 112 | self.features = features 113 | self.label_classes = label_classes 114 | self.data_url = data_url 115 | self.citation = citation 116 | self.url = url 117 | 118 | 119 | class LBoxOpen(datasets.GeneratorBasedBuilder): 120 | """The Legal AI Benchmark dataset from Korean Legal Cases.""" 121 | 122 | BUILDER_CONFIGS = [ 123 | LBoxOpenConfig( 124 | name="casename_classification", 125 | description="", 126 | features=_CASENAME_CLASSIFICATION_FEATURES, 127 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/casename_classification/v0.1.2/", 128 | citation="", 129 | url="lbox.kr", 130 | ), 131 | LBoxOpenConfig( 132 | name="casename_classification_plus", 133 | description="", 134 | features=_CASENAME_CLASSIFICATION_FEATURES, 135 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/casename_classification/v0.1.2_plus/", 136 | citation="", 137 | url="lbox.kr", 138 | ), 139 | LBoxOpenConfig( 140 | name="statute_classification", 141 | description="", 142 | features=_STATUTE_CLASSIFICATION_FEATURES, 143 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/statute_classification/v0.1.2/", 144 | citation="", 145 | url="lbox.kr", 146 | ), 147 | LBoxOpenConfig( 148 | name="statute_classification_plus", 149 | description="", 150 | features=_STATUTE_CLASSIFICATION_FEATURES, 151 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/statute_classification/v0.1.2_plus/", 152 | citation="", 153 | url="lbox.kr", 154 | ), 155 | LBoxOpenConfig( 156 | name="ljp_criminal", 157 | description="", 158 | features=_LJP_CRIMINAL, 159 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/judgement_prediction/v0.1.2/criminal/", 160 | citation="", 161 | url="lbox.kr", 162 | ), 163 | LBoxOpenConfig( 164 | name="ljp_civil", 165 | description="", 166 | features=_LJP_CIVIL, 167 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/judgement_prediction/v0.1.2/civil/", 168 | citation="", 169 | url="lbox.kr", 170 | ), 171 | LBoxOpenConfig( 172 | name="summarization", 173 | description="", 174 | features=_SUMMARIZATION_FEATURES, 175 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/summarization/v0.1.0/", 176 | citation="", 177 | url="lbox.kr", 178 | ), 179 | LBoxOpenConfig( 180 | name="summarization_plus", 181 | description="", 182 | features=_SUMMARIZATION_FEATURES, 183 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/summarization/v0.1.0_plus/", 184 | citation="", 185 | url="lbox.kr", 186 | ), 187 | LBoxOpenConfig( 188 | name="precedent_corpus", 189 | description="", 190 | features=_PRECEDENT_CORPUS_FEATURES, 191 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/case_corpus/v0.1.0/", 192 | citation="", 193 | url="lbox.kr", 194 | ), 195 | ] 196 | 197 | def _info(self): 198 | return datasets.DatasetInfo( 199 | description="", 200 | features=datasets.Features(self.config.features), 201 | homepage=self.config.url, 202 | citation="", 203 | ) 204 | 205 | def _split_generators(self, dl_manager): 206 | if self.config.name == "precedent_corpus": 207 | dl_dir = { 208 | "train": dl_manager.download_and_extract( 209 | f"{self.config.data_url}case_corpus-150k.jsonl" 210 | ) 211 | or "", 212 | } 213 | 214 | return [ 215 | datasets.SplitGenerator( 216 | name=datasets.Split.TRAIN, 217 | gen_kwargs={ 218 | "data_file": dl_dir["train"], 219 | "split": datasets.Split.TRAIN, 220 | }, 221 | ) 222 | ] 223 | 224 | elif self.config.name in [ 225 | "casename_classification", 226 | "statute_classification", 227 | "ljp_criminal", 228 | "ljp_civil", 229 | ]: 230 | dl_dir = { 231 | "train": dl_manager.download_and_extract( 232 | f"{self.config.data_url}train.jsonl" 233 | ) 234 | or "", 235 | "valid": dl_manager.download_and_extract( 236 | f"{self.config.data_url}valid.jsonl" 237 | ) 238 | or "", 239 | "test": dl_manager.download_and_extract( 240 | f"{self.config.data_url}test.jsonl" 241 | ) 242 | or "", 243 | "test2": dl_manager.download_and_extract( 244 | f"{self.config.data_url}test2.jsonl" 245 | ) 246 | or "", 247 | } 248 | 249 | return [ 250 | datasets.SplitGenerator( 251 | name=datasets.Split.TRAIN, 252 | gen_kwargs={ 253 | "data_file": dl_dir["train"], 254 | "split": datasets.Split.TRAIN, 255 | }, 256 | ), 257 | datasets.SplitGenerator( 258 | name=datasets.Split.VALIDATION, 259 | gen_kwargs={ 260 | "data_file": dl_dir["valid"], 261 | "split": datasets.Split.VALIDATION, 262 | }, 263 | ), 264 | datasets.SplitGenerator( 265 | name=datasets.Split.TEST, 266 | gen_kwargs={ 267 | "data_file": dl_dir["test"], 268 | "split": datasets.Split.TEST, 269 | }, 270 | ), 271 | datasets.SplitGenerator( 272 | name="test2", 273 | gen_kwargs={ 274 | "data_file": dl_dir["test2"], 275 | "split": "test2", 276 | }, 277 | ), 278 | ] 279 | else: 280 | dl_dir = { 281 | "train": dl_manager.download_and_extract( 282 | f"{self.config.data_url}train.jsonl" 283 | ) 284 | or "", 285 | "valid": dl_manager.download_and_extract( 286 | f"{self.config.data_url}valid.jsonl" 287 | ) 288 | or "", 289 | "test": dl_manager.download_and_extract( 290 | f"{self.config.data_url}test.jsonl" 291 | ) 292 | or "", 293 | } 294 | 295 | return [ 296 | datasets.SplitGenerator( 297 | name=datasets.Split.TRAIN, 298 | gen_kwargs={ 299 | "data_file": dl_dir["train"], 300 | "split": datasets.Split.TRAIN, 301 | }, 302 | ), 303 | datasets.SplitGenerator( 304 | name=datasets.Split.VALIDATION, 305 | gen_kwargs={ 306 | "data_file": dl_dir["valid"], 307 | "split": datasets.Split.VALIDATION, 308 | }, 309 | ), 310 | datasets.SplitGenerator( 311 | name=datasets.Split.TEST, 312 | gen_kwargs={ 313 | "data_file": dl_dir["test"], 314 | "split": datasets.Split.TEST, 315 | }, 316 | ), 317 | ] 318 | 319 | def _generate_examples(self, data_file, split): 320 | with open(data_file, encoding="utf-8") as f: 321 | for line in f: 322 | row = json.loads(line) 323 | yield row["id"], row 324 | -------------------------------------------------------------------------------- /lbox_open/metric/exact_match.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | from collections import defaultdict 6 | 7 | 8 | class ExactMatch: 9 | def __init__(self, parse_keys, empty_value): 10 | 11 | if "doc_id" in parse_keys: 12 | parse_keys.remove("doc_id") 13 | 14 | self.parse_keys = parse_keys 15 | 16 | self.empty_value = empty_value 17 | 18 | def is_empty(self, value): 19 | return (str(value) == str(self.empty_value)) or (value is None) 20 | 21 | def compare_parse(self, gt_parse, pr_parse): 22 | cnt_tp = defaultdict(int) # both exsit and pr is correct 23 | cnt_fn = defaultdict(int) # gt exists but pr is empty 24 | cnt_fp = defaultdict( 25 | int 26 | ) # [gt empty but pr exists] or [gt exists yet pr is wrong] 27 | cnt_tn = defaultdict(int) # gt & pr both empty 28 | 29 | for key in self.parse_keys: 30 | gt_val = gt_parse[key] 31 | pr_val = pr_parse[key] 32 | 33 | if self.is_empty(gt_val): 34 | if self.is_empty(pr_val): 35 | cnt_tn[key] += 1 36 | else: 37 | cnt_fp[key] += 1 38 | else: 39 | if self.is_empty(pr_val): 40 | cnt_fn[key] += 1 41 | else: 42 | if str(gt_val) == str(pr_val): 43 | cnt_tp[key] += 1 44 | else: 45 | cnt_fp[key] += 1 46 | 47 | return (cnt_tp, cnt_fp, cnt_fn, cnt_tn) 48 | 49 | def imp_fill_cnt(self, cnt_all, cnt): 50 | for key in self.parse_keys: 51 | cnt_all[key] += cnt[key] 52 | 53 | def calculate_micro_f1(self, cnt_tp_all, cnt_fp_all, cnt_fn_all): 54 | f1_all = {} 55 | for key in self.parse_keys: 56 | tp = cnt_tp_all[key] 57 | fp = cnt_fp_all[key] 58 | fn = cnt_fn_all[key] 59 | 60 | p = tp / (tp + fp + 1e-5) 61 | r = tp / (tp + fn + 1e-5) 62 | f1 = 2 * p * r / (p + r + 1e-5) 63 | 64 | f1_all[key] = f1 65 | 66 | return f1_all 67 | 68 | def compare_parses(self, gt_parses, pr_parses, confidences=None, threshold=0.0): 69 | cnt_tp_all = defaultdict(int) 70 | cnt_fn_all = defaultdict(int) 71 | cnt_fp_all = defaultdict(int) 72 | cnt_tn_all = defaultdict(int) 73 | if confidences is None: 74 | confidences = [1.0] * len(gt_parses) 75 | assert threshold == 0.0 76 | cnt = 0 77 | for gt_parse, pr_parse, confidence in zip(gt_parses, pr_parses, confidences): 78 | if confidence < threshold: 79 | continue 80 | cnt += 1 81 | (cnt_tp, cnt_fp, cnt_fn, cnt_tn) = self.compare_parse(gt_parse, pr_parse) 82 | 83 | self.imp_fill_cnt(cnt_tp_all, cnt_tp) 84 | self.imp_fill_cnt(cnt_fp_all, cnt_fp) 85 | self.imp_fill_cnt(cnt_fn_all, cnt_fn) 86 | self.imp_fill_cnt(cnt_tn_all, cnt_tn) 87 | 88 | f1_all = self.calculate_micro_f1(cnt_tp_all, cnt_fp_all, cnt_fn_all) 89 | th_recall = cnt / len(confidences) 90 | 91 | return ( 92 | f1_all, 93 | cnt_tp_all, 94 | cnt_fp_all, 95 | cnt_fn_all, 96 | cnt_tn_all, 97 | th_recall, 98 | ) 99 | -------------------------------------------------------------------------------- /lbox_open/metric/rouge_metric_utils.py: -------------------------------------------------------------------------------- 1 | from rouge_score.tokenizers import Tokenizer 2 | 3 | 4 | class WhiteSpaceTokenizer(Tokenizer): 5 | def tokenize(self, text): 6 | return text.split() 7 | -------------------------------------------------------------------------------- /lbox_open/model/generative_baseline_model.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import os 6 | from collections import defaultdict 7 | from itertools import zip_longest 8 | from pathlib import Path 9 | from pprint import pprint 10 | 11 | 12 | import datasets 13 | import pytorch_lightning as pl 14 | import torch 15 | from openprompt.utils.metrics import generation_metric 16 | from transformers.generation_utils import GenerationMixin 17 | from rouge_score import rouge_scorer 18 | import numpy as np 19 | 20 | import lbox_open.utils.general_utils as gu 21 | from lbox_open import openprompt_wrapper 22 | from lbox_open.model.model_optimizer import get_lr_dict, get_optimizer 23 | from lbox_open.parser.output_parser_utils import ( 24 | cal_em_from_parses, 25 | get_parses_from_eval_results, 26 | ) 27 | from lbox_open.metric import rouge_metric_utils 28 | 29 | 30 | class GenerativeParser(pl.LightningModule, GenerationMixin): 31 | def __init__(self, cfg, plm, plm_tokenizer, input_templates): 32 | super().__init__() 33 | self.task = cfg.model.task 34 | self.mparam = cfg.model 35 | self.tparam = cfg.train 36 | self.iparam = cfg.infer 37 | self.cfg_name = cfg.name 38 | self.target_parses_dict = cfg.model.target_parses_dict 39 | 40 | self.prompt_models = {} 41 | self.plm = plm 42 | for target_parse, target_sub_parses in cfg.model.target_parses_dict.items(): 43 | # keep them for just in case we tune plm 44 | prompt_model = openprompt_wrapper.PromptForGenerationCustom( 45 | plm=plm, 46 | template=input_templates[target_parse], 47 | freeze_plm=cfg.model.plm.freeze, 48 | tokenizer=plm_tokenizer, 49 | plm_eval_mode=cfg.model.plm.eval_mode, 50 | ) 51 | 52 | self.prompt_models[target_parse] = prompt_model 53 | 54 | self.prompt_models = torch.nn.ModuleDict(self.prompt_models) 55 | 56 | # if self.plm.config.is_encoder_decoder: 57 | self.generation_arguments = { 58 | "max_length": cfg.infer.max_length, 59 | "max_new_tokens": cfg.infer.get("max_new_tokens", None), 60 | "min_length": cfg.infer.min_length, 61 | "temperature": cfg.infer.temperature, 62 | "do_sample": cfg.infer.do_sample, 63 | "top_k": cfg.infer.top_k, 64 | "top_p": cfg.infer.top_p, 65 | "repetition_penalty": cfg.infer.repetition_penalty, 66 | "num_beams": cfg.infer.num_beams, 67 | "bad_words_ids": cfg.infer.bad_words_ids, 68 | "use_cache": True, 69 | } 70 | 71 | if plm.config.is_encoder_decoder: 72 | # remove max_new_tokens 73 | print(f"The model is of is_encoder_decoder. Thus we remove max new tokens.") 74 | self.generation_arguments.pop("max_new_tokens") 75 | else: 76 | if cfg.infer.get("max_new_tokens", None): 77 | print( 78 | f"Max length in generation option shall be ignored as max_new_tokens presents." 79 | ) 80 | self.generation_arguments["max_length"] = None 81 | 82 | self.rouge_scorer = rouge_scorer.RougeScorer( 83 | ["rouge1", "rouge2", "rougeL"], tokenizer=rouge_metric_utils.WhiteSpaceTokenizer() 84 | ) 85 | def forward(self, target_parse, batch): 86 | loss = self.prompt_models[target_parse](batch[target_parse]) 87 | return loss 88 | 89 | def training_step(self, batch, batch_idx): 90 | n_keys = len(self.target_parses_dict) 91 | loss = 0 92 | for i_target, (target_parse, _) in enumerate(self.target_parses_dict.items()): 93 | loss += self.forward(target_parse, batch) 94 | return {"loss": loss / n_keys} 95 | 96 | def training_epoch_end(self, outputs): 97 | 98 | loss_all = torch.stack(self.gather_loss(outputs)) 99 | ave_loss = torch.mean(loss_all) 100 | self.log("training__ave_loss", ave_loss) 101 | 102 | def gather_loss(self, outputs): 103 | loss_all = [] 104 | for output in outputs: 105 | loss_all.append(output["loss"]) 106 | 107 | return loss_all 108 | 109 | def validation_step(self, batch, batch_idx): 110 | return self._eval_step(batch, batch_idx) 111 | 112 | def validation_epoch_end(self, outputs): 113 | ( 114 | eval_score, 115 | doc_ids_all, 116 | pr_texts_all, 117 | gt_texts_all, 118 | confidences_all, 119 | ) = self._eval_epoch_end(outputs) 120 | print("\nValidation!-----------------------------------------") 121 | pprint(eval_score) 122 | pprint(f"GT: {gt_texts_all[self.tparam.validation_target_parse][0:2]}") 123 | pprint(f"PR: {pr_texts_all[self.tparam.validation_target_parse][0:2]}") 124 | 125 | if self.tparam.validation_metric in ["sentence_bleu"]: 126 | validation_score = eval_score[self.tparam.validation_target_parse] 127 | 128 | elif self.tparam.validation_metric in ["rougeL"]: 129 | validation_score = eval_score[self.tparam.validation_target_parse] 130 | 131 | elif self.tparam.validation_metric in ["em"]: 132 | if self.tparam.validation_sub_param.method == "single_parse": 133 | sub_parse_name = self.tparam.validation_sub_param.target_sub_parse 134 | validation_score = eval_score[self.tparam.validation_target_parse][ 135 | "f1" 136 | ][sub_parse_name] 137 | elif self.tparam.validation_sub_param.method == "average": 138 | validation_score = 0 139 | cnt = 0 140 | for sub_parse_name, score in eval_score[ 141 | self.tparam.validation_target_parse 142 | ]["f1"].items(): 143 | validation_score += score 144 | cnt += 1 145 | validation_score /= cnt 146 | elif self.tparam.validation_sub_param.method == "text_em": 147 | validation_score = eval_score[self.tparam.validation_target_parse][ 148 | "text_em" 149 | ] 150 | else: 151 | raise ValueError 152 | for sub_parse_name, score in eval_score[ 153 | self.tparam.validation_target_parse 154 | ]["f1"].items(): 155 | self.log(sub_parse_name, score) 156 | self.log( 157 | f"{self.tparam.validation_target_parse}_text_em", 158 | eval_score[self.tparam.validation_target_parse]["text_em"], 159 | ) 160 | else: 161 | raise ValueError 162 | 163 | self.log( 164 | f"{self.tparam.validation_metric}_{self.tparam.validation_sub_param.method}", 165 | validation_score, 166 | ) 167 | 168 | def test_step(self, batch, batch_idx): 169 | return self._eval_step(batch, batch_idx) 170 | 171 | def test_epoch_end(self, outputs): 172 | output_save_dir = ( 173 | Path(self.tparam.weight.path).parent / "analysis" / self.cfg_name 174 | ) 175 | os.makedirs(output_save_dir, exist_ok=True) 176 | ( 177 | eval_score, 178 | doc_ids_all, 179 | pr_texts_all, 180 | gt_texts_all, 181 | confidences_all, 182 | ) = self._eval_epoch_end( 183 | outputs, save=True, output_save_dir=output_save_dir, verbose=True 184 | ) 185 | print("Test!-----------------------------------------------") 186 | print(eval_score) 187 | 188 | output_save_path_eval_score = output_save_dir / "eval_score.json" 189 | gu.save_json(output_save_path_eval_score, eval_score) 190 | 191 | eval_result = { 192 | "doc_ids": doc_ids_all, 193 | "pr_texts": pr_texts_all, 194 | "gt_texts": gt_texts_all, 195 | } 196 | output_save_path_eval_result = output_save_dir / "eval_result.json" 197 | gu.save_json(output_save_path_eval_result, eval_result) 198 | 199 | output_save_path_confidences = output_save_dir / "confidences.json" 200 | gu.save_json(output_save_path_confidences, confidences_all) 201 | 202 | # add doc_ids to confidences_all 203 | confidences_all_with_doc_ids = {} 204 | for key_target_parse, confidences in confidences_all.items(): 205 | c_with_ids = [ 206 | (doc_id, c) 207 | for doc_id, c in zip_longest(doc_ids_all[key_target_parse], confidences) 208 | ] 209 | confidences_all_with_doc_ids[key_target_parse] = c_with_ids 210 | 211 | output_save_path_confidences_with_doc_ids = ( 212 | output_save_dir / "confidences_with_doc_ids.json" 213 | ) 214 | gu.save_json( 215 | output_save_path_confidences_with_doc_ids, confidences_all_with_doc_ids 216 | ) 217 | 218 | def _eval_step(self, batch, batch_idx): 219 | 220 | out = defaultdict(dict) 221 | for target_parse, _ in self.target_parses_dict.items(): 222 | _prs, _gts, confidences = self.evaluate(target_parse, batch) 223 | 224 | # add confidences as a saved output. 225 | out[target_parse]["pr_texts"] = _prs 226 | out[target_parse]["gt_texts"] = _gts 227 | out[target_parse]["doc_ids"] = batch[target_parse]["guid"] 228 | out[target_parse]["confidences"] = confidences 229 | 230 | return out 231 | 232 | def _eval_epoch_end(self, outputs, save=False, output_save_dir=None, verbose=False): 233 | # outputs = [list of each step outputs] 234 | pr_texts_all = self.gather_step_outputs("pr_texts", outputs) 235 | gt_texts_all = self.gather_step_outputs("gt_texts", outputs) 236 | doc_ids_all = self.gather_step_outputs("doc_ids", outputs) 237 | confidences_all = self.gather_step_outputs("confidences", outputs) 238 | 239 | eval_score = self.cal_score( 240 | doc_ids_all, 241 | pr_texts_all, 242 | gt_texts_all, 243 | save=save, 244 | output_save_dir=output_save_dir, 245 | confidences=confidences_all, 246 | threshold=0.0, 247 | verbose=False, 248 | ) 249 | 250 | return eval_score, doc_ids_all, pr_texts_all, gt_texts_all, confidences_all 251 | 252 | def cal_score( 253 | self, 254 | doc_ids_all, 255 | pr_texts_all, 256 | gt_texts_all, 257 | save=False, 258 | output_save_dir=None, 259 | confidences=None, 260 | threshold=0.0, 261 | verbose=False, 262 | input_texts=None, 263 | ): 264 | 265 | if self.tparam.validation_metric == "sentence_bleu": 266 | eval_score = {} 267 | for target_parse, _ in self.target_parses_dict.items(): 268 | groundtruth_sentence = gt_texts_all[target_parse] 269 | generated_sentence = pr_texts_all[target_parse] 270 | eval_score[target_parse] = generation_metric( 271 | generated_sentence, groundtruth_sentence, "sentence_bleu" 272 | ) 273 | elif self.tparam.validation_metric == "rougeL": 274 | eval_score = {} 275 | for target_parse, _ in self.target_parses_dict.items(): 276 | pr_texts = pr_texts_all[target_parse] 277 | gt_texts = gt_texts_all[target_parse] 278 | target_scores = [] 279 | for pr_text, gt_text in zip_longest(pr_texts, gt_texts): 280 | r_score = self.rouge_scorer.score( 281 | prediction=pr_text, target=gt_text 282 | ) 283 | 284 | target_scores.append( 285 | r_score[self.tparam.validation_metric].fmeasure 286 | ) 287 | 288 | eval_score[target_parse] = np.mean( 289 | target_scores 290 | ) 291 | print(eval_score) 292 | 293 | elif self.tparam.validation_metric == "em": 294 | # EM score 295 | parses = get_parses_from_eval_results( 296 | self.iparam, 297 | self.target_parses_dict, 298 | doc_ids_all, 299 | gt_texts_all, 300 | pr_texts_all, 301 | ) 302 | 303 | # analysis 304 | eval_score = cal_em_from_parses( 305 | self.iparam, 306 | self.target_parses_dict, 307 | parses, 308 | verbose=verbose, 309 | save=save, 310 | output_save_dir=output_save_dir, 311 | input_texts=input_texts, 312 | confidences=confidences, 313 | threshold=threshold, 314 | ) 315 | 316 | # text exact matching 317 | for target_parse, target_sub_parses in self.target_parses_dict.items(): 318 | gt_texts = gt_texts_all[target_parse] 319 | pr_texts = pr_texts_all[target_parse] 320 | corrects = [str(x) == str(y) for x, y in zip(gt_texts, pr_texts)] 321 | text_em_score = sum(corrects) / len(corrects) 322 | eval_score[target_parse]["text_em"] = text_em_score 323 | 324 | else: 325 | raise ValueError 326 | return eval_score 327 | 328 | def gather_step_outputs(self, key, outputs): 329 | outputs_all = defaultdict(list) 330 | 331 | for target_parse, _ in self.target_parses_dict.items(): 332 | for output in outputs: 333 | outputs_all[target_parse] += output[target_parse][key] 334 | 335 | return outputs_all 336 | 337 | def configure_optimizers(self): 338 | optimizer = get_optimizer(self.mparam, self.tparam, self) 339 | lr_dict = get_lr_dict(optimizer, self.tparam, "prompt") 340 | 341 | return {"optimizer": optimizer, "lr_scheduler": lr_dict} 342 | 343 | def evaluate(self, target_parse, batch): 344 | generated_sentence = [] 345 | groundtruth_sentence = [] 346 | 347 | seqs, output_sentence, confidences = self.prompt_models[target_parse].generate( 348 | batch[target_parse], **self.generation_arguments 349 | ) 350 | generated_sentence.extend(output_sentence) 351 | groundtruth_sentence.extend(batch[target_parse]["tgt_text"]) 352 | 353 | return generated_sentence, groundtruth_sentence, confidences 354 | -------------------------------------------------------------------------------- /lbox_open/model/model_optimizer.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import torch 6 | import transformers 7 | 8 | map_optimizers_name_to_type = { 9 | "sgd": torch.optim.SGD, 10 | "adam": torch.optim.Adam, 11 | "adamw": torch.optim.AdamW, 12 | } 13 | 14 | 15 | def get_optimizer(mparam, tparam, model): 16 | # todo: plm training part 17 | _lr_type, lr_param = get_lr_type_and_param(tparam, "prompt") 18 | 19 | # prompt 20 | optimizer_type = map_optimizers_name_to_type[tparam.optim.prompt.optimizer_type] 21 | 22 | if model.task in [ 23 | "ljp_civil", 24 | "ljp_criminal", 25 | "casename_classification", 26 | "statute_classification", 27 | "summarization", 28 | ]: 29 | optimizer_grouped_parameters = [] 30 | if not mparam.plm.freeze: 31 | optimizer_grouped_parameters.append( 32 | { 33 | "params": list( 34 | filter(lambda p: p.requires_grad, model.plm.parameters()) 35 | ), 36 | "lr": tparam.optim.plm.lr, 37 | } 38 | ) 39 | 40 | for target_parse, _target_sub_parses in model.target_parses_dict.items(): 41 | optimizer_grouped_parameters.append( 42 | { 43 | "params": [ 44 | p 45 | for n, p in model.prompt_models[ 46 | target_parse 47 | ].template.named_parameters() 48 | if "raw_embedding" not in n 49 | ] 50 | } 51 | ) 52 | 53 | optimizer = optimizer_type( 54 | optimizer_grouped_parameters, lr=tparam.optim.prompt.lr, weight_decay=0 55 | ) 56 | 57 | else: 58 | raise NotImplementedError 59 | 60 | return optimizer 61 | 62 | 63 | def get_lr_type_and_param(tparam, key): 64 | lr_type = tparam.optim[key].lr_scheduler_type 65 | lr_param = tparam.optim[key].lr_scheduler_param[lr_type] 66 | return lr_type, lr_param 67 | 68 | 69 | def gen_lr_scheduler(tparam, optimizer, lr_type, lr_param): 70 | if lr_type == "constant": 71 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 72 | optimizer, lr_lambda=[lambda epoch: 1, lambda epoch: 1], verbose=True 73 | ) 74 | elif lr_type == "multi_step_lr": 75 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 76 | optimizer, 77 | milestones=lr_param["milestones"], 78 | gamma=lr_param["gamma"], 79 | verbose=True, 80 | ) 81 | 82 | elif lr_type == "warmup_constant": 83 | lr_scheduler = transformers.get_constant_schedule_with_warmup( 84 | optimizer, num_warmup_steps=lr_param.num_warmup_steps 85 | ) 86 | elif lr_type == "cos_with_hard_restarts": 87 | lr_scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( 88 | optimizer, 89 | num_warmup_steps=lr_param.num_warmup_steps, 90 | num_training_steps=lr_param.num_training_steps, 91 | num_cycles=lr_param.num_cycles, 92 | ) 93 | elif lr_type == "linear": 94 | lr_scheduler = transformers.get_linear_schedule_with_warmup( 95 | optimizer, 96 | num_warmup_steps=lr_param.num_warmup_steps, 97 | num_training_steps=tparam.max_epochs, 98 | ) 99 | 100 | else: 101 | raise NotImplementedError 102 | return lr_scheduler 103 | 104 | 105 | def get_lr_dict(optimizer, tparam, key): 106 | lr_type, lr_param = get_lr_type_and_param(tparam, key) 107 | lr_scheduler = gen_lr_scheduler(tparam, optimizer, lr_type, lr_param) 108 | lr_dict = { 109 | "scheduler": lr_scheduler, 110 | "interval": "epoch", 111 | "frequency": 1, 112 | "monitor": "val_loss", 113 | "strict": True, 114 | "name": None, 115 | } 116 | 117 | return lr_dict 118 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import InputExampleWrapper 2 | from .pipeline_base import PromptForGenerationCustom 3 | from .plms import load_plm_wrapper 4 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from openprompt import data_utils 4 | 5 | 6 | class InputExampleWrapper(data_utils.InputExample): 7 | def to_json_string(self): 8 | r"""Serialize this instance to a JSON string.""" 9 | # return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 10 | return ( 11 | json.dumps(self.to_dict(), indent=2, sort_keys=True, ensure_ascii=False) 12 | + "\n" 13 | ) 14 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/pipeline_base.py: -------------------------------------------------------------------------------- 1 | # Wonseok add PromptForGenerationCustom by copying and tweak OpenPrompt-v1.0.0 PromptForGeneration class. 2 | # We modify two things: (1) L343--L345 for the compatibility with transformesr 4.19.4, and 3 | # (2) recover "confidences" which was available in the initial version of OpenPrompt 4 | 5 | from copy import deepcopy 6 | from typing import Any, Dict, Optional, Union 7 | 8 | import numpy as np 9 | import torch 10 | from openprompt.data_utils import InputFeatures 11 | from openprompt.pipeline_base import PromptForGeneration, PromptModel 12 | from openprompt.prompt_base import Template, Verbalizer 13 | from openprompt.utils import round_list, signature 14 | from openprompt.utils.logging import logger 15 | from torch import nn 16 | from transformers.generation_utils import GenerationMixin 17 | from transformers.tokenization_utils import PreTrainedTokenizer 18 | from transformers.utils.dummy_pt_objects import PreTrainedModel 19 | from yacs.config import CfgNode 20 | 21 | 22 | class PromptForGenerationCustom(torch.nn.Module, GenerationMixin): 23 | r"""``PromptModel`` with generation loss caculation and generation utils integrated. 24 | 25 | 26 | Args: 27 | plm (:obj:`PretrainedModel`): A pre-traiend model you decide to use for generation, e.g. GPT. 28 | template (:obj:`Template`): A ``Template`` object you use to wrap the input text for classification, e.g. ``PrefixTemplate``. 29 | tokenizer (:obj:`Tokenizer`): A ``Tokenizer`` of the current model. 30 | gen_config (:obj:`CfgNode`): The generation configs to pass into `GenerationMixin.generate `_ 31 | freeze_plm (:obj:`bool`): whether or not to freeze the pretrained language model 32 | plm_eval_mode (:obj:`bool`): this is a stronger freezing mode than freeze_plm, i.e. the dropout of the model is turned off. No matter whether the other part is set to train. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | plm: PreTrainedModel, 38 | template: Template, 39 | freeze_plm: bool = False, 40 | plm_eval_mode: bool = False, 41 | gen_config: Optional[CfgNode] = None, 42 | tokenizer: Optional[PreTrainedTokenizer] = None, 43 | ): 44 | super().__init__() 45 | self.freeze_plm = freeze_plm 46 | if tokenizer is None: 47 | assert ( 48 | template.tokenizer is not None 49 | ), "Tokenizer can't be set from input args or template" 50 | self.tokenizer = template.tokenizer 51 | else: 52 | self.tokenizer = tokenizer 53 | self.prompt_model = PromptModel(plm, template, freeze_plm, plm_eval_mode) 54 | 55 | self.loss_fct = nn.CrossEntropyLoss(reduction="none") 56 | self.config = plm.config 57 | if gen_config: 58 | for key in gen_config: 59 | setattr(self.config, key, gen_config[key]) 60 | self.in_generation_function = False 61 | 62 | self.main_input_name = ( 63 | self.prompt_model.main_input_name 64 | ) # for transformers 4.17.0 and higher. 65 | 66 | @property 67 | def plm(self): 68 | return self.prompt_model.plm 69 | 70 | @property 71 | def template(self): 72 | return self.prompt_model.template 73 | 74 | @property 75 | def device(self): 76 | return self.plm.device 77 | 78 | def shift_logits_and_labels(self, logits, loss_ids, reference_ids): 79 | 80 | r""" 81 | Left shift the label, and make label of the positions that are 82 | not loss position to -100, which is the ignore index in pytorch's 83 | loss function. 84 | 85 | Args: 86 | logits (:obj:`torch.Tensor`): 87 | batch (:obj:`InputFeatures`): The input features of batchified data sequences. 88 | 89 | Returns: 90 | shift_logits (:obj:`torch.Tensor`): 91 | shift_input_ids (:obj:`List[int]`): 92 | 93 | """ 94 | 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_loss_ids = loss_ids[..., 1:].contiguous() 97 | shift_input_ids = reference_ids[..., 1:].contiguous() 98 | shift_input_ids = torch.where(shift_loss_ids > 0, shift_input_ids, -100) 99 | return shift_logits, shift_input_ids 100 | 101 | def forward(self, *args, **kwargs): 102 | r"""In generation process, it will use the plm's forward function. 103 | This is because, in the first step we will directly call the process_batch function to 104 | generate initial input with the template, after that the all template 105 | have been processed into the past_key_value, 106 | then we can use the normal generation function. 107 | In learning process, the forward is linked to ``_forward`` functions. 108 | in which the loss will be calculated for all the positions in the same time. 109 | """ 110 | if self.in_generation_function: 111 | return self.plm.forward(*args, **kwargs) 112 | else: 113 | return self._forward(*args, **kwargs) 114 | 115 | def _forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor: 116 | r""" 117 | This is the forward method of the training of generation in prompt-learning framework. 118 | 119 | Args: 120 | batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences. 121 | 122 | Returns: 123 | loss(:obj:torch.Tensor): The loss of the current generation procedure. 124 | """ 125 | if self.config.is_encoder_decoder: 126 | reference_ids = batch["decoder_input_ids"] 127 | else: 128 | reference_ids = batch[ 129 | "input_ids" 130 | ] # in case in some template, these field is dropped 131 | outputs = self.prompt_model(batch) 132 | logits = outputs.logits 133 | logits, labels = self.shift_logits_and_labels( 134 | logits, batch["loss_ids"], reference_ids 135 | ) 136 | batch_size, seq_len, vocab_size = logits.shape 137 | loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 138 | loss = loss.view(batch_size, -1).sum(dim=-1) # TODO support more objectives 139 | loss = loss.mean() 140 | return loss 141 | 142 | def generate( 143 | self, 144 | batch: Union[Dict, InputFeatures], 145 | verbose: Optional[bool] = False, 146 | **generation_kwargs, 147 | ): 148 | r"""This function wraps the generate() methods in parent class ``GenerationMixin``. 149 | Forward uses the ``PretrainedModel``'s forward method. 150 | generation_kwargs include all the parameters that are passed in to 151 | ``transformers.generation_util.GenerationMixin.generate`` 152 | 153 | Args: 154 | batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences. 155 | verbose (:obj:`Optional[bool]`): Set to true to verbose the generated sentence. 156 | 157 | Returns: 158 | output_sequences (:obj:`List[torch.Tensor]`): The raw sequences generated by the generation model. 159 | generated_sentences (:obj:`List[torch.Tensor]`): The generated sentences that have been post-processed. 160 | """ 161 | input_generation_kwargs = { 162 | key: value 163 | for key, value in generation_kwargs.items() 164 | if key in signature(GenerationMixin.generate).args 165 | } 166 | if self.config.is_encoder_decoder: 167 | loss_ids_start = batch["loss_ids"].argmax(dim=-1) 168 | assert ( 169 | loss_ids_start.min() == loss_ids_start.max() 170 | ), "The generation start from different position in a batch." 171 | batch["decoder_input_ids"] = batch["decoder_input_ids"][ 172 | :, : loss_ids_start.min() + 1 173 | ] 174 | input_length = batch["decoder_input_ids"].size(1) 175 | batch_size = batch["decoder_input_ids"].size(0) 176 | 177 | self.generate_ith_token = 0 178 | self.in_generation_function = True 179 | 180 | output_dict = super().generate( 181 | **batch, 182 | **input_generation_kwargs, 183 | pad_token_id=self.tokenizer.pad_token_id, 184 | eos_token_id=self.tokenizer.eos_token_id, 185 | output_scores=True, 186 | return_dict_in_generate=True, 187 | ) 188 | output_sequences = output_dict["sequences"] 189 | output_scores = output_dict[ 190 | "scores" 191 | ] # (L tuples, (B batches, N tokens)). each tuple = (B, 192 | self.in_generation_function = False 193 | output_sequences = output_sequences.cpu().tolist() 194 | generated_sentences, confidences = self.post_processing_with_confidence( 195 | output_sequences=output_sequences, 196 | input_lengths=input_length, 197 | output_scores=output_scores, 198 | ) 199 | # output_sequences = super().generate(**batch, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id) 200 | # self.in_generation_function = False 201 | # output_sequences = output_sequences.cpu().tolist() 202 | # generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_length) 203 | else: 204 | input_length = batch["input_ids"].size(1) 205 | batch_size = batch["input_ids"].size(0) 206 | 207 | # Currently huggingface transformers only support single sample generation, or padding to the left (instead of the right). 208 | # because it will only extract the last position of the output 209 | # generate one_by_one 210 | if "input_ids_len" in batch: 211 | input_real_lens = batch["input_ids_len"] 212 | else: 213 | input_real_lens = torch.sum( 214 | (batch["input_ids"] != self.tokenizer.pad_token_id).to(torch.int), 215 | dim=-1, 216 | ) 217 | output_sequences = [] 218 | output_scores = [] 219 | for instance_id in range(batch_size): 220 | # remove the pad token 221 | instance = { 222 | key: batch[key][instance_id : instance_id + 1][ 223 | :, : input_real_lens[instance_id] 224 | ] 225 | for key in batch 226 | if isinstance(batch[key], torch.Tensor) 227 | and batch[key].shape[:2] == torch.Size([batch_size, input_length]) 228 | } 229 | self.generate_ith_token = 0 230 | self.in_generation_function = True 231 | output_dict = super().generate( 232 | **instance, 233 | **input_generation_kwargs, 234 | pad_token_id=self.tokenizer.pad_token_id, 235 | eos_token_id=self.tokenizer.eos_token_id, 236 | output_scores=True, 237 | return_dict_in_generate=True, 238 | ) 239 | output_sequence = output_dict["sequences"] 240 | self.in_generation_function = False 241 | output_sequences.extend( 242 | output_sequence.cpu().tolist() 243 | ) # TODO: to support generate multiple sentence 244 | 245 | output_score = output_dict["scores"] 246 | output_scores.append(output_score) 247 | 248 | generated_sentences, confidences = self.post_processing_with_confidence( 249 | output_sequences=output_sequences, 250 | input_lengths=input_real_lens.cpu().tolist(), 251 | output_scores=output_scores, 252 | ) 253 | # for instance_id in range(batch_size): 254 | # # remove the pad token 255 | # instance = {key: batch[key][instance_id:instance_id+1][:,:input_real_lens[instance_id]] for key in batch if isinstance(batch[key], torch.Tensor) and batch[key].shape[:2]==torch.Size([batch_size, input_length])} 256 | # self.generate_ith_token = 0 257 | # self.in_generation_function = True 258 | # output_sequence = super().generate(**instance, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id) 259 | # self.in_generation_function = False 260 | # output_sequences.extend(output_sequence.cpu().tolist()) # TODO: to support generate multiple sentence 261 | # generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_real_lens.cpu().tolist()) 262 | if verbose: 263 | logger.info(f"Generated:{generated_sentences}") 264 | return output_sequences, generated_sentences, confidences 265 | 266 | def post_processing(self, output_sequences, input_lengths): 267 | r""" 268 | Post-process the sequences generated by the generation model. 269 | 270 | Args: 271 | output_sequences (:obj:`torch.Tensor`): The raw sequences generated by the generation model. 272 | input_lengths (:obj:`int` or `list`): The length(s) of the input sequence. 273 | 274 | Returns: 275 | :obj:`List`: The generated sentences that have been post-processed. 276 | """ 277 | generated_sentences = [] 278 | if type(input_lengths) == int: 279 | input_lengths = [input_lengths] * len(output_sequences) 280 | for sent_id, seq in enumerate(output_sequences): 281 | seq = seq[input_lengths[sent_id] :] 282 | 283 | if ( 284 | hasattr(self.tokenizer, "eos_token") 285 | and self.tokenizer.eos_token is not None 286 | ): 287 | text_output = self.tokenizer.decode( 288 | seq, clean_up_tokenization_spaces=True, skip_special_tokens=False 289 | ) 290 | idx = text_output.find(self.tokenizer.eos_token) 291 | if idx >= 0: 292 | text_output = text_output[:idx] 293 | else: 294 | text_output = self.tokenizer.decode( 295 | seq, clean_up_tokenization_spaces=True, skip_special_tokens=True 296 | ) 297 | text_output = text_output.strip() 298 | generated_sentences.append(text_output) 299 | return generated_sentences 300 | 301 | def prepare_inputs_for_generation( 302 | self, input_ids: Optional[torch.Tensor] = None, **model_kwargs 303 | ): 304 | r"""This function wraps the ``prepare_inputs_for_generation`` function in the huggingface transformers. 305 | 306 | When the `past` not in model_kwargs, we prepare the input from scratch. 307 | When `past` is in model_kwargs, we don't need to prepare the template wrapped input, 308 | instead we use the inner pretrain_models' function to prepare the next step's input. 309 | `model_kwargs` includes all the argument passed in the `batch`: InputFeatures, except ``input_ids`` 310 | , as long as they do not conflict with keywords in ``generation_kwargs``. if 'past' not in model_kwargs: # the past_key_value not in model_kwargs, then we need to prepare input from scrath 311 | , as long as they do not conflict with keywords in ``generation_kwargs``. 312 | 313 | Args: 314 | input_ids(:obj:`torch.Tensor`): Indices of input sequence tokens in the vocabulary. 315 | """ 316 | if ( 317 | self.generate_ith_token == 0 and "encoder_outputs" not in model_kwargs 318 | ): # generating the first token in decoder only setting. 319 | 320 | batch = InputFeatures(input_ids=input_ids, **model_kwargs) 321 | model_inputs = self.prompt_model.prepare_model_inputs(batch) 322 | # check the compatibility for more models. Having checked gpt2, T5 323 | else: # generating the subsequence generation can use the default setting 324 | model_inputs = self.plm.prepare_inputs_for_generation( 325 | input_ids, **model_kwargs 326 | ) 327 | self.last_model_inputs = model_inputs # to update the model_kwargs in _update_model_kwargs_for_generation, in-place operation. 328 | return model_inputs 329 | 330 | def _update_model_kwargs_for_generation( 331 | self, outputs, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False 332 | ) -> Dict[str, Any]: 333 | r"""The parents class's ``_update_model_kwargs_for_generation`` method will 334 | add ``past_key_values`` to model_kwargs, and update ``token_type_ids``, and ``attention_mask_ids``. 335 | 336 | In case some of the model_kwargs are modified in the prepare_inputs_for_generation function 337 | and should be used as the subsequent model_kwargs, we upate these kwargs before the parent class 338 | call. 339 | 340 | Other updates should be added here after the parent's function call. 341 | 342 | Args: 343 | outputs (:obj:`torch.Tensor`): 344 | is_encoder_decoder (:obj:`bool`, defaults to False): 345 | """ 346 | if self.generate_ith_token == 0: 347 | for key in self.last_model_inputs: 348 | if key in model_kwargs: 349 | model_kwargs[key] = self.last_model_inputs[key] 350 | model_kwargs = super( 351 | PromptForGeneration, PromptForGeneration 352 | )._update_model_kwargs_for_generation( 353 | outputs=outputs, 354 | model_kwargs=model_kwargs, 355 | is_encoder_decoder=is_encoder_decoder, 356 | ) 357 | self.generate_ith_token += 1 358 | return model_kwargs 359 | 360 | def _prepare_encoder_decoder_kwargs_for_generation( 361 | self, 362 | input_ids: torch.LongTensor, 363 | model_kwargs, 364 | model_input_name: Optional[str] = None, 365 | ) -> Dict[str, Any]: 366 | r"""This function resemble the function in GeneraionMix 367 | 368 | Args: 369 | input_ids (:obj:`torch.LongTensor`) The input ids for 370 | """ 371 | if "encoder_outputs" not in model_kwargs: 372 | # retrieve encoder hidden states 373 | encoder = self.plm.get_encoder() 374 | encoder_kwargs = { 375 | argument: value 376 | for argument, value in model_kwargs.items() 377 | if not ( 378 | argument.startswith("decoder_") or argument.startswith("cross_attn") 379 | ) 380 | } 381 | model_input_name = ( 382 | model_input_name 383 | if model_input_name is not None 384 | else self.main_input_name 385 | ) 386 | batch = {model_input_name: input_ids, **encoder_kwargs} 387 | model_inputs = self.prompt_model.prepare_model_inputs( 388 | batch 389 | ) # This line differs from the orinigal code base, we should process the input 390 | # with our template, then pass it into the model. 391 | # some of the arguments may have been changed by the template, 392 | # e.g. the attention mask. Here we update the model_kwargs 393 | for key in model_kwargs: 394 | if key in model_inputs: 395 | model_kwargs[key] = model_inputs[key] 396 | model_inputs_with_use_cache_false = deepcopy(model_inputs) 397 | model_inputs_with_use_cache_false["use_cache"] = False 398 | model_kwargs["encoder_outputs"] = encoder( 399 | return_dict=True, **model_inputs_with_use_cache_false 400 | ) 401 | return model_kwargs 402 | 403 | ## We comment this code since it conflict with [OpenDelta](https://github.com/thunlp/OpenDelta) 404 | # def state_dict(self, *args, **kwargs): 405 | # """ Save the model using template and plm's save methods. """ 406 | # _state_dict = {} 407 | # if not self.prompt_model.freeze_plm: 408 | # _state_dict['plm'] = self.plm.state_dict(*args, **kwargs) 409 | # _state_dict['template'] = self.template.state_dict(*args, **kwargs) 410 | # return _state_dict 411 | 412 | # def load_state_dict(self, state_dict, *args, **kwargs): 413 | # """ Load the model using template and plm's load methods. """ 414 | # if 'plm' in state_dict and not self.prompt_model.freeze_plm: 415 | # self.plm.load_state_dict(state_dict['plm'], *args, **kwargs) 416 | # self.template.load_state_dict(state_dict['template'], *args, **kwargs) 417 | 418 | def _reorder_cache(self, past, beam_idx): 419 | r"""Use the plm's default _reorder_cache function""" 420 | return self.plm._reorder_cache(past, beam_idx) 421 | 422 | def parallelize(self, device_map=None): 423 | r"""Parallelize the model across device""" 424 | if hasattr(self.plm, "parallelize"): 425 | self.plm.parallelize(device_map) 426 | self.device_map = self.plm.device_map 427 | else: 428 | raise NotImplementedError( 429 | "parallelize method was not implemented for this plm." 430 | ) 431 | 432 | def deparallelize(self): 433 | r"""Deparallelize the model across device""" 434 | if hasattr(self.plm, "deparallelize"): 435 | self.plm.deparallelize() 436 | self.device_map = None 437 | else: 438 | raise NotImplementedError( 439 | "parallelize method was not implemented for this plm." 440 | ) 441 | 442 | def post_processing_with_confidence( 443 | self, output_sequences, input_lengths, output_scores 444 | ): 445 | r""" 446 | Post-process the sequences generated by the generation model. 447 | 448 | Args: 449 | output_sequences (:obj:`torch.Tensor`): The raw sequences generated by the generation model. 450 | input_lengths (:obj:`int` or `list`): The length(s) of the input sequence. 451 | 452 | Returns: 453 | :obj:`List`: The generated sentences that have been post-processed. 454 | """ 455 | generated_sentences = [] 456 | if type(input_lengths) == int: 457 | input_lengths = [input_lengths] * len(output_sequences) 458 | confidences = [] 459 | confidences_list = [] 460 | for sent_id, seq in enumerate(output_sequences): 461 | seq = seq[input_lengths[sent_id] :] 462 | if self.config.is_encoder_decoder: 463 | # [T, B, Ntoken] 464 | assert len(seq) == len( 465 | output_scores 466 | ) # (T, B, Ntoken), T is a length of sequence. 467 | else: 468 | # [B, T, Ntoken] 469 | assert len(seq) == len(output_scores[sent_id]) 470 | 471 | text_output = self.tokenizer.decode(seq, clean_up_tokenization_spaces=True) 472 | idx = text_output.find(self.tokenizer.eos_token) 473 | if idx >= 0: 474 | text_output = text_output[:idx] 475 | text_output = text_output.strip() 476 | generated_sentences.append(text_output) 477 | 478 | if self.tokenizer.eos_token_id in seq: 479 | idx_token = seq.index(self.tokenizer.eos_token_id) 480 | else: 481 | idx_token = -1 482 | 483 | if idx_token >= 0: 484 | seq_trimmed = seq[:idx_token] 485 | else: 486 | seq_trimmed = seq 487 | 488 | confidence_list = [] 489 | for i_tok, tok_id in enumerate(seq_trimmed): 490 | if self.config.is_encoder_decoder: 491 | # [T, B, Ntoken] 492 | scores = output_scores[i_tok] # [B, Ntok] 493 | prob = scores[sent_id, :].softmax(-1) 494 | else: 495 | # [B, T, Ntoken] 496 | scores = output_scores[sent_id] # [L, Ntok] 497 | prob = scores[i_tok].softmax(-1)[0] 498 | confidence_list.append(prob[tok_id].item()) 499 | confidences_list.append(confidence_list) 500 | confidences.append(np.mean(confidence_list)) 501 | 502 | return generated_sentences, confidences 503 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/plms/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from openprompt import plms 5 | from transformers import ( 6 | AutoTokenizer, 7 | GPT2Config, 8 | GPT2LMHeadModel, 9 | MT5Config, 10 | MT5ForConditionalGeneration, 11 | MT5Tokenizer, 12 | PreTrainedTokenizer, 13 | PreTrainedTokenizerFast, 14 | ) 15 | 16 | from .lm import LMTFastokenizerWrapperCustom 17 | 18 | 19 | def get_model_class(plm_type: str): 20 | return _MODEL_CLASSES[plm_type] 21 | 22 | 23 | MT5TokenizerWrapper = plms.T5TokenizerWrapper 24 | 25 | _MODEL_CLASSES = { 26 | "mt5": plms.ModelClass( 27 | **{ 28 | "config": MT5Config, 29 | "tokenizer": MT5Tokenizer, 30 | "model": MT5ForConditionalGeneration, 31 | "wrapper": MT5TokenizerWrapper, 32 | } 33 | ), 34 | "kogpt2": plms.ModelClass( 35 | **{ 36 | "config": GPT2Config, 37 | "tokenizer": PreTrainedTokenizerFast, 38 | "model": GPT2LMHeadModel, 39 | "wrapper": LMTFastokenizerWrapperCustom, 40 | } 41 | ), 42 | "legal-gpt": plms.ModelClass( 43 | **{ 44 | "config": GPT2Config, 45 | "tokenizer": AutoTokenizer, 46 | "model": GPT2LMHeadModel, 47 | "wrapper": LMTFastokenizerWrapperCustom, 48 | } 49 | ), 50 | } 51 | 52 | 53 | def load_plm_wrapper( 54 | model_name, 55 | model_path, 56 | specials_to_add=None, 57 | revision=None, 58 | do_not_load_pretrained_weight=False, 59 | use_custom_loader=False, 60 | ): 61 | if not use_custom_loader: 62 | return plms.load_plm(model_name, model_path, specials_to_add) 63 | else: 64 | model_class = get_model_class(plm_type=model_name) 65 | wrapper = model_class.wrapper 66 | if model_name in ["kogpt2"]: 67 | model_config = model_class.config.from_pretrained( 68 | model_path, revision=revision 69 | ) 70 | if do_not_load_pretrained_weight: 71 | model = model_class.model( 72 | config=model_config, 73 | ) 74 | else: 75 | model = model_class.model.from_pretrained( 76 | model_path, revision=revision, config=model_config 77 | ) 78 | 79 | tokenizer = model_class.tokenizer.from_pretrained( 80 | model_path, 81 | bos_token="", 82 | eos_token="", 83 | unk_token="", 84 | pad_token="", 85 | mask_token="", 86 | ) 87 | elif model_name in ["legal-gpt"]: 88 | model_config = model_class.config.from_pretrained( 89 | model_path, revision=revision 90 | ) 91 | if do_not_load_pretrained_weight: 92 | model = model_class.model( 93 | config=model_config, 94 | ) 95 | else: 96 | model = model_class.model.from_pretrained( 97 | model_path, revision=revision, config=model_config 98 | ) 99 | tokenizer = model_class.tokenizer.from_pretrained( 100 | model_path, 101 | bos_token="[BOS]", 102 | unk_token="[UNK]", 103 | pad_token="[PAD]", 104 | mask_token="[MASK]", 105 | ) 106 | 107 | else: 108 | model_config = model_class.config.from_pretrained( 109 | model_path, revision=revision 110 | ) 111 | if do_not_load_pretrained_weight: 112 | model = model_class.model( 113 | config=model_config, 114 | ) 115 | else: 116 | 117 | model = model_class.model.from_pretrained( 118 | model_path, revision=revision, config=model_config 119 | ) 120 | 121 | if "gpt" in model_name: # add pad token for gpt 122 | specials_to_add = [""] 123 | 124 | tokenizer = model_class.tokenizer.from_pretrained(model_path) 125 | model, tokenizer = plms.add_special_tokens( 126 | model, tokenizer, specials_to_add=specials_to_add 127 | ) 128 | 129 | if model_name in ["mt5"]: 130 | _path = ( 131 | Path(__file__).parent.resolve() / "mt5_additional_special_tokens.json" 132 | ) 133 | with open(_path) as f: 134 | mt5_additional_special_tokens = json.load(f) 135 | tokenizer.add_special_tokens( 136 | { 137 | "additional_special_tokens": mt5_additional_special_tokens[ 138 | "additional_special_tokens" 139 | ] 140 | } 141 | ) 142 | 143 | return model, tokenizer, model_config, wrapper 144 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/plms/lm.py: -------------------------------------------------------------------------------- 1 | # Wonseok add LMTFastokenizerWrapperCustom which is copied from OpenPrompt-v1.0.0 LMTokenizerWrapper class. 2 | # - The only difference is to inherit FastTokenizerWrapper instead of TokenizerWrapper 3 | 4 | from collections import defaultdict 5 | from typing import Optional 6 | 7 | from transformers.tokenization_utils import PreTrainedTokenizer 8 | 9 | from .utils import FastTokenizerWrapper 10 | 11 | 12 | class LMTFastokenizerWrapperCustom(FastTokenizerWrapper): 13 | r""" 14 | LMTokenizer is a causual language model. Therefore it can only predict position 15 | at the end of the sentence. A prefix-style template like: 'A news : ' is 16 | not applicable in this situation. 17 | For the template where there is '' or '' after '', we raise an exception and terminate 18 | the program. 19 | For the template where there are template words after '', we ignore these template words. 20 | Moreover, it can only predict one '' position. All template that has multiple '' will 21 | give rise to an exception. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | max_seq_length: int, 27 | tokenizer: PreTrainedTokenizer, 28 | truncate_method: Optional[str] = "tail", 29 | predict_eos_token: Optional[bool] = False, 30 | **kwargs 31 | ): 32 | super().__init__( 33 | max_seq_length=max_seq_length, 34 | tokenizer=tokenizer, 35 | truncate_method=truncate_method, 36 | ) 37 | self.predict_eos = predict_eos_token 38 | 39 | @property 40 | def num_special_tokens_to_add(self): 41 | if not hasattr(self, "_num_specials"): 42 | self._num_specials = self.tokenizer.num_special_tokens_to_add() 43 | return self._num_specials 44 | 45 | def tokenize_one_example(self, wrapped_example, teacher_forcing): 46 | """# TODO doens't consider the situation that input has two parts""" 47 | wrapped_example, others = wrapped_example 48 | 49 | if teacher_forcing: 50 | 51 | tgt_text = others["tgt_text"] 52 | if isinstance(tgt_text, str): 53 | tgt_text = [tgt_text] 54 | 55 | if self.predict_eos: 56 | if not wrapped_example[-1]["text"].endswith(self.tokenizer.eos_token): 57 | wrapped_example.append( 58 | { 59 | "text": self.tokenizer.eos_token, 60 | "shortenable_ids": 0, 61 | "loss_ids": 1, 62 | } 63 | ) 64 | 65 | encoder_inputs = defaultdict(list) 66 | 67 | num_mask_token_used = 0 68 | 69 | for piece_id, piece in enumerate(wrapped_example): 70 | if len(piece["text"]) == 0: 71 | continue 72 | 73 | if ( 74 | piece["text"] == self.tokenizer.eos_token 75 | and self.predict_eos 76 | and wrapped_example[piece_id - 1]["loss_ids"] == 1 77 | ): # eos after the mask also need to be pred 78 | piece["loss_ids"] = 1 79 | 80 | if piece["text"] == self.template_mask_token: 81 | if teacher_forcing: 82 | piece["text"] = " " + tgt_text[num_mask_token_used] + " " 83 | else: 84 | encoder_inputs["loss_ids"][-1][-1] = 1 85 | break 86 | 87 | if piece["text"] in self.special_tokens_maps.keys(): 88 | to_replace = self.special_tokens_maps[piece["text"]] 89 | if to_replace is not None: 90 | piece["text"] = to_replace 91 | else: 92 | raise KeyError( 93 | "This tokenizer doesn't specify {} token.".format(piece["text"]) 94 | ) 95 | 96 | if "soft_token_ids" in piece and piece["soft_token_ids"] != 0: 97 | encode_text = [ 98 | 0 99 | ] # can be replace by any token, since these token will use their own embeddings 100 | else: 101 | encode_text = self.tokenizer.encode( 102 | piece["text"], add_special_tokens=False 103 | ) 104 | 105 | encoding_length = len(encode_text) 106 | 107 | encoder_inputs["input_ids"].append(encode_text) 108 | for key in piece: 109 | if key not in ["text"]: 110 | encoder_inputs[key].append([piece[key]] * encoding_length) 111 | 112 | encoder_inputs = self.truncate(encoder_inputs=encoder_inputs) 113 | 114 | # delete shortenable ids 115 | encoder_inputs.pop("shortenable_ids") 116 | encoder_inputs = self.concate_parts(input_dict=encoder_inputs) 117 | encoder_inputs = self.add_special_tokens( 118 | encoder_inputs=encoder_inputs 119 | ) # this will do nothing in GPT2 tokenizer 120 | # create special input ids 121 | encoder_inputs["attention_mask"] = [1] * len(encoder_inputs["input_ids"]) 122 | if self.create_token_type_ids: 123 | encoder_inputs["token_type_ids"] = [0] * len(encoder_inputs["input_ids"]) 124 | # pad to max length 125 | input_ids_len = len(encoder_inputs["input_ids"]) 126 | encoder_inputs = self.padding( 127 | input_dict=encoder_inputs, 128 | max_len=self.max_seq_length, 129 | pad_id_for_inputs=self.tokenizer.pad_token_id, 130 | ) 131 | encoder_inputs = {**encoder_inputs, "input_ids_len": input_ids_len} 132 | return encoder_inputs 133 | -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/plms/mt5_additional_special_tokens.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "", 24 | "", 25 | "", 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "", 33 | "", 34 | "", 35 | "", 36 | "", 37 | "", 38 | "", 39 | "", 40 | "", 41 | "", 42 | "", 43 | "", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | "", 50 | "", 51 | "", 52 | "", 53 | "", 54 | "", 55 | "", 56 | "", 57 | "", 58 | "", 59 | "", 60 | "", 61 | "", 62 | "", 63 | "", 64 | "", 65 | "", 66 | "", 67 | "", 68 | "", 69 | "", 70 | "", 71 | "", 72 | "", 73 | "", 74 | "", 75 | "", 76 | "", 77 | "", 78 | "", 79 | "", 80 | "", 81 | "", 82 | "", 83 | "", 84 | "", 85 | "", 86 | "", 87 | "", 88 | "", 89 | "", 90 | "", 91 | "", 92 | "", 93 | "", 94 | "", 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "" 103 | ] 104 | } -------------------------------------------------------------------------------- /lbox_open/openprompt_wrapper/plms/utils.py: -------------------------------------------------------------------------------- 1 | # Wonseok add FastTokenizerWrapper. The class inherit OpenPrompt-v1.0.0 TokenizerWrapper class. 2 | 3 | import warnings 4 | 5 | import numpy as np 6 | from openprompt import plms 7 | 8 | 9 | class FastTokenizerWrapper(plms.utils.TokenizerWrapper): 10 | def add_special_tokens(self, encoder_inputs): 11 | # add special tokens 12 | for key in encoder_inputs: 13 | if key == "input_ids": 14 | with warnings.catch_warnings(): 15 | warnings.simplefilter("ignore") 16 | encoder_inputs[ 17 | key 18 | ] = self.tokenizer.build_inputs_with_special_tokens( 19 | encoder_inputs[key] 20 | ) 21 | else: 22 | # special_tokens_mask = np.array(self.tokenizer.get_special_tokens_mask(encoder_inputs[key], already_has_special_tokens=True)) 23 | special_tokens_mask = np.array([0] * len(encoder_inputs[key])) 24 | with_special_tokens = np.array( 25 | self.tokenizer.build_inputs_with_special_tokens(encoder_inputs[key]) 26 | ) 27 | if key in ["soft_token_ids"]: # TODO maybe more than this 28 | encoder_inputs[key] = ( 29 | (1 - special_tokens_mask) * with_special_tokens 30 | ).tolist() # use 0 as special 31 | else: 32 | encoder_inputs[key] = ( 33 | (1 - special_tokens_mask) * with_special_tokens 34 | - special_tokens_mask * 100 35 | ).tolist() # use -100 as special 36 | return encoder_inputs 37 | -------------------------------------------------------------------------------- /lbox_open/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/parser/__init__.py -------------------------------------------------------------------------------- /lbox_open/parser/output_parser.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import re 6 | 7 | 8 | def sep_token_parser_baseline( 9 | parse_sep_token, value_sep_token, empty_token, sub_parse_keys, text 10 | ): 11 | parse = {} 12 | char_comma = "," 13 | 14 | # filter ',' inside of number 15 | ms_money = re.finditer("\d[\d|,|.]+\d", text) 16 | ms_comma = re.finditer(char_comma, text) 17 | 18 | idxs_comma = [m.start() for m in ms_comma] 19 | idxs_comma_I = [] 20 | for ms_money in ms_money: 21 | st = ms_money.start() 22 | ed = ms_money.end() 23 | for idx_comma in idxs_comma: 24 | if idx_comma >= st and idx_comma <= ed: 25 | idxs_comma_I.append(idx_comma) 26 | 27 | text_copy = "" 28 | rpl_sym = "★" 29 | for i, c in enumerate(text): 30 | if i in idxs_comma_I: 31 | text_copy += rpl_sym 32 | else: 33 | text_copy += c 34 | 35 | values = text_copy.split(parse_sep_token) 36 | 37 | for i, k in enumerate(sub_parse_keys): 38 | if i <= len(values) - 1: 39 | if empty_token in values[i]: 40 | vals = empty_token 41 | else: 42 | parse_values_before_split = values[i] 43 | parse_values = parse_values_before_split.split(value_sep_token) 44 | vals = [] 45 | for val in parse_values: 46 | v = val.replace(rpl_sym, char_comma).strip() 47 | v = re.sub("\s", "", v) 48 | vals.append(v) 49 | else: 50 | vals = None 51 | parse[k] = vals 52 | return parse 53 | 54 | 55 | def sep_token_based_parser( 56 | target_parse, parse_sep_token, value_sep_token, empty_token, keys, text 57 | ): 58 | if target_parse in [ 59 | "fine_imprisonment_lvs", 60 | "claim_acceptance_lv", 61 | "casename_classification", 62 | "statute_classification", 63 | ]: 64 | # print(text) 65 | parse = sep_token_parser_baseline( 66 | parse_sep_token, value_sep_token, empty_token, keys, text 67 | ) 68 | else: 69 | raise NotImplementedError 70 | 71 | return parse 72 | -------------------------------------------------------------------------------- /lbox_open/parser/output_parser_utils.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | from collections import defaultdict 6 | from itertools import zip_longest 7 | from pathlib import Path 8 | 9 | import lbox_open.utils.general_utils as gu 10 | from lbox_open.metric.exact_match import ExactMatch 11 | 12 | from .output_parser import sep_token_based_parser 13 | 14 | 15 | def text_to_parse_separator_based( 16 | target_parse, 17 | parse_sep_token, 18 | value_sep_token, 19 | empty_token, 20 | target_sub_parses, 21 | texts, 22 | ): 23 | return list( 24 | map( 25 | lambda x: sep_token_based_parser( 26 | target_parse, 27 | parse_sep_token, 28 | value_sep_token, 29 | empty_token, 30 | target_sub_parses, 31 | x, 32 | ), 33 | texts, 34 | ) 35 | ) 36 | 37 | 38 | def get_parses_from_eval_results( 39 | infer_param, 40 | target_parses_dict, 41 | doc_ids, 42 | gt_texts, 43 | pr_texts, 44 | ): 45 | parses = defaultdict(dict) 46 | for target_parse, target_sub_parses in target_parses_dict.items(): 47 | gt_parses = text_to_parse_separator_based( 48 | target_parse, 49 | infer_param.parse_sep_token, 50 | infer_param.value_sep_token, 51 | infer_param.empty_token, 52 | target_sub_parses, 53 | gt_texts[target_parse], 54 | ) 55 | 56 | pr_parses = text_to_parse_separator_based( 57 | target_parse, 58 | infer_param.parse_sep_token, 59 | infer_param.value_sep_token, 60 | infer_param.empty_token, 61 | target_sub_parses, 62 | pr_texts[target_parse], 63 | ) 64 | 65 | # insert doc_ids 66 | for doc_id, gt_parse, pr_parse in zip_longest( 67 | doc_ids[target_parse], gt_parses, pr_parses 68 | ): 69 | gt_parse["doc_id"] = doc_id 70 | pr_parse["doc_id"] = doc_id 71 | 72 | parses[target_parse]["gt_parses"] = gt_parses 73 | parses[target_parse]["pr_parses"] = pr_parses 74 | 75 | return parses 76 | 77 | 78 | def cal_em_from_parses( 79 | infer_param, 80 | target_parses_dict, 81 | parses, 82 | verbose=False, 83 | save=False, 84 | output_save_dir=None, 85 | confidences=None, 86 | threshold=0.0, 87 | input_texts=None, 88 | ): 89 | em_scores_full = {} 90 | for target_parse, target_sub_parses in target_parses_dict.items(): 91 | 92 | gt_parses = parses[target_parse]["gt_parses"] 93 | pr_parses = parses[target_parse]["pr_parses"] 94 | 95 | if confidences is None: 96 | _confs = [1.0] * len(gt_parses) 97 | else: 98 | _confs = confidences[target_parse] 99 | 100 | exact_match = ExactMatch( 101 | list(gt_parses[0].keys()), empty_value=infer_param.empty_token 102 | ) 103 | 104 | ( 105 | f1_all, 106 | cnt_tp_all, 107 | cnt_fp_all, 108 | cnt_fn_all, 109 | cnt_tn_all, 110 | th_recall, 111 | ) = exact_match.compare_parses(gt_parses, pr_parses, _confs, threshold) 112 | 113 | if verbose: 114 | print(f"Target_parse: {target_parse} with th-recall: {th_recall}") 115 | print("tp-------------------") 116 | print(cnt_tp_all) 117 | print("fp-------------------") 118 | print(cnt_fp_all) 119 | print("fn-------------------") 120 | print(cnt_fn_all) 121 | print("tn-------------------") 122 | print(cnt_tn_all) 123 | print("f1-------------------") 124 | print(f1_all) 125 | 126 | score = { 127 | "f1": f1_all, 128 | "tp": cnt_tp_all, 129 | "fp": cnt_fp_all, 130 | "fn": cnt_fn_all, 131 | "tn": cnt_tn_all, 132 | "th_recall": th_recall, 133 | } 134 | em_scores_full[target_parse] = score 135 | 136 | if save: 137 | if output_save_dir is not None: 138 | if "path_eval_result" in infer_param: 139 | print("path_eval_result is ignored!!!") 140 | else: 141 | output_save_dir = infer_param.path_eval_result 142 | 143 | # path_save_dir = os.path.dirname(output_save_dir) 144 | path_save_dir = output_save_dir 145 | path_save = Path(path_save_dir) / f"eval_parse_{target_parse}.json" 146 | gu.save_json(path_save, parses) 147 | 148 | path_save = Path(path_save_dir) / f"score_exact_match_{target_parse}.json" 149 | gu.save_json(path_save, score) 150 | 151 | return em_scores_full 152 | -------------------------------------------------------------------------------- /lbox_open/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .lbox_open_pipeline import * 2 | -------------------------------------------------------------------------------- /lbox_open/pipeline/lbox_open_pipeline.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | from pathlib import Path 6 | 7 | import pytorch_lightning as pl 8 | import torch 9 | 10 | from lbox_open import openprompt_wrapper 11 | from lbox_open.data_module.data_precedent import PrecedentDataModule 12 | from lbox_open.model.generative_baseline_model import GenerativeParser 13 | from lbox_open.template import prompt_generation_utils 14 | from lbox_open.utils import general_utils as gu 15 | 16 | 17 | def get_data_module( 18 | cfg, 19 | plm_tokenizer, 20 | TokenizerWrapper, 21 | input_templates, 22 | ): 23 | 24 | if cfg.data.use_local_data: 25 | raw_data = { 26 | "train": gu.load_jsonl(cfg.data.path_train, None), 27 | "valid": gu.load_jsonl(cfg.data.path_valid, None), 28 | } 29 | if cfg.data.path_test is not None: 30 | raw_data["test"] = gu.load_jsonl(cfg.data.path_test, None) 31 | else: 32 | raw_data = None 33 | 34 | if cfg.model.task in [ 35 | "ljp_civil", 36 | "ljp_criminal", 37 | "casename_classification", 38 | "statute_classification", 39 | "summarization", 40 | ]: 41 | data_module = PrecedentDataModule( 42 | cfg, 43 | plm_tokenizer, 44 | TokenizerWrapper, 45 | input_templates, 46 | raw_data, 47 | ) 48 | else: 49 | raise NotImplementedError 50 | 51 | return data_module 52 | 53 | 54 | def get_plm(cfg): 55 | ( 56 | plm, 57 | plm_tokenizer, 58 | plm_model_config, 59 | TokenizerWrapperClass, 60 | ) = openprompt_wrapper.load_plm_wrapper( 61 | model_name=cfg.model.plm.name, 62 | model_path=cfg.model.plm.path, 63 | revision=cfg.model.plm.revision, 64 | do_not_load_pretrained_weight=cfg.train.weight.do_not_load_pretrained_weight, 65 | use_custom_loader=True, 66 | ) 67 | return plm, plm_tokenizer, plm_model_config, TokenizerWrapperClass 68 | 69 | 70 | def gen_input_templates(cfg, plm, plm_tokenizer): 71 | input_templates = {} 72 | for target_parse, target_sub_parses in cfg.model.target_parses_dict.items(): 73 | input_templates[target_parse] = prompt_generation_utils.gen_template( 74 | cfg.model.task, 75 | target_parse, 76 | cfg.model.input_template_type, 77 | plm, 78 | plm_tokenizer, 79 | ) 80 | 81 | return input_templates 82 | 83 | 84 | def get_model(cfg, plm, plm_tokenizer, input_templates): 85 | if cfg.model.model_type == "generative": 86 | model = GenerativeParser(cfg, plm, plm_tokenizer, input_templates) 87 | else: 88 | raise NotImplementedError 89 | 90 | if cfg.train.weight.trained: 91 | path_load = Path(cfg.train.weight.path) 92 | 93 | if cfg.model.task in [ 94 | "ljp_civil", 95 | "ljp_criminal", 96 | "casename_classification", 97 | "statute_classification", 98 | "summarization", 99 | ]: 100 | ckpt = torch.load(path_load) 101 | if "state_dict" in ckpt: 102 | ckpt_state_dict = ckpt["state_dict"] 103 | else: 104 | ckpt_state_dict = ckpt 105 | model.load_state_dict(ckpt_state_dict, strict=False) 106 | 107 | else: 108 | raise NotImplementedError 109 | 110 | print(f"The model weights are loaded from {path_load}.") 111 | 112 | return model 113 | 114 | 115 | def get_trainer(cfg): 116 | from pytorch_lightning import loggers as pl_loggers 117 | 118 | tparam = cfg.train 119 | mparam = cfg.model 120 | 121 | log_dir = Path(cfg.train.log_dir) / cfg.name 122 | tb_logger = pl_loggers.TensorBoardLogger(log_dir) 123 | 124 | pl.utilities.seed.seed_everything(seed=cfg.train.seed, workers=False) 125 | 126 | n_gpus = torch.cuda.device_count() 127 | 128 | callbacks = [ 129 | pl.callbacks.ModelCheckpoint( 130 | monitor=f"{cfg.train.validation_metric}_{cfg.train.validation_sub_param.method}", 131 | dirpath=gu.get_model_saving_path(tparam.weight.save_path_dir, cfg.name), 132 | save_top_k=1, 133 | mode="max", 134 | save_last=not True, 135 | ) 136 | ] 137 | if tparam.optim.swa.use: 138 | callbacks.append( 139 | pl.callbacks.StochasticWeightAveraging( 140 | swa_epoch_start=tparam.optim.swa.swa_epoch_start, 141 | swa_lrs=tparam.optim.swa.lr, 142 | annealing_epochs=tparam.optim.swa.annealing_epochs, 143 | ) 144 | ) 145 | 146 | trainer = pl.Trainer( 147 | logger=tb_logger, 148 | accelerator=tparam.accelerator, 149 | strategy=tparam.strategy, 150 | max_epochs=tparam.max_epochs, 151 | precision=mparam.precision if torch.cuda.is_available() else 32, 152 | num_sanity_val_steps=tparam.num_sanity_val_steps, 153 | gpus=n_gpus, 154 | check_val_every_n_epoch=tparam.check_val_every_n_epoch, 155 | gradient_clip_val=tparam.optim.gradient_clip_val, 156 | gradient_clip_algorithm=tparam.optim.gradient_clip_algorithm, 157 | accumulate_grad_batches=tparam.accumulate_grad_batches, 158 | val_check_interval=tparam.val_check_interval, 159 | profiler=tparam.profiler, 160 | fast_dev_run=tparam.fast_dev_run, 161 | callbacks=callbacks, 162 | limit_train_batches=tparam.get("limit_train_batches", 1.0), 163 | limit_val_batches=tparam.get("limit_val_batches", 1.0), 164 | ) 165 | return trainer 166 | 167 | 168 | def prepare_modules(mode, cfg): 169 | 170 | # get pretrained language models 171 | plm, plm_tokenizer, plm_model_config, TokenizerWrapperClass = get_plm(cfg) 172 | 173 | # gen templates 174 | input_templates = gen_input_templates(cfg, plm, plm_tokenizer) 175 | 176 | # get data module 177 | data_module = get_data_module( 178 | cfg, plm_tokenizer, TokenizerWrapperClass, input_templates 179 | ) 180 | 181 | # get model 182 | model = get_model(cfg, plm, plm_tokenizer, input_templates) 183 | 184 | # get trainer 185 | trainer = get_trainer(cfg) 186 | 187 | return data_module, model, trainer 188 | -------------------------------------------------------------------------------- /lbox_open/template/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/template/__init__.py -------------------------------------------------------------------------------- /lbox_open/template/prompt_generation_utils.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | from openprompt import prompts 6 | 7 | from lbox_open.template import prompt_templates 8 | 9 | from ..constants import ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL 10 | 11 | 12 | def gen_template(task, key, type, plm, tokenizer): 13 | mytemplate = prompts.MixedTemplate( 14 | model=plm, 15 | tokenizer=tokenizer, 16 | text=prompt_templates.gen_input_template_str(task, key, type), 17 | ) 18 | 19 | return mytemplate 20 | 21 | 22 | def gen_output_template( 23 | task, 24 | key, 25 | sub_keys, 26 | label, 27 | parse_sep_token, 28 | ): 29 | """ """ 30 | # todo: move template part to ./template.py 31 | 32 | if task == "ljp_criminal": 33 | if key == "fine_imprisonment_lvs": 34 | label_dict = label 35 | out = "" 36 | for key in sub_keys: 37 | key_kor = ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL[key] 38 | out += f"{key_kor}{label_dict[key]}{parse_sep_token} " 39 | out = out.strip(f"{parse_sep_token} ") 40 | 41 | else: 42 | raise NotImplementedError 43 | 44 | elif task == "ljp_civil": 45 | if key == "claim_acceptance_lv": 46 | out = str(label) 47 | else: 48 | raise NotImplementedError 49 | 50 | elif task == "casename_classification": 51 | if key == "casename_classification": 52 | out = str(label) 53 | else: 54 | raise NotImplementedError 55 | elif task == "statute_classification": 56 | assert isinstance(label, list) 57 | if key == "statute_classification": 58 | out = f"{parse_sep_token} ".join(label) 59 | else: 60 | raise NotImplementedError 61 | elif task == "summarization": 62 | if key == "summarization": 63 | out = str(label) 64 | else: 65 | raise NotImplementedError 66 | else: 67 | raise NotImplementedError 68 | 69 | return out 70 | -------------------------------------------------------------------------------- /lbox_open/template/prompt_templates.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | 6 | def gen_input_template_str(task, key, type): 7 | if key == "fine_imprisonment_lvs": 8 | if type == 0: 9 | input_template_str = ( 10 | '{"placeholder":"text_a"} 형사사건에 대하여 순서대로 벌금, 징역, 금고 레벨을 쓰시오. {"mask"}' 11 | ) 12 | else: 13 | raise NotImplementedError 14 | elif key == "claim_acceptance_lv": 15 | if type == 0: 16 | input_template_str = ( 17 | '{"placeholder":"text_a"} 주어진 사실관계, 청구 취지를 읽고, 주장 인정율을 예측하시오. {"mask"}' 18 | ) 19 | else: 20 | raise NotImplementedError 21 | elif key == "casename_classification": 22 | if type == 0: 23 | input_template_str = ( 24 | '{"placeholder":"text_a"} 주어진 사실관계를 읽고, 사건명을 예측하시오. {"mask"}' 25 | ) 26 | else: 27 | raise NotImplementedError 28 | elif key == "statute_classification": 29 | if type == 0: 30 | input_template_str = ( 31 | '{"placeholder":"text_a"} 주어진 사실관계를 읽고, 적용될 형법 조항들을 예측하시오. {"mask"}' 32 | ) 33 | else: 34 | raise NotImplementedError 35 | elif key == "summarization": 36 | if type == 0: 37 | input_template_str = '{"placeholder":"text_a"}\n요약하시오.\n{"mask"}' 38 | else: 39 | raise NotImplementedError 40 | else: 41 | raise NotImplementedError 42 | 43 | return input_template_str 44 | -------------------------------------------------------------------------------- /lbox_open/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/utils/__init__.py -------------------------------------------------------------------------------- /lbox_open/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import json 6 | import os 7 | import pickle 8 | import subprocess 9 | import time 10 | from pathlib import Path 11 | 12 | from tqdm import tqdm 13 | 14 | 15 | def stop_flag(idx, toy_size): 16 | # idx + 1 = length 17 | data_size = idx + 1 18 | if toy_size is not None: 19 | if toy_size <= data_size: 20 | return True 21 | else: 22 | return False 23 | 24 | 25 | def save_pkl(path_save, data): 26 | with open(path_save, "wb") as f: 27 | pickle.dump(data, f) 28 | 29 | 30 | def load_pkl(path_load): 31 | with open(path_load, "rb") as f: 32 | data = pickle.load(f) 33 | return data 34 | 35 | 36 | def save_json(path_save, data): 37 | with open(path_save, "w") as f: 38 | json.dump(data, f, ensure_ascii=False) 39 | 40 | 41 | def load_json(fpath): 42 | with open(fpath) as f: 43 | return json.load(f) 44 | 45 | 46 | def save_jsonl(path_save, data): 47 | with open(path_save, "w") as f: 48 | for t1 in data: 49 | f.writelines(json.dumps(t1, ensure_ascii=False)) 50 | f.writelines("\n") 51 | 52 | 53 | def load_jsonl(fpath, toy_size=None): 54 | data = [] 55 | with open(fpath) as f: 56 | for i, line in tqdm(enumerate(f)): 57 | try: 58 | data1 = json.loads(line) 59 | except: 60 | print(f"{i}th sample failed.") 61 | print(f"We will wkip this!") 62 | print(line) 63 | data1 = None 64 | if data1 is not None: 65 | data.append(data1) 66 | if stop_flag(i, toy_size): 67 | break 68 | 69 | return data 70 | 71 | 72 | def my_timeit(func): 73 | def wrapped_func(*args, **kwargs): 74 | st = time.time() 75 | results = func(*args, **kwargs) 76 | ed = time.time() 77 | print(f"func {func.__name__} taks {ed - st} sec.") 78 | return results 79 | 80 | return wrapped_func 81 | 82 | 83 | def flatten_list(list_): 84 | out = [] 85 | for x in list_: 86 | if isinstance(x, list): 87 | out += flatten_list(x) 88 | else: 89 | out += [x] 90 | 91 | return out 92 | 93 | 94 | def load_cfg(path_cfg): 95 | import munch 96 | import yaml 97 | 98 | with open(path_cfg) as f: 99 | cfg = yaml.full_load(f) 100 | cfg = munch.munchify(cfg) 101 | cfg.name = path_cfg.__str__().split("/")[-1] 102 | return cfg 103 | 104 | 105 | def get_model_saving_path(save_dir, cfg_name): 106 | return Path(save_dir) / cfg_name 107 | 108 | 109 | def download_url(path_save, url): 110 | p = subprocess.Popen(["wget", "-q", "-O", path_save.__str__(), url]) 111 | sts = os.waitpid(p.pid, 0) 112 | 113 | 114 | def get_local_rank(): 115 | """ 116 | Pytorch lightning save local rank to environment variable "LOCAL_RANK". 117 | From rank_zero_only 118 | """ 119 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 120 | return local_rank 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.62.3 2 | munch==2.5.0 3 | 4 | transformers==4.19.4 5 | pytorch-lightning==1.5.8 6 | 7 | setuptools==59.5.0 # for the compatiblity with pytorch 1.10 8 | sentencepiece==0.1.96 9 | 10 | # for OpenPrompt 11 | openprompt==1.0.0 12 | rouge_score==0.1.2 13 | 14 | # for facts 15 | thefuzz==0.19.0 16 | nltk==3.6.7 17 | python-Levenshtein==0.12.2 18 | 19 | # misc 20 | scikit-learn 21 | tweepy==3.10.0 22 | thefuzz==0.19.0 23 | python-Levenshtein==0.12.2 24 | scikit-learn==0.23.2 25 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | # LBox Open 2 | # Copyright (c) 2022-present LBox Co. Ltd. 3 | # CC BY-NC 4.0 4 | 5 | import argparse 6 | 7 | from lbox_open.pipeline import prepare_modules 8 | from lbox_open.utils import general_utils 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("path_cfg", default="") 14 | parser.add_argument("--mode", default="") 15 | args = parser.parse_args() 16 | 17 | cfg = general_utils.load_cfg(args.path_cfg) 18 | 19 | if args.mode == "train": 20 | data_module, model, trainer = prepare_modules("train", cfg) 21 | trainer.fit(model, data_module) 22 | 23 | elif args.mode == "test": 24 | data_module, model, trainer = prepare_modules("train", cfg) 25 | trainer.test(model, datamodule=data_module) 26 | else: 27 | print( 28 | f"{args.mode} mode is not supported. The mode should be either 'train' or 'test'." 29 | ) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /scripts/predict_summarization.sh: -------------------------------------------------------------------------------- 1 | config="configs/summarization/summarization.legal-mt5s.predict.yaml" 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run_model.py $config --mode test 4 | -------------------------------------------------------------------------------- /scripts/test_casename.sh: -------------------------------------------------------------------------------- 1 | #config="configs/casename_classification/casename.kogpt2.test.yaml" 2 | #config="configs/casename_classification/casename.lcube-base.test.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode test 5 | 6 | -------------------------------------------------------------------------------- /scripts/test_ljp_civil.sh: -------------------------------------------------------------------------------- 1 | #config="configs/ljp/civil/ljp.civil.kogpt2.test.yaml" 2 | #config="configs/ljp/civil/ljp.civil.lcube-base.test.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode test 5 | -------------------------------------------------------------------------------- /scripts/test_ljp_criminal.sh: -------------------------------------------------------------------------------- 1 | #config="configs/ljp/criminal/ljp.criminal.lcube-base.test.yaml" 2 | #config="configs/ljp/criminal/ljp.criminal.kogpt2.test.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode test 5 | -------------------------------------------------------------------------------- /scripts/test_statute.sh: -------------------------------------------------------------------------------- 1 | #config="configs/statute_classification/statute.kogpt2.test.yaml" 2 | #config="configs/statute_classification/statute.lcube-base.test.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode test 5 | -------------------------------------------------------------------------------- /scripts/test_summarization.sh: -------------------------------------------------------------------------------- 1 | config="configs/summarization/summarization.legal-mt5s.test.yaml" 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run_model.py $config --mode test 4 | -------------------------------------------------------------------------------- /scripts/train_casename.sh: -------------------------------------------------------------------------------- 1 | #config="configs/casename_classification/casename.kogpt2.yaml" 2 | config="configs/casename_classification/casename.lcube-base.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode train 5 | 6 | -------------------------------------------------------------------------------- /scripts/train_ljp_civil.sh: -------------------------------------------------------------------------------- 1 | #config="configs/ljp/civil/ljp.civil.kogpt2.yaml" 2 | config="configs/ljp/civil/ljp.civil.lcube-base.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode train 5 | -------------------------------------------------------------------------------- /scripts/train_ljp_criminal.sh: -------------------------------------------------------------------------------- 1 | #config="configs/ljp/criminal/ljp.criminal.kogpt2.yaml" 2 | config="configs/ljp/criminal/ljp.criminal.lcube-base.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode train 5 | -------------------------------------------------------------------------------- /scripts/train_statute.sh: -------------------------------------------------------------------------------- 1 | #config="configs/statute_classification/statute.kogpt2.yaml" 2 | config="configs/statute_classification/statute.lcube-base.yaml" 3 | export CUDA_VISIBLE_DEVICES=0 4 | python run_model.py $config --mode train 5 | -------------------------------------------------------------------------------- /scripts/train_summarization.sh: -------------------------------------------------------------------------------- 1 | #config="configs/summarization/summarization.kogpt2.yaml" 2 | # config="configs/summarization/summarization.lcube-base.yaml" 3 | config="configs/summarization/summarization.legal-mt5s.yaml" 4 | export CUDA_VISIBLE_DEVICES=0 5 | python run_model.py $config --mode train 6 | --------------------------------------------------------------------------------