├── .gitignore
├── LICENSE
├── README.md
├── configs
├── casename_classification
│ ├── casename.kogpt2.e2.yaml
│ ├── casename.kogpt2.e3.yaml
│ ├── casename.kogpt2.yaml
│ ├── casename.lcube-base.e2.yaml
│ ├── casename.lcube-base.e3.yaml
│ └── casename.lcube-base.yaml
├── ljp
│ ├── civil
│ │ ├── ljp.civil.kogpt2.e2.yaml
│ │ ├── ljp.civil.kogpt2.e3.yaml
│ │ ├── ljp.civil.kogpt2.yaml
│ │ ├── ljp.civil.lcube-base.e2.yaml
│ │ ├── ljp.civil.lcube-base.e3.yaml
│ │ └── ljp.civil.lcube-base.yaml
│ └── criminal
│ │ ├── ljp.criminal.kogpt2.e2.yaml
│ │ ├── ljp.criminal.kogpt2.e3.yaml
│ │ ├── ljp.criminal.kogpt2.yaml
│ │ ├── ljp.criminal.lcube-base.e2.yaml
│ │ ├── ljp.criminal.lcube-base.e3.yaml
│ │ └── ljp.criminal.lcube-base.yaml
├── statute_classification
│ ├── statute.kogpt2.e2.yaml
│ ├── statute.kogpt2.e3.yaml
│ ├── statute.kogpt2.yaml
│ ├── statute.lcube-base.e2.yaml
│ ├── statute.lcube-base.e3.yaml
│ └── statute.lcube-base.yaml
└── summarization
│ ├── summarization.kogpt2.yaml
│ ├── summarization.lcube-base.yaml
│ └── summarization.legal-mt5s.test.yaml
├── lbox_open
├── constants
│ ├── __init__.py
│ └── constants_fie.py
├── data_module
│ ├── __init__.py
│ └── data_precedent.py
├── datasets_script
│ └── lbox_open.py
├── metric
│ ├── exact_match.py
│ └── rouge_metric_utils.py
├── model
│ ├── generative_baseline_model.py
│ └── model_optimizer.py
├── openprompt_wrapper
│ ├── __init__.py
│ ├── data_utils
│ │ └── __init__.py
│ ├── pipeline_base.py
│ └── plms
│ │ ├── __init__.py
│ │ ├── lm.py
│ │ ├── mt5_additional_special_tokens.json
│ │ └── utils.py
├── parser
│ ├── __init__.py
│ ├── output_parser.py
│ └── output_parser_utils.py
├── pipeline
│ ├── __init__.py
│ └── lbox_open_pipeline.py
├── template
│ ├── __init__.py
│ ├── prompt_generation_utils.py
│ └── prompt_templates.py
└── utils
│ ├── __init__.py
│ └── general_utils.py
├── requirements.txt
├── run_model.py
└── scripts
├── predict_summarization.sh
├── test_casename.sh
├── test_ljp_civil.sh
├── test_ljp_criminal.sh
├── test_statute.sh
├── test_summarization.sh
├── train_casename.sh
├── train_ljp_civil.sh
├── train_ljp_criminal.sh
├── train_statute.sh
└── train_summarization.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .idea/
132 | configs/summarization/summarization.legal-mt5s.yaml
133 | saved
134 | configs/summarization/summarization.legal-mt5s.predict.yaml
135 | logs/
136 | data/
137 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LBox Open
2 |
3 | A multi-task benchmark for Korean legal language understanding and judgement prediction by [LBox](https://lbox.kr)
4 |
5 | # Authors
6 |
7 | - [Wonseok Hwang](mailto:wonseok.hwang@lbox.kr)
8 | - [Dongjun Lee](mailto:dongjun.lee@lbox.kr)
9 | - [Kyoungyeon Cho](mailto:kycho@lbox.kr)
10 | - [Hanuhl Lee](mailto:leehanuhl@lbox.kr)
11 | - [Minjoon Seo](mailto:minjoon@lbox.kr)
12 |
13 | # Updates
14 | - Dec 2, 2022: We release [additional 1024 examples of `drunk driving` cases](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/judgement_prediction/ljp_criminal_drunk_driving_plus_1024.jsonl) for `ljp_criminal` task. Compared to `ljp_criminal` data, it includes the parses extracted from the facts (blood alchol level, driving distance, types of car, previous criminal history) and the suspension of exeuction period. See also [this issue](https://github.com/lbox-kr/lbox-open/issues/10). The data shall be integrated to `ljp_criminal` in the next release.
15 |
16 | - Dec 2, 2022: We will present our recent work ["Data-efficient End-to-end Information Extraction for Statistical Legal Analysis"](https://arxiv.org/abs/2211.01692) at [NLLP workshop @ EMNLP22](https://nllpw.org/workshop/)!
17 |
18 | - Nov 8, 2022: We release [`legal-mt5-small`], a domain adapted mt5-small using `precedent_corpus`. We also release the `legal-mt5-small` fine-tuned on the `summarization` dataset. Both models can be download from [here](https://drive.google.com/file/d/1lZaUtDPCkAOcwaxBzFo-QHecGAQendOd/view?usp=share_link)! To use the models, `cd [project-dir]; tar xvfz legal-mt5-small.tar.gz`.
19 | - Oct 25, 2022: [`act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus`](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus/act_on_special_cases_concerning_the_settlement_of_traffic_accidents_corpus.jsonl) corpus (고통사고처리특례법위반(치상)) has been released. The corpus consists of 768 criminal cases. The corpus will be integrated into `precedent corpus` in the future (the overlap between `precedent corpus` and `defamation corpus-v0.1` is expected). See also [this issue](https://github.com/lbox-kr/lbox-open/issues/9).
20 | - Oct 18, 2022: We release three new datasets `casename_classification_plus`, `statute_classification_plus`, and `summarization_plus`!
21 | - Oct 2, 2022: [`defamation corpus-v0.1`](https://cdn.lbox.kr/public/dataset/lbox-open/precedent_benchmark_dataset/defamation_corpus/defamation_corpus.jsonl) has been added. The corpus consists of 1,536 criminal cases related to "defamation (명예훼손)". The corpus will be integrated into `precedent corpus` in the future (at the moment, there can be some overlap between `precedent corpus` and `defamation corpus-v0.1`). See also [this issue](https://github.com/lbox-kr/lbox-open/issues/4#issue-1393652876).
22 | - Sep 2022: Our paper is accepted for publication in NeurIPS 2022 Datasets and Benchmarks track! There will be major updates on the paper, the dataets, and the models soon! Meanwile, one can check the most recent version of our paper from [OpenReview](https://openreview.net/forum?id=TaARsI_Iio)
23 | - Jun 2022: We release `lbox-open-v0.2`!
24 | - Two legal judgement prediction tasks, `ljp_criminal`, `ljp-civil`, are added to LBox Open.
25 | - `LCube-base`, a LBox Legal Language model with 124M parameters, is added.
26 | - The baseline scores and its training/test scripts are added.
27 | - Other updates
28 | - Some missing values in `facts` fields of `casename_classification` and `statute_classification` are updated.
29 | - `case_corpus` is renamed to `precedent_corpus`
30 | - Mar 2022: We release `lbox-open-v0.1`!
31 |
32 | # Paper
33 |
34 | [A Multi-Task Benchmark for Korean Legal Language Understanding and Judgement Prediction](https://arxiv.org/abs/2206.05224)
35 |
36 | # Benchmarks
37 |
38 | - Last updated at Oct 18 2022
39 |
40 | | **Model** | casename | statute | ljp-criminal | ljp-civil | summarization |
41 | |-------------------|----------------|----------------|-----------------------------------------------------------------------|----------------|------------------|
42 | | | EM | EM | F1-fine
F1-imprisonment w/ labor
F1-imprisonment w/o labor | EM | R1
R2
RL |
43 | | KoGPT2 | $78.5 \pm 0.3$ | $85.7 \pm 0.8$ | $49.9 \pm 1.7$
$67.5 \pm 1.1$
$69.2 \pm 1.6$ | $66.0 \pm 0.5$ | $47.2$
$39.1$
$45.7$ |
44 | | KoGPT2 + `d.a.` | $81.9 \pm 0.2$ | $89.4 \pm 0.5$ | $49.8$
$65.4$
$70.1$ | $64.7 \pm 1.1$ | $49.2$
$40.9$
$47.7$ |
45 | | LCube-base (ours) | $81.1 \pm 0.3$ | $87.6 \pm 0.5$ | $46.4 \pm 2.8$
$69.3 \pm 0.3$
$70.3 \pm 0.7$ | $67.6 \pm 1.3$ | $46.0$
$37.7$
$44.5$ |
46 | | LCube-base + `d.a.` (ours) | $82.7 \pm 0.6$ | $89.3 \pm 0.4$ | $48.1 \pm 1.2$
$67.4 \pm 1.5$
$69.9 \pm 1.1$ | $60.9 \pm 1.1$ | $47.8$
$39.5$
$46.4$ |
47 | | mt5-small | $81.0 \pm 1.3$ | $87.2 \pm 0.3$ | $49.1 \pm 1.3$
$66.6 \pm 0.6$
$69.8 \pm 1.0$ | $68.9 \pm 0.8$ | $56.2$
$47.8$
$54.7$ |
48 | | mt5-small + `d.a.`| $82.2 \pm 0.2$ | $88.8 \pm 0.5$ | $51.8 \pm 0.7$
$68.9 \pm 0.3$
$70.3 \pm 0.7$ | $69.1 \pm 0.1$ | $56.2$
$47.7$
$54.8$ |
49 |
50 | - The errors are estimated from three independent experiments performed with different random seeds.
51 | - ROUGE scores are computed at word level.
52 | - `d.a.` stands for domain adaptation, an additional pre-trainig with `Precedent` corpus only.
53 |
54 | # Dataset
55 |
56 | ## How to use the dataset
57 |
58 | We use [`datasets`](https://github.com/huggingface/datasets) library from `HuggingFace`.
59 |
60 | ```python
61 | # !pip install datasets
62 | from datasets import load_dataset
63 |
64 | # casename classficiation task
65 | data_cn = load_dataset("lbox/lbox_open", "casename_classification")
66 | ata_cn_plus = load_dataset("lbox/lbox_open", "casename_classification_plus")
67 |
68 | # statutes classification task
69 | data_st = load_dataset("lbox/lbox_open", "statute_classification")
70 | data_st_plus = load_dataset("lbox/lbox_open", "statute_classification_plus")
71 |
72 | # Legal judgement prediction tasks
73 | data_ljp_criminal = load_dataset("lbox/lbox_open", "ljp_criminal")
74 | data_ljp_civil = load_dataset("lbox/lbox_open", "ljp_civil")
75 |
76 | # case summarization task
77 | data_summ = load_dataset("lbox/lbox_open", "summarization")
78 | data_summ_plus = load_dataset("lbox/lbox_open", "summarization_plus")
79 |
80 | # precedent corpus
81 | data_corpus = load_dataset("lbox/lbox_open", "precedent_corpus")
82 |
83 |
84 | ```
85 |
86 | - [Explore the dataset on Colab](https://colab.research.google.com/drive/1R4T91Ix__-4rjtxATh7JeTX69zYrmWy0?usp=sharing)
87 |
88 | ## Dataset Description
89 | ### `precedent_corpus`
90 | - Korean legal precedent corpus.
91 | - The corpus consists of 150k cases.
92 | - About 80k from [LAW OPEN DATA](https://www.law.go.kr/LSO/main.do) and 70k from LBox database.
93 |
94 | - Example
95 | ```json
96 | {
97 | "id": 99990,
98 | "precedent": "주문\n피고인을 징역 6개월에 처한다.\n다만, 이 판결 확정일로부터 1년간 위 형의 집행을 유예한다.\n\n이유\n범 죄 사 실\n1. 사기\n피고인은 2020. 12. 15. 16:00경 경북 칠곡군 B에 있는 피해자 C이 운영하는 ‘D’에서, 마치 정상적으로 대금을 지급할 것처럼 행세하면서 피해자에게 술을 주문하였다.\n그러나 사실 피고인은 수중에 충분한 현금이나 신용카드 등 결제 수단을 가지고 있지 않아 정상적으로 대금을 지급할 의사나 능력이 없었다.\n그럼에도 피고인은 위와 같이 피해자를 기망하여 이에 속은 피해자로부터 즉석에서 합계 8,000원 상당의 술을 교부받았다.\n2. 공무집행방해\n피고인은 제1항 기재 일시·장소에서, ‘손님이 술값을 지불하지 않고 있다’는 내용의 112신고를 접수하고 현장에 출동한 칠곡경찰서 E지구대 소속 경찰관 F로부터 술값을 지불하고 귀가할 것을 권유받자, “징역가고 싶은데 무전취식했으니 유치장에 넣어 달라”고 말하면서 순찰차에 타려고 하였다. 이에 경찰관들이 수회 귀가 할 것을 재차 종용하였으나, 피고인은 경찰관들을 향해 “내가 돌로 순찰차를 찍으면 징역갑니까?, 내여경 엉덩이 발로 차면 들어갈 수 있나?”라고 말하고, 이를 제지하는 F의 가슴을 팔꿈치로 수회 밀쳐 폭행하였다.\n이로써 피고인은 경찰관의 112신고사건 처리에 관한 정당한 직무집행을 방해하였다. 증거의 요지\n1. 피고인의 판시 제1의 사실에 부합하는 법정진술\n1. 증인 G, F에 대한 각 증인신문조서\n1. 영수증\n1. 현장 사진\n법령의 적용\n1. 범죄사실에 대한 해당법조 및 형의 선택\n형법 제347조 제1항, 제136조 제1항, 각 징역형 선택\n1. 경합범가중\n형법 제37조 전단, 제38조 제1항 제2호, 제50조\n1. 집행유예\n형법 제62조 제1항\n양형의 이유\n1. 법률상 처단형의 범위: 징역 1월∼15년\n2. 양형기준에 따른 권고형의 범위\n가. 제1범죄(사기)\n[유형의 결정]\n사기범죄 > 01. 일반사기 > [제1유형] 1억 원 미만\n[특별양형인자]\n- 감경요소: 미필적 고의로 기망행위를 저지른 경우 또는 기망행위의 정도가 약한 경우, 처벌불원\n[권고영역 및 권고형의 범위]\n특별감경영역, 징역 1월∼1년\n[일반양형인자] 없음\n나. 제2범죄(공무집행방해)\n[유형의 결정]\n공무집행방해범죄 > 01. 공무집행방해 > [제1유형] 공무집행방해/직무강요\n[특별양형인자]\n- 감경요소: 폭행·협박·위계의 정도가 경미한 경우\n[권고영역 및 권고형의 범위]\n감경영역, 징역 1월∼8월\n[일반양형인자]\n- 감경요소: 심신미약(본인 책임 있음)\n다. 다수범죄 처리기준에 따른 권고형의 범위: 징역 1월∼1년4월(제1범죄 상한 + 제2범죄 상한의 1/2)\n3. 선고형의 결정: 징역 6월에 집행유예 1년\n만취상태에서 식당에서 소란을 피웠고, 112신고로 출동한 경찰관이 여러 차례 귀가를 종용하였음에도 이를 거부하고 경찰관의 가슴을 밀친 점 등을 종합하면 죄책을 가볍게 볼 수 없으므로 징역형을 선택하되, 평소 주량보다 훨씬 많은 술을 마신 탓에 제정신을 가누지 못해 저지른 범행으로 보이고 폭행 정도가 매우 경미한 점, 피고인이 술이 깬 후 자신의 경솔한 언동을 깊이 반성하면서 재범하지 않기 위해 정신건강의학과의 치료 및 상담을 받고 있는 점, 식당 업주에게 피해를 변상하여 용서를 받은 점, 피고인의 나이와 가족관계 등의 사정을 참작하여 형의 집행을 유예하고, 범행 경위와 범행 후 피고인의 태도 등에 비추어 볼 때 재범의 위험성은 그다지 우려하지 않아도 될 것으로 보여 보호관찰 등 부수처분은 부과하지 않음.\n이상의 이유로 주문과 같이 판결한다."
99 | }
100 | ```
101 | - `id`: a data id.
102 | - `precedent`: a case from the court of Korea. It includes the ruling (주문), the gist of claim (청구취지), the claim of appeal (항소취지), and
103 | the reasoning (이유).
104 |
105 | ### `casename_classification`
106 |
107 | - Task: for the given facts (사실관계), a model is asked to predict the case name.
108 | - The dataset consists of 10k `(facts, case name)` pairs extracted from Korean precedents.
109 | - There are 100 classes (case categories) and each class contains 100 corresponding examples.
110 | - 8,000 training, 1,000 validation, 1,000 test, and 1,294 test2 examples. The test2 set consists of examples that do not overlap with the precedents in `precedent_corpus`.
111 | - We also provide `casename_classification_plus`, a dataset that extends `casename_classification` by including infrequent case categories. `casename_classification_plus` consists of 31,283 examples with total 603 case categories. See our paper for the detail.
112 | - Example
113 |
114 | ```json
115 | {
116 | "id": 80,
117 | "casetype": "criminal",
118 | "casename": "감염병의예방및관리에관한법률위반",
119 | "facts": "질병관리청장, 시·도지사 또는 시장·군수·구청장은 제1급 감염병이 발생한 경우 감염병의 전파방지 및 예방을 위하여 감염병의심자를 적당한 장소에 일정한 기간 격리시키는 조치를 하여야 하고, 그 격리조치를 받은 사람은 이를 위반하여서는 아니 된다. 피고인은 해외에서 국내로 입국하였음을 이유로 2021. 4. 21.경 감염병의심자로 분류되었고, 같은 날 창녕군수로부터 ‘2021. 4. 21.부터 2021. 5. 5. 12:00경까지 피고인의 주거지인 경남 창녕군 B에서 격리해야 한다’는 내용의 자가격리 통지서를 수령하였다. 1. 2021. 4. 27.자 범행 그럼에도 불구하고 피고인은 2021. 4. 27. 11:20경에서 같은 날 11:59경까지 사이에 위 격리장소를 무단으로 이탈하여 자신의 승용차를 이용하여 경남 창녕군 C에 있는 ‘D’ 식당에 다녀오는 등 자가격리 조치를 위반하였다. 2. 2021. 5. 3.자 범행 피고인은 2021. 5. 3. 10:00경에서 같은 날 11:35경까지 사이에 위 격리장소를 무단으로 이탈하여 자신의 승용차를 이용하여 불상의 장소를 다녀오는 등 자가격리 조치를 위반하였다."
120 | }
121 | ```
122 | - `id`: a data id.
123 | - `casetype`: a case type. The value is either `civil` (민사) or `criminal` (형사).
124 | - `casename`: a case name.
125 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases.
126 |
127 | ### `statute_classification`
128 |
129 | - Task: for a given facts (사실관계), a model is asked to predict related statutes (법령).
130 | - The dataset consists of 2760 `(facts, statutes)` pairs extracted from individual Korean legal cases.
131 | - There are 46 classes (case categories) and each class has 60 examples.
132 | - 2,208 training, 276 validation, 276 test, 538 test2 examples. The test2 set consists of examples that do not overlap with the precedents in `precedent_corpus`.
133 | - We also release `statute_classification_plus`, a dataset that extends `statute_classification` by including less frequent case categories.`statute_classification_plus` includes 17,730 examples with total 434 case categories and 1,015 statutes.
134 | - Example
135 |
136 | ```json
137 | {
138 | "id": 5180,
139 | "casetype": "criminal",
140 | "casename": "사문서위조, 위조사문서행사",
141 | "statutes": [
142 | "형법 제231조",
143 | "형법 제234조"
144 | ],
145 | "facts": "1. 사문서위조 피고인은 2014. 5. 10.경 서울 송파구 또는 하남시 이하 알 수 없는 장소에서 영수증문구용지에 검정색 볼펜을 사용하여 수신인란에 ‘A’, 일금란에 ‘오천오백육십만원정’, 내역 란에 ‘2010가합7485사건의 합의금 및 피해 보상금 완결조’, 발행일란에 ‘2014년 5월 10일’이라고 기재한 뒤, 발행인 옆에 피고인이 임의로 만들었던 B의 도장을 찍었다. 이로써 피고인은 행사할 목적으로 사실증명에 관한 사문서인 B 명의의 영수증 1장을 위조하였다. 2. 위조사문서행사 피고인은 2014. 10. 16.경 하남시 이하 알 수 없는 장소에서 피고인이 B에 대한 채무를 모두 변제하였기 때문에 B가 C회사에 채권을 양도한 것을 인정할 수 없다는 취지의 내용증명원과 함께 위와 같이 위조한 영수증 사본을 마치 진정하게 성립한 문서인 것처럼 B에게 우편으로 보냈다. 이로써 피고인은 위조한 사문서를 행사하였다."
146 | }
147 |
148 | ```
149 |
150 | - `id`: a data id.
151 | - `casetype`: a case type. The value is always `criminal`.
152 | - `casename`: a case name.
153 | - `statutes`: related statues.
154 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases.
155 |
156 | ### `ljp_criminal`
157 |
158 | - Task: a model needs to predict the ranges of fine (벌금), imprisonment with labor (징역), imprisonment without labor (금고).
159 | - 10,500 `facts` and the corresponding punishment are extracted from cases with following case categories are “indecent
160 | act by compulsion” (강제추행), “obstruction of performance of official duties” (공무집행방해), “bodily injuries from traffic
161 | accident” (교통사고처리특례법위반(치상)), “drunk driving” (도로교통 법위반(음주운전)), “fraud” (사기), “inflicting bodily injuries” (상해), and
162 | “violence” (폭행)
163 | - 8,400 training, 1,050 validation, 1,050 test, 928 test2 examples. The test2 set consists of the examples from the test set that do not overlap with the precedents in `precedent_corpus`.
164 | - Example
165 | ```json
166 | {
167 | "casename": "공무집행방해",
168 | "casetype": "criminal",
169 | "facts": "피고인은 2020. 3. 13. 18:57경 수원시 장안구 B 앞 노상에서 지인인 C와 술을 마시던 중 C를 때려 112신고를 받고 출동한 수원중부경찰서 D지구대 소속 경위 E가 C의 진술을 청취하고 있는 모습을 보고 화가 나 '씨발,개새끼'라며 욕설을 하고, 위 E가 이를 제지하며 귀가를 종용하자 그의 왼쪽 뺨을 오른 주먹으로 1회 때려 폭행하였다.\n이로써 피고인은 경찰관의 112신고사건 처리에 관한 정당한 직무집행을 방해하였다. 증거의 요지\n1. 피고인의 법정진술\n1. 피고인에 대한 경찰 피의자신문조서\n1. E에 대한 경찰 진술조서\n1. 현장사진 등, 바디캠영상",
170 | "id": 2300,
171 | "label": {
172 | "fine_lv": 0,
173 | "imprisonment_with_labor_lv": 2,
174 | "imprisonment_without_labor_lv": 0,
175 | "text": "징역 6월"
176 | },
177 | "reason": "양형의 이유\n1. 법률상 처단형의 범위: 징역 1월∼5년\n2. 양형기준에 따른 권고형의 범위\n[유형의 결정]\n공무집행방해범죄 > 01. 공무집행방해 > [제1유형] 공무집행방해/직무강요\n[특별양형인자] 없음\n[권고영역 및 권고형의 범위] 기본영역, 징역 6월∼1년6월\n3. 선고형의 결정\n피고인이 싸움 발생 신고를 받고 출동한 경찰관에게 욕설을 퍼붓고 귀가를 종용한다는 이유로 경찰관의 뺨을 때리는 등 폭행을 행사하여 경찰관의 정당한 공무집행을 방해한 점에서 그 죄책이 매우 무겁다. 피고인의 범죄 전력도 상당히 많다.\n다만, 피고인이 범행을 인정하면서 반성하고 있는 점, 공무집행방해 범죄로 처벌받은 전력이 없는 점 등은 피고인에게 유리한 정상으로 참작한다.\n그 밖에 피고인의 연령, 성행, 환경, 가족관계, 건강상태, 범행의 동기와 수단 및 결과, 범행 후의 정황 등 이 사건 기록 및 변론에 나타난 모든 양형요소를 종합하여, 주문과 같이 형을 정한다.",
178 | "ruling": {
179 | "parse": {
180 | "fine": {
181 | "type": "",
182 | "unit": "",
183 | "value": -1
184 | },
185 | "imprisonment": {
186 | "type": "징역",
187 | "unit": "mo",
188 | "value": 6
189 | }
190 | },
191 | "text": "피고인을 징역 6월에 처한다.\n다만 이 판결 확정일로부터 2년간 위 형의 집행을 유예한다."
192 | }
193 | }
194 | ```
195 |
196 | - `id`: a data id.
197 | - `casetype`: a case type. The value is always `criminal`.
198 | - `casename`: a case name.
199 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases.
200 | - `label`
201 | - `fine_lv`: a label representing individual ranges of the fine amount. See our paper for the detail.
202 | - `imprisonment_with_labor_lv`: a label representing the ranges of the imprisonemnt with labor.
203 | - `imprisonment_without_labor_lv`: a label for the imprisonment without labor case.
204 | - `reason`: the reason for the punishment (양형의 이유).
205 | - `ruling`: the ruling (주문) and its parsing result. `"" and -1` indicates null values.
206 |
207 | ### `ljp_civil`
208 |
209 | - Task: a model is asked to predict the claim acceptance level (= "the approved money" / "the claimed money")
210 | - 4,678 `facts` and the corresponding acceptance lv from 4 case categories: 929 examples from “price of
211 | indemnification” (구상금), 745 examples from “loan” (대여금), 1,004 examples from “unfair profits” (부당이득금), and 2,000
212 | examples from “lawsuit for damages (etc)” (손해배상(기)).
213 | - 3,742 training, 467 validation, 467 test, 403 test2 examples. The test2 set consists of the test set examples those do not overlap with the precedents in `precedent_corpus`.
214 | - Example
215 | ```json
216 | {
217 | "id": 99,
218 | "casetype": "civil",
219 | "casename": "구상금",
220 | "claim_acceptance_lv": 1,
221 | "facts": "가. C는 2017. 7. 21. D으로부터 100,000,000원을 이율 연 25%, 변제기 2017. 8. 20.로 정하여 차용하였고(이하 ‘이 사건 차용금채무'라고 한다), 피고는 이 사건 차용금 채무를 보증한도액 140,000,000원, 보증기한 10년으로 정하여 연대보증하였으며, 같은 날 이 사건 차용금채무에 관한 공정증서를 작성하였다(공증인가 법무법인 E 증서 2017년 제392호, 이하 ‘이 사건 공정증서'라고 한다).\n나. 원고는 이 사건 차용금채무와 관련하여 원고 소유의 안산시 상록구 F, G, H 및 그 지상 건물(이하 ‘이 사건 부동산'이라고 한다)을 담보로 제공하기로 하여 2017. 7. 21. 수원지방법원 안산지원 접수 제53820호로 채권최고액 140,000,000원, 채무자 C, 근저당권자 D으로 한 근저당권설정등기를 경료하는 한편, 2018. 7. 13. D에게 이 사건 공정증서에 기한 채무를 2018. 7. 31.까지 변제하고, 변제기 이후 연 24%의 비율로 계산한 지연손해금을 지급하기로 하는 차용증을 작성하여 주었다(이하 ‘이 사건 차용증'이라고 한다).\n다. 원고는 2019. 11. 29. D에게 이 사건 차용금채무 원리금으로 합계 157,500,000원을 변제하였다.",
222 | "gist_of_claim": {
223 | "money": {
224 | "provider": "피고",
225 | "taker": "원고",
226 | "unit": "won",
227 | "value": 140000000
228 | },
229 | "text": "피고는 원고에게 140,000,000원 및 이에 대한 2019. 11. 30.부터 이 사건 소장 부본 송달일까지는 연 5%의, 그 다음날부터 다 갚는 날까지는 연 12%의 각 비율로 계산한 돈을 지급하라."
230 | },
231 | "ruling": {
232 | "litigation_cost": 0.5,
233 | "money": {
234 | "provider": "피고",
235 | "taker": "원고",
236 | "unit": "won",
237 | "value": 78750000
238 | },
239 | "text": "1. 피고는 원고에게 78,750,000원 및 이에 대한 2019. 11. 30.부터 2021. 11. 26.까지는 연 5%의, 그 다음날부터 다 갚는 날까지는 연 12%의 각 비율로 계산한 돈을 지급하라.\n2. 원고의 나머지 청구를 기각한다.\n3. 소송비용 중 1/2은 원고가 나머지는 피고가 각 부담한다.\n4. 제1항은 가집행할 수 있다."
240 | }
241 | }
242 |
243 | ```
244 |
245 | - `id`: a data id.
246 | - `casetype`: a case type. The value is always `civil`.
247 | - `casename`: a case name.
248 | - `facts`: facts (사실관계) extracted from `reasoning` (이유) section of individual cases.
249 | - `claim_acceptaance_lv`: the claim acceptance level. `0`, `1`, and `2` indicate rejection, partial approval, and full approval respectively.
250 | - `gist_of_claim`: a gist of claim from plaintiffs (청구 취지) and its parsing result.
251 | - `ruling`: a ruling (주문) and its parsing results.
252 | - `litigation_cost`: the ratio of the litigation cost that the plaintiff should pay.
253 |
254 | ### `summarization`
255 |
256 | - Task: a model is asked to summarize precedents from the Supreme Court of Korea.
257 | - The dataset is obtained from [LAW OPEN DATA](https://www.law.go.kr/LSO/main.do).
258 | - The dataset consists of 20k `(precendent, summary)` pairs.
259 | - 16,000 training, 2,000 validation, and 2,000 test examples.
260 | - We also provide `summarization_plus` by extending `summarization` with precedents with longer text making the task more challenging and realistic. In the extended dataset there are a total of 51,114 examples. The average number of tokens in the precedents and the corresponding summaries are 1,516 and 248 respectively. The maximum number of tokens in the input texts and the summaries are 93,420 and 6,536 respectively.
261 |
262 | - Example
263 |
264 | ```json
265 | {
266 | "id": 16454,
267 | "summary": "[1] 피고와 제3자 사이에 있었던 민사소송의 확정판결의 존재를 넘어서 그 판결의 이유를 구성하는 사실관계들까지 법원에 현저한 사실로 볼 수는 없다. 민사재판에 있어서 이미 확정된 관련 민사사건의 판결에서 인정된 사실은 특별한 사정이 없는 한 유력한 증거가 되지만, 당해 민사재판에서 제출된 다른 증거 내용에 비추어 확정된 관련 민사사건 판결의 사실인정을 그대로 채용하기 어려운 경우에는 합리적인 이유를 설시하여 이를 배척할 수 있다는 법리도 그와 같이 확정된 민사판결 이유 중의 사실관계가 현저한 사실에 해당하지 않음을 전제로 한 것이다.\n\n\n[2] 원심이 다른 하급심판결의 이유 중 일부 사실관계에 관한 인정 사실을 그대로 인정하면서, 위 사정들이 ‘이 법원에 현저한 사실’이라고 본 사안에서, 당해 재판의 제1심 및 원심에서 다른 하급심판결의 판결문 등이 증거로 제출된 적이 없고, 당사자들도 이에 관하여 주장한 바가 없음에도 이를 ‘법원에 현저한 사실’로 본 원심판단에 법리오해의 잘못이 있다고 한 사례.",
268 | "precedent": "주문\n원심판결을 파기하고, 사건을 광주지방법원 본원 합의부에 환송한다.\n\n이유\n상고이유를 판단한다.\n1. 피고와 제3자 사이에 있었던 민사소송의 확정판결의 존재를 넘어서 그 판결의 이유를 구성하는 사실관계들까지 법원에 현저한 사실로 볼 수는 없다(대법원 2010. 1. 14. 선고 2009다69531 판결 참조). 민사재판에 있어서 이미 확정된 관련 민사사건의 판결에서 인정된 사실은 특별한 사정이 없는 한 유력한 증거가 되지만, 당해 민사재판에서 제출된 다른 증거 내용에 비추어 확정된 관련 민사사건 판결의 사실인정을 그대로 채용하기 어려운 경우에는 합리적인 이유를 설시하여 이를 배척할 수 있다는 법리(대법원 2018. 8. 30. 선고 2016다46338, 46345 판결 등 참조)도 그와 같이 확정된 민사판결 이유 중의 사실관계가 현저한 사실에 해당하지 않음을 전제로 한 것이다.\n2. 원심은 광주고등법원 2003나8816 판결 이유 중 ‘소외인이 피고 회사를 설립한 경위’에 관한 인정 사실, 광주지방법원 목포지원 2001가합1664 판결과 광주고등법원 2003나416 판결 이유 중 ‘피고 회사 이사회의 개최 여부’에 관한 인정 사실을 그대로 인정하면서, 위 사정들이 ‘이 법원에 현저한 사실’이라고 보았다.\n그런데 이 사건 기록에 의하면, 광주고등법원 2003나8816 판결, 광주지방법원 목포지원 2001가합1664 판결, 광주고등법원 2003나416 판결은 제1심 및 원심에서 판결문 등이 증거로 제출된 적이 없고, 당사자들도 이에 관하여 주장한 바가 없다.\n그렇다면 원심은 ‘법원에 현저한 사실’에 관한 법리를 오해한 나머지 필요한 심리를 다하지 아니한 채, 당사자가 증거로 제출하지 않고 심리가 되지 않았던 위 각 판결들에서 인정된 사실관계에 기하여 판단한 잘못이 있다. 이 점을 지적하는 상고이유 주장은 이유 있다.\n3. 그러므로 나머지 상고이유에 대한 판단을 생략한 채 원심판결을 파기하고, 사건을 다시 심리·판단하게 하기 위하여 원심법원에 환송하기로 하여, 관여 대법관의 일치된 의견으로 주문과 같이 판결한다."
269 | }
270 | ```
271 |
272 | - `id`: a data id.
273 | - `summary`: a summary (판결요지) of given precedent (판결문).
274 | - `precedent`: a case from the Korean supreme court.
275 |
276 |
277 |
278 | # Models
279 |
280 | ## How to use the language model `lcube-base`
281 | ```python
282 | # !pip instal transformers==4.19.4
283 | import transformers
284 |
285 | model = transformers.GPT2LMHeadModel.from_pretrained("lbox/lcube-base")
286 | tokenizer = transformers.AutoTokenizer.from_pretrained(
287 | "lbox/lcube-base",
288 | bos_token="[BOS]",
289 | unk_token="[UNK]",
290 | pad_token="[PAD]",
291 | mask_token="[MASK]",
292 | )
293 |
294 | text = "피고인은 불상지에 있는 커피숍에서, 피해자 B으로부터"
295 | model_inputs = tokenizer(text,
296 | max_length=1024,
297 | padding=True,
298 | truncation=True,
299 | return_tensors='pt')
300 | out = model.generate(
301 | model_inputs["input_ids"],
302 | max_new_tokens=150,
303 | pad_token_id=tokenizer.pad_token_id,
304 | use_cache=True,
305 | repetition_penalty=1.2,
306 | top_k=5,
307 | top_p=0.9,
308 | temperature=1,
309 | num_beams=2,
310 | )
311 | tokenizer.batch_decode(out)
312 | ```
313 |
314 | ## Fine-tuning
315 | ### Setup
316 |
317 | ```bash
318 | conda create -n lbox-open pytyon=3.8.11
319 | conda install pytorch==1.10.1 torchvision torchaudio cudatoolkit=11.3 -c pytorch
320 | pip install -r requirements.txt
321 | ```
322 |
323 | ### Training
324 |
325 | ```bash
326 | python run_model.py [TRINING_CONFIG_FILE_PATH] --mode train
327 | ````
328 | See also `scripts/train_[TASK].sh`
329 |
330 | ### Test
331 |
332 | 1. Make the test config file from the training config file by copying and changing the values of `trained` and `path` fields as shown below.
333 | ```yaml
334 | train:
335 | weights:
336 | trained: true
337 | path: ./models/[THE NAME OF THE TRAININ CONFIG FILE]/epoch=[XX]-step=[XX].ckpt
338 | ```
339 | 2.
340 | ```bash
341 | python run_model.py [TEST_CONFIG_FILE_PATH] --mode test
342 | ````
343 | See also `scripts/test_[TASK].sh`
344 |
345 |
346 |
347 | # Licensing Information
348 |
349 | Copyright 2022-present [LBox Co. Ltd.](https://lbox.kr/)
350 |
351 | Licensed under the [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/)
352 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.kogpt2.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.kogpt2.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.kogpt2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.lcube-base.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 64
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path:
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 1
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length: 64
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
93 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.lcube-base.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 64
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path:
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 1
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length: 64
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
93 |
--------------------------------------------------------------------------------
/configs/casename_classification/casename.lcube-base.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 64
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: casename_classification
17 | subtask: casename_classification
18 | target_field: facts
19 | target_parses_dict:
20 | casename_classification:
21 | - casename
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 10
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path:
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 1
70 | val_check_interval: 1.0
71 | validation_metric: em
72 | validation_target_parse: casename_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length: 64
79 | max_new_tokens: 64
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
93 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.kogpt2.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.kogpt2.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.kogpt2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.lcube-base.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.lcube-base.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/civil/ljp.civil.lcube-base.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1021
16 | task: ljp_civil
17 | subtask: civil
18 | target_field: facts
19 | target_parses_dict:
20 | claim_acceptance_lv:
21 | - claim_acceptance_lv
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: claim_acceptance_lv
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 3
80 | min_length: 1
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.kogpt2.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: kogpt2
29 | path: skt/kogpt2-base-v2
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 20
42 | multiple_trainloader_mode:
43 | seed: 2
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.kogpt2.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: kogpt2
29 | path: skt/kogpt2-base-v2
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 10
42 | multiple_trainloader_mode:
43 | seed: 3
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.kogpt2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: kogpt2
29 | path: skt/kogpt2-base-v2
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 20
42 | multiple_trainloader_mode:
43 | seed: 1
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.lcube-base.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: legal-gpt
29 | path: lbox/lcube-base
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 20
42 | multiple_trainloader_mode:
43 | seed: 2
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.lcube-base.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: legal-gpt
29 | path: lbox/lcube-base
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 10
42 | multiple_trainloader_mode:
43 | seed: 3
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/ljp/criminal/ljp.criminal.lcube-base.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1012
16 | task: ljp_criminal
17 | subtask: criminal
18 | target_field: facts
19 | target_parses_dict:
20 | fine_imprisonment_lvs:
21 | - fine_lv
22 | - imprisonment_with_labor_lv
23 | - imprisonment_without_labor_lv
24 | path_template:
25 | plm:
26 | freeze: false
27 | eval_mode: false
28 | name: legal-gpt
29 | path: lbox/lcube-base
30 | revision:
31 | precision: bf16
32 |
33 | train:
34 | accelerator: auto
35 | accumulate_grad_batches: 2
36 | limit_val_batches: 1.0
37 | batch_size: 4
38 | batch_size_prediction: 12
39 | check_val_every_n_epoch: 1
40 | fast_dev_run: false
41 | max_epochs: 20
42 | multiple_trainloader_mode:
43 | seed: 1
44 | strategy: null
45 | weight:
46 | trained: false
47 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
48 | save_path_dir: ./data/models
49 | do_not_load_pretrained_weight: false
50 | old_format: false
51 | log_dir: ./logs
52 | optim:
53 | gradient_clip_val: 1.0
54 | gradient_clip_algorithm: norm
55 | prompt:
56 | lr: 0.1
57 | optimizer_type: adamw
58 | lr_scheduler_type: warmup_constant
59 | lr_scheduler_param:
60 | warmup_constant:
61 | num_warmup_steps: 10
62 | plm:
63 | lr: 0.00005
64 | optimizer_type: adamw
65 | swa:
66 | use: true
67 | lr: 0.00005
68 | swa_epoch_start: 4
69 | annealing_epochs: 6
70 | profiler: null
71 | num_sanity_val_steps: 0
72 | val_check_interval: 0.5
73 | validation_metric: em
74 | validation_target_parse: fine_imprisonment_lvs
75 | validation_sub_param:
76 | method: average
77 | target_sub_parse: average
78 |
79 | infer:
80 | max_length:
81 | max_new_tokens: 12
82 | min_length: 5
83 | temperature: 1.0
84 | do_sample: False
85 | top_k: 0
86 | top_p: 0.9
87 | repetition_penalty: 1.0
88 | num_beams: 1
89 | bad_words_ids: null
90 | parse_sep_token: ","
91 | value_sep_token: "|"
92 | empty_token: "0"
93 |
94 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.kogpt2.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.kogpt2.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.kogpt2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.lcube-base.e2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 2
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.lcube-base.e3.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 3
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/statute_classification/statute.lcube-base.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test2
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 960
16 | task: statute_classification
17 | subtask: statute_classification
18 | target_field: facts
19 | target_parses_dict:
20 | statute_classification:
21 | - statute
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: bf16
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 4
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 1
38 | fast_dev_run: false
39 | max_epochs: 15
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path: data/models/casename.lv1.d0.1.1.e1.lgpt_tune_plm_only.yaml/epoch=3-step=5335.ckpt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.00005
62 | optimizer_type: adamw
63 | swa:
64 | use: true
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: em
72 | validation_target_parse: statute_classification
73 | validation_sub_param:
74 | method: text_em
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 64
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: ","
89 | value_sep_token: "|"
90 | empty_token: "0"
91 |
92 |
--------------------------------------------------------------------------------
/configs/summarization/summarization.kogpt2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 768
16 | task: summarization
17 | subtask: summarization
18 | target_field: precedent
19 | target_parses_dict:
20 | summarization:
21 | - summarization
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: kogpt2
27 | path: skt/kogpt2-base-v2
28 | revision:
29 | precision: 32
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 6
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 2
38 | fast_dev_run: false
39 | max_epochs: 20
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path:
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: false
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 1.0
71 | validation_metric: rougeL
72 | validation_target_parse: summarization
73 | validation_sub_param:
74 | method: rougeL
75 | target_sub_parse:
76 |
77 |
78 | infer:
79 | max_length:
80 | max_new_tokens: 256
81 | min_length: 5
82 | temperature: 1.0
83 | do_sample: False
84 | top_k: 0
85 | top_p: 0.9
86 | repetition_penalty: 1.0
87 | num_beams: 1
88 | bad_words_ids: null
89 | parse_sep_token: "*"
90 | value_sep_token: "|"
91 | empty_token: "없음"
92 |
93 |
--------------------------------------------------------------------------------
/configs/summarization/summarization.lcube-base.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 1024
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 768
16 | task: summarization
17 | subtask: summarization
18 | target_field: precedent
19 | target_parses_dict:
20 | summarization:
21 | - summarization
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: legal-gpt
27 | path: lbox/lcube-base
28 | revision:
29 | precision: 32
30 |
31 | train:
32 | accelerator: auto
33 | accumulate_grad_batches: 2
34 | limit_val_batches: 1.0
35 | batch_size: 6
36 | batch_size_prediction: 12
37 | check_val_every_n_epoch: 2
38 | fast_dev_run: false
39 | max_epochs: 20
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: false
45 | path:
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: false
65 | lr: 0.00005
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: rougeL
72 | validation_target_parse: summarization
73 | validation_sub_param:
74 | method: rougeL
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length:
79 | max_new_tokens: 256
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
--------------------------------------------------------------------------------
/configs/summarization/summarization.legal-mt5s.test.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_card: lbox/lbox_open
3 | training_set_name: train
4 | validation_set_name: validation
5 | test_set_name: test
6 | use_local_data: false
7 | path_train:
8 | path_valid:
9 | path_test:
10 |
11 | model:
12 | decoder_max_length: 512
13 | input_template_type: 0
14 | model_type: generative
15 | max_seq_length: 1024
16 | task: summarization
17 | subtask: summarization
18 | target_field: precedent
19 | target_parses_dict:
20 | summarization:
21 | - summarization
22 | path_template:
23 | plm:
24 | freeze: false
25 | eval_mode: false
26 | name: mt5
27 | path: google/mt5-small
28 | revision:
29 | precision: bf16
30 | train:
31 | accelerator: auto
32 | accumulate_grad_batches: 1
33 | limit_train_batches: 0.2
34 | limit_val_batches: 4
35 | batch_size: 12
36 | batch_size_prediction: 36
37 | check_val_every_n_epoch: 2
38 | fast_dev_run: false
39 | max_epochs: 60
40 | multiple_trainloader_mode:
41 | seed: 1
42 | strategy: null
43 | weight:
44 | trained: true
45 | path: saved/models/lbox-open/legal-mt5s-summarization.pt
46 | save_path_dir: ./data/models
47 | do_not_load_pretrained_weight: false
48 | old_format: false
49 | log_dir: ./logs
50 | optim:
51 | gradient_clip_val: 1.0
52 | gradient_clip_algorithm: norm
53 | prompt:
54 | lr: 0.1
55 | optimizer_type: adamw
56 | lr_scheduler_type: warmup_constant
57 | lr_scheduler_param:
58 | warmup_constant:
59 | num_warmup_steps: 10
60 | plm:
61 | lr: 0.0001
62 | optimizer_type: adamw
63 | swa:
64 | use: false
65 | lr: 0.0001
66 | swa_epoch_start: 4
67 | annealing_epochs: 6
68 | profiler: null
69 | num_sanity_val_steps: 0
70 | val_check_interval: 0.5
71 | validation_metric: rougeL
72 | validation_target_parse: summarization
73 | validation_sub_param:
74 | method: rougeL
75 | target_sub_parse:
76 |
77 | infer:
78 | max_length: 512
79 | max_new_tokens: 512
80 | min_length: 5
81 | temperature: 1.0
82 | do_sample: False
83 | top_k: 0
84 | top_p: 0.9
85 | repetition_penalty: 1.0
86 | num_beams: 1
87 | bad_words_ids: null
88 | parse_sep_token: "*"
89 | value_sep_token: "|"
90 | empty_token: "없음"
91 |
92 |
93 |
--------------------------------------------------------------------------------
/lbox_open/constants/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants_fie import *
2 |
--------------------------------------------------------------------------------
/lbox_open/constants/constants_fie.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL = {
6 | "fine_lv": "벌금",
7 | "imprisonment_with_labor_lv": "징역",
8 | "imprisonment_without_labor_lv": "금고",
9 | }
10 |
--------------------------------------------------------------------------------
/lbox_open/data_module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/data_module/__init__.py
--------------------------------------------------------------------------------
/lbox_open/data_module/data_precedent.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import datasets
6 | import pytorch_lightning as pl
7 | from openprompt import PromptDataLoader
8 | from pytorch_lightning.trainer.supporters import CombinedLoader
9 |
10 | from lbox_open import openprompt_wrapper
11 | from lbox_open.template import prompt_generation_utils
12 |
13 |
14 | class PrecedentData(object):
15 | def __init__(self, cfg, mode, target_parse, target_sub_parses, raw_data):
16 | assert mode in ["train", "valid", "test", "predict"]
17 | self.cfg = cfg
18 | self.mode = mode
19 | self.target_parse = target_parse
20 | self.label_key = self._get_label_key(target_parse)
21 | self.target_sub_parses = target_sub_parses
22 | self.data_aug_param = cfg.train.get("data_aug_param", None)
23 | self.doc_id_key = self._get_doc_id(cfg.model.task)
24 | if raw_data is not None:
25 | self.features = self._gen_input_features(raw_data)
26 |
27 | def __getitem__(self, idx):
28 | return self.features[idx]
29 |
30 | def get_text_a(self, raw_data1):
31 | if isinstance(self.cfg.model.target_field, list):
32 | text_a = ""
33 | if self.cfg.model.task == "ljp_civil":
34 | for i, k in enumerate(self.cfg.model.target_field):
35 | if k == "facts":
36 | text_a += f"사실관계: {raw_data1[k]}\n"
37 | elif k == "claim":
38 | text_a += f"청구취지: {raw_data1[k]['text']}\n"
39 | else:
40 | raise NotImplementedError
41 | text_a = text_a.strip()
42 | else:
43 | for i, k in enumerate(self.cfg.model.target_field):
44 | text_a += f"{raw_data1[k]}\n"
45 | text_a = text_a.strip()
46 |
47 | else:
48 | text_a = raw_data1[self.cfg.model.target_field]
49 | return text_a
50 |
51 | def _get_label_key(self, target_parse):
52 | if target_parse in ["claim_acceptance_lv"]:
53 | label_key = "claim_acceptance_lv"
54 | elif target_parse in ["casename_classification"]:
55 | label_key = "casename"
56 | elif target_parse in ["statute_classification"]:
57 | label_key = "statutes"
58 | elif target_parse in ["summarization"]:
59 | label_key = "summary"
60 | else:
61 | label_key = "label"
62 |
63 | return label_key
64 |
65 | def _gen_input_features(self, raw_data):
66 | features = []
67 |
68 | for i, raw_data1 in enumerate(raw_data):
69 | try:
70 | text_a = self.get_text_a(raw_data1)
71 |
72 | if self.label_key in raw_data1:
73 | tgt_text = prompt_generation_utils.gen_output_template(
74 | self.cfg.model.task,
75 | self.target_parse,
76 | self.target_sub_parses,
77 | raw_data1[self.label_key],
78 | self.cfg.infer.parse_sep_token,
79 | )
80 | else:
81 | assert self.mode == "predict"
82 | tgt_text = "This is a dummy text."
83 |
84 | feature = openprompt_wrapper.InputExampleWrapper(
85 | text_a=text_a,
86 | text_b="",
87 | tgt_text=tgt_text,
88 | guid=str(raw_data1[self.doc_id_key]),
89 | )
90 | except Exception as e:
91 | print(f"doc_id: {self.doc_id_key}")
92 | print(repr(e))
93 | raise e
94 | features.append(feature)
95 | return features
96 |
97 | def __len__(self):
98 | if self.mode != "predict":
99 | return len(self.features)
100 | else:
101 | return 0
102 |
103 | def __iter__(self):
104 | self.features.__iter__()
105 |
106 | def _get_doc_id(self, task):
107 |
108 | if task in [
109 | "ljp_civil",
110 | "ljp_criminal",
111 | "casename_classification",
112 | "statute_classification",
113 | "summarization",
114 | ]:
115 | doc_id_key = "id"
116 | else:
117 | raise NotImplementedError
118 | return doc_id_key
119 |
120 |
121 | class PrecedentDataModule(pl.LightningDataModule):
122 | def __init__(
123 | self, cfg, plm_tokenizer, TokenizerWrapper, input_templates, raw_data=None
124 | ):
125 | super().__init__()
126 | self.cfg = cfg
127 | self.task = cfg.model.task
128 | self.raw_data = raw_data
129 |
130 | self.plm_tokenizer = plm_tokenizer
131 | self.TokenizerWrapperClass = TokenizerWrapper
132 |
133 | self.data_ts = {}
134 | self.data_vs = {}
135 | self.data_es = {}
136 |
137 | self.input_templates = input_templates
138 | self.target_parses_dict = cfg.model.target_parses_dict
139 | if len(self.target_parses_dict) > 1:
140 | raise Exception("Multitask learning is currently not supported!")
141 |
142 | self.use_local_data = cfg.data.use_local_data
143 | self.dataset_card = cfg.data.dataset_card
144 |
145 | self.training_set_name = cfg.data.training_set_name
146 | self.validation_set_name = cfg.data.validation_set_name
147 | self.test_set_name = cfg.data.test_set_name
148 |
149 | def setup(self, stage):
150 | if not self.use_local_data:
151 | assert self.raw_data is None
152 | self.raw_data = datasets.load_dataset(self.dataset_card, self.task)
153 |
154 | # Assign train/val datasets for use in dataloaders
155 | if stage in ["fit", "test"] or stage is None:
156 | for target_parse, target_sub_parses in self.target_parses_dict.items():
157 | self.data_ts[target_parse] = PrecedentData(
158 | self.cfg,
159 | "train",
160 | target_parse,
161 | target_sub_parses,
162 | self.raw_data[self.training_set_name],
163 | ).features
164 | self.data_vs[target_parse] = PrecedentData(
165 | self.cfg,
166 | "valid",
167 | target_parse,
168 | target_sub_parses,
169 | self.raw_data[self.validation_set_name],
170 | ).features
171 | if "test" in self.raw_data:
172 | self.data_es[target_parse] = PrecedentData(
173 | self.cfg,
174 | "test",
175 | target_parse,
176 | target_sub_parses,
177 | self.raw_data[self.test_set_name],
178 | ).features
179 |
180 | def train_dataloader(self):
181 | data_loaders = {}
182 | for target_parse, target_sub_parses in self.target_parses_dict.items():
183 | data_loaders[target_parse] = PromptDataLoader(
184 | dataset=self.data_ts[target_parse],
185 | template=self.input_templates[target_parse],
186 | tokenizer=self.plm_tokenizer,
187 | tokenizer_wrapper_class=self.TokenizerWrapperClass,
188 | max_seq_length=self.cfg.model.max_seq_length,
189 | decoder_max_length=self.cfg.model.decoder_max_length,
190 | batch_size=self.cfg.train.batch_size,
191 | shuffle=True,
192 | teacher_forcing=True,
193 | predict_eos_token=True,
194 | truncate_method="head",
195 | ).dataloader
196 |
197 | return data_loaders
198 |
199 | def val_dataloader(self):
200 | data_loaders = {}
201 |
202 | for target_parse, target_sub_parses in self.target_parses_dict.items():
203 | data_loaders[target_parse] = PromptDataLoader(
204 | dataset=self.data_vs[target_parse],
205 | template=self.input_templates[target_parse],
206 | tokenizer=self.plm_tokenizer,
207 | tokenizer_wrapper_class=self.TokenizerWrapperClass,
208 | max_seq_length=self.cfg.model.max_seq_length,
209 | decoder_max_length=self.cfg.model.decoder_max_length,
210 | batch_size=self.cfg.train.batch_size_prediction,
211 | shuffle=False,
212 | teacher_forcing=False,
213 | predict_eos_token=True,
214 | truncate_method="head",
215 | ).dataloader
216 |
217 | data_loaders = CombinedLoader(data_loaders)
218 |
219 | return data_loaders
220 |
221 | def test_dataloader(self):
222 | data_loaders = {}
223 | for target_parse, target_sub_parses in self.target_parses_dict.items():
224 | data_loaders[target_parse] = PromptDataLoader(
225 | dataset=self.data_es[target_parse],
226 | template=self.input_templates[target_parse],
227 | tokenizer=self.plm_tokenizer,
228 | tokenizer_wrapper_class=self.TokenizerWrapperClass,
229 | max_seq_length=self.cfg.model.max_seq_length,
230 | decoder_max_length=self.cfg.model.decoder_max_length,
231 | batch_size=self.cfg.train.batch_size_prediction,
232 | shuffle=False,
233 | teacher_forcing=False,
234 | predict_eos_token=True,
235 | truncate_method="head",
236 | ).dataloader
237 |
238 | data_loaders = CombinedLoader(data_loaders)
239 |
240 | return data_loaders
241 |
--------------------------------------------------------------------------------
/lbox_open/datasets_script/lbox_open.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 | # 2022.10.18, Wonseok: Add casename_classification_plus, statute_classification_plus, summarization_plus datasets
5 |
6 |
7 | import json
8 |
9 | import datasets
10 |
11 | _CASENAME_CLASSIFICATION_FEATURES = {
12 | "id": datasets.Value("int64"),
13 | "casetype": datasets.Value("string"),
14 | "casename": datasets.Value("string"),
15 | "facts": datasets.Value("string"),
16 | }
17 |
18 |
19 | _STATUTE_CLASSIFICATION_FEATURES = {
20 | "id": datasets.Value("int64"),
21 | "casetype": datasets.Value("string"),
22 | "casename": datasets.Value("string"),
23 | "statutes": datasets.features.Sequence(datasets.Value("string")),
24 | "facts": datasets.Value("string"),
25 | }
26 |
27 | _LJP_CRIMINAL = {
28 | "id": datasets.Value("int64"),
29 | "casetype": datasets.Value("string"),
30 | "casename": datasets.Value("string"),
31 | "facts": datasets.Value("string"),
32 | "reason": datasets.Value("string"),
33 | "label": {
34 | "text": datasets.Value("string"),
35 | "fine_lv": datasets.Value("int64"),
36 | "imprisonment_with_labor_lv": datasets.Value("int64"),
37 | "imprisonment_without_labor_lv": datasets.Value("int64"),
38 | },
39 | "ruling": {
40 | "text": datasets.Value("string"),
41 | "parse": {
42 | "fine": {
43 | "type": datasets.Value("string"),
44 | "unit": datasets.Value("string"),
45 | "value": datasets.Value("int64"),
46 | },
47 | "imprisonment": {
48 | "type": datasets.Value("string"),
49 | "unit": datasets.Value("string"),
50 | "value": datasets.Value("int64"),
51 | },
52 | },
53 | },
54 | }
55 |
56 | _LJP_CIVIL = {
57 | "id": datasets.Value("int64"),
58 | "casetype": datasets.Value("string"),
59 | "casename": datasets.Value("string"),
60 | "facts": datasets.Value("string"),
61 | "claim_acceptance_lv": datasets.Value("int64"),
62 | "gist_of_claim": {
63 | "text": datasets.Value("string"),
64 | "money": {
65 | "provider": datasets.Value("string"),
66 | "taker": datasets.Value("string"),
67 | "unit": datasets.Value("string"),
68 | "value": datasets.Value("int64"),
69 | },
70 | },
71 | "ruling": {
72 | "text": datasets.Value("string"),
73 | "money": {
74 | "provider": datasets.Value("string"),
75 | "taker": datasets.Value("string"),
76 | "unit": datasets.Value("string"),
77 | "value": datasets.Value("int64"),
78 | },
79 | "litigation_cost": datasets.Value("float32"),
80 | },
81 | }
82 |
83 | _SUMMARIZATION_FEATURES = {
84 | "id": datasets.Value("int64"),
85 | "summary": datasets.Value("string"),
86 | "precedent": datasets.Value("string"),
87 | }
88 |
89 | _PRECEDENT_CORPUS_FEATURES = {
90 | "id": datasets.Value("int64"),
91 | "precedent": datasets.Value("string"),
92 | }
93 |
94 |
95 | class LBoxOpenConfig(datasets.BuilderConfig):
96 | """BuilderConfig for OpenLBox."""
97 |
98 | def __init__(
99 | self,
100 | features,
101 | data_url,
102 | citation,
103 | url,
104 | label_classes=("False", "True"),
105 | **kwargs,
106 | ):
107 | # Version history:
108 | # 0.1.0: Initial version.
109 | super(LBoxOpenConfig, self).__init__(
110 | version=datasets.Version("0.2.0"), **kwargs
111 | )
112 | self.features = features
113 | self.label_classes = label_classes
114 | self.data_url = data_url
115 | self.citation = citation
116 | self.url = url
117 |
118 |
119 | class LBoxOpen(datasets.GeneratorBasedBuilder):
120 | """The Legal AI Benchmark dataset from Korean Legal Cases."""
121 |
122 | BUILDER_CONFIGS = [
123 | LBoxOpenConfig(
124 | name="casename_classification",
125 | description="",
126 | features=_CASENAME_CLASSIFICATION_FEATURES,
127 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/casename_classification/v0.1.2/",
128 | citation="",
129 | url="lbox.kr",
130 | ),
131 | LBoxOpenConfig(
132 | name="casename_classification_plus",
133 | description="",
134 | features=_CASENAME_CLASSIFICATION_FEATURES,
135 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/casename_classification/v0.1.2_plus/",
136 | citation="",
137 | url="lbox.kr",
138 | ),
139 | LBoxOpenConfig(
140 | name="statute_classification",
141 | description="",
142 | features=_STATUTE_CLASSIFICATION_FEATURES,
143 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/statute_classification/v0.1.2/",
144 | citation="",
145 | url="lbox.kr",
146 | ),
147 | LBoxOpenConfig(
148 | name="statute_classification_plus",
149 | description="",
150 | features=_STATUTE_CLASSIFICATION_FEATURES,
151 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/statute_classification/v0.1.2_plus/",
152 | citation="",
153 | url="lbox.kr",
154 | ),
155 | LBoxOpenConfig(
156 | name="ljp_criminal",
157 | description="",
158 | features=_LJP_CRIMINAL,
159 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/judgement_prediction/v0.1.2/criminal/",
160 | citation="",
161 | url="lbox.kr",
162 | ),
163 | LBoxOpenConfig(
164 | name="ljp_civil",
165 | description="",
166 | features=_LJP_CIVIL,
167 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/judgement_prediction/v0.1.2/civil/",
168 | citation="",
169 | url="lbox.kr",
170 | ),
171 | LBoxOpenConfig(
172 | name="summarization",
173 | description="",
174 | features=_SUMMARIZATION_FEATURES,
175 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/summarization/v0.1.0/",
176 | citation="",
177 | url="lbox.kr",
178 | ),
179 | LBoxOpenConfig(
180 | name="summarization_plus",
181 | description="",
182 | features=_SUMMARIZATION_FEATURES,
183 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/summarization/v0.1.0_plus/",
184 | citation="",
185 | url="lbox.kr",
186 | ),
187 | LBoxOpenConfig(
188 | name="precedent_corpus",
189 | description="",
190 | features=_PRECEDENT_CORPUS_FEATURES,
191 | data_url="https://lbox-open.s3.ap-northeast-2.amazonaws.com/precedent_benchmark_dataset/case_corpus/v0.1.0/",
192 | citation="",
193 | url="lbox.kr",
194 | ),
195 | ]
196 |
197 | def _info(self):
198 | return datasets.DatasetInfo(
199 | description="",
200 | features=datasets.Features(self.config.features),
201 | homepage=self.config.url,
202 | citation="",
203 | )
204 |
205 | def _split_generators(self, dl_manager):
206 | if self.config.name == "precedent_corpus":
207 | dl_dir = {
208 | "train": dl_manager.download_and_extract(
209 | f"{self.config.data_url}case_corpus-150k.jsonl"
210 | )
211 | or "",
212 | }
213 |
214 | return [
215 | datasets.SplitGenerator(
216 | name=datasets.Split.TRAIN,
217 | gen_kwargs={
218 | "data_file": dl_dir["train"],
219 | "split": datasets.Split.TRAIN,
220 | },
221 | )
222 | ]
223 |
224 | elif self.config.name in [
225 | "casename_classification",
226 | "statute_classification",
227 | "ljp_criminal",
228 | "ljp_civil",
229 | ]:
230 | dl_dir = {
231 | "train": dl_manager.download_and_extract(
232 | f"{self.config.data_url}train.jsonl"
233 | )
234 | or "",
235 | "valid": dl_manager.download_and_extract(
236 | f"{self.config.data_url}valid.jsonl"
237 | )
238 | or "",
239 | "test": dl_manager.download_and_extract(
240 | f"{self.config.data_url}test.jsonl"
241 | )
242 | or "",
243 | "test2": dl_manager.download_and_extract(
244 | f"{self.config.data_url}test2.jsonl"
245 | )
246 | or "",
247 | }
248 |
249 | return [
250 | datasets.SplitGenerator(
251 | name=datasets.Split.TRAIN,
252 | gen_kwargs={
253 | "data_file": dl_dir["train"],
254 | "split": datasets.Split.TRAIN,
255 | },
256 | ),
257 | datasets.SplitGenerator(
258 | name=datasets.Split.VALIDATION,
259 | gen_kwargs={
260 | "data_file": dl_dir["valid"],
261 | "split": datasets.Split.VALIDATION,
262 | },
263 | ),
264 | datasets.SplitGenerator(
265 | name=datasets.Split.TEST,
266 | gen_kwargs={
267 | "data_file": dl_dir["test"],
268 | "split": datasets.Split.TEST,
269 | },
270 | ),
271 | datasets.SplitGenerator(
272 | name="test2",
273 | gen_kwargs={
274 | "data_file": dl_dir["test2"],
275 | "split": "test2",
276 | },
277 | ),
278 | ]
279 | else:
280 | dl_dir = {
281 | "train": dl_manager.download_and_extract(
282 | f"{self.config.data_url}train.jsonl"
283 | )
284 | or "",
285 | "valid": dl_manager.download_and_extract(
286 | f"{self.config.data_url}valid.jsonl"
287 | )
288 | or "",
289 | "test": dl_manager.download_and_extract(
290 | f"{self.config.data_url}test.jsonl"
291 | )
292 | or "",
293 | }
294 |
295 | return [
296 | datasets.SplitGenerator(
297 | name=datasets.Split.TRAIN,
298 | gen_kwargs={
299 | "data_file": dl_dir["train"],
300 | "split": datasets.Split.TRAIN,
301 | },
302 | ),
303 | datasets.SplitGenerator(
304 | name=datasets.Split.VALIDATION,
305 | gen_kwargs={
306 | "data_file": dl_dir["valid"],
307 | "split": datasets.Split.VALIDATION,
308 | },
309 | ),
310 | datasets.SplitGenerator(
311 | name=datasets.Split.TEST,
312 | gen_kwargs={
313 | "data_file": dl_dir["test"],
314 | "split": datasets.Split.TEST,
315 | },
316 | ),
317 | ]
318 |
319 | def _generate_examples(self, data_file, split):
320 | with open(data_file, encoding="utf-8") as f:
321 | for line in f:
322 | row = json.loads(line)
323 | yield row["id"], row
324 |
--------------------------------------------------------------------------------
/lbox_open/metric/exact_match.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | from collections import defaultdict
6 |
7 |
8 | class ExactMatch:
9 | def __init__(self, parse_keys, empty_value):
10 |
11 | if "doc_id" in parse_keys:
12 | parse_keys.remove("doc_id")
13 |
14 | self.parse_keys = parse_keys
15 |
16 | self.empty_value = empty_value
17 |
18 | def is_empty(self, value):
19 | return (str(value) == str(self.empty_value)) or (value is None)
20 |
21 | def compare_parse(self, gt_parse, pr_parse):
22 | cnt_tp = defaultdict(int) # both exsit and pr is correct
23 | cnt_fn = defaultdict(int) # gt exists but pr is empty
24 | cnt_fp = defaultdict(
25 | int
26 | ) # [gt empty but pr exists] or [gt exists yet pr is wrong]
27 | cnt_tn = defaultdict(int) # gt & pr both empty
28 |
29 | for key in self.parse_keys:
30 | gt_val = gt_parse[key]
31 | pr_val = pr_parse[key]
32 |
33 | if self.is_empty(gt_val):
34 | if self.is_empty(pr_val):
35 | cnt_tn[key] += 1
36 | else:
37 | cnt_fp[key] += 1
38 | else:
39 | if self.is_empty(pr_val):
40 | cnt_fn[key] += 1
41 | else:
42 | if str(gt_val) == str(pr_val):
43 | cnt_tp[key] += 1
44 | else:
45 | cnt_fp[key] += 1
46 |
47 | return (cnt_tp, cnt_fp, cnt_fn, cnt_tn)
48 |
49 | def imp_fill_cnt(self, cnt_all, cnt):
50 | for key in self.parse_keys:
51 | cnt_all[key] += cnt[key]
52 |
53 | def calculate_micro_f1(self, cnt_tp_all, cnt_fp_all, cnt_fn_all):
54 | f1_all = {}
55 | for key in self.parse_keys:
56 | tp = cnt_tp_all[key]
57 | fp = cnt_fp_all[key]
58 | fn = cnt_fn_all[key]
59 |
60 | p = tp / (tp + fp + 1e-5)
61 | r = tp / (tp + fn + 1e-5)
62 | f1 = 2 * p * r / (p + r + 1e-5)
63 |
64 | f1_all[key] = f1
65 |
66 | return f1_all
67 |
68 | def compare_parses(self, gt_parses, pr_parses, confidences=None, threshold=0.0):
69 | cnt_tp_all = defaultdict(int)
70 | cnt_fn_all = defaultdict(int)
71 | cnt_fp_all = defaultdict(int)
72 | cnt_tn_all = defaultdict(int)
73 | if confidences is None:
74 | confidences = [1.0] * len(gt_parses)
75 | assert threshold == 0.0
76 | cnt = 0
77 | for gt_parse, pr_parse, confidence in zip(gt_parses, pr_parses, confidences):
78 | if confidence < threshold:
79 | continue
80 | cnt += 1
81 | (cnt_tp, cnt_fp, cnt_fn, cnt_tn) = self.compare_parse(gt_parse, pr_parse)
82 |
83 | self.imp_fill_cnt(cnt_tp_all, cnt_tp)
84 | self.imp_fill_cnt(cnt_fp_all, cnt_fp)
85 | self.imp_fill_cnt(cnt_fn_all, cnt_fn)
86 | self.imp_fill_cnt(cnt_tn_all, cnt_tn)
87 |
88 | f1_all = self.calculate_micro_f1(cnt_tp_all, cnt_fp_all, cnt_fn_all)
89 | th_recall = cnt / len(confidences)
90 |
91 | return (
92 | f1_all,
93 | cnt_tp_all,
94 | cnt_fp_all,
95 | cnt_fn_all,
96 | cnt_tn_all,
97 | th_recall,
98 | )
99 |
--------------------------------------------------------------------------------
/lbox_open/metric/rouge_metric_utils.py:
--------------------------------------------------------------------------------
1 | from rouge_score.tokenizers import Tokenizer
2 |
3 |
4 | class WhiteSpaceTokenizer(Tokenizer):
5 | def tokenize(self, text):
6 | return text.split()
7 |
--------------------------------------------------------------------------------
/lbox_open/model/generative_baseline_model.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import os
6 | from collections import defaultdict
7 | from itertools import zip_longest
8 | from pathlib import Path
9 | from pprint import pprint
10 |
11 |
12 | import datasets
13 | import pytorch_lightning as pl
14 | import torch
15 | from openprompt.utils.metrics import generation_metric
16 | from transformers.generation_utils import GenerationMixin
17 | from rouge_score import rouge_scorer
18 | import numpy as np
19 |
20 | import lbox_open.utils.general_utils as gu
21 | from lbox_open import openprompt_wrapper
22 | from lbox_open.model.model_optimizer import get_lr_dict, get_optimizer
23 | from lbox_open.parser.output_parser_utils import (
24 | cal_em_from_parses,
25 | get_parses_from_eval_results,
26 | )
27 | from lbox_open.metric import rouge_metric_utils
28 |
29 |
30 | class GenerativeParser(pl.LightningModule, GenerationMixin):
31 | def __init__(self, cfg, plm, plm_tokenizer, input_templates):
32 | super().__init__()
33 | self.task = cfg.model.task
34 | self.mparam = cfg.model
35 | self.tparam = cfg.train
36 | self.iparam = cfg.infer
37 | self.cfg_name = cfg.name
38 | self.target_parses_dict = cfg.model.target_parses_dict
39 |
40 | self.prompt_models = {}
41 | self.plm = plm
42 | for target_parse, target_sub_parses in cfg.model.target_parses_dict.items():
43 | # keep them for just in case we tune plm
44 | prompt_model = openprompt_wrapper.PromptForGenerationCustom(
45 | plm=plm,
46 | template=input_templates[target_parse],
47 | freeze_plm=cfg.model.plm.freeze,
48 | tokenizer=plm_tokenizer,
49 | plm_eval_mode=cfg.model.plm.eval_mode,
50 | )
51 |
52 | self.prompt_models[target_parse] = prompt_model
53 |
54 | self.prompt_models = torch.nn.ModuleDict(self.prompt_models)
55 |
56 | # if self.plm.config.is_encoder_decoder:
57 | self.generation_arguments = {
58 | "max_length": cfg.infer.max_length,
59 | "max_new_tokens": cfg.infer.get("max_new_tokens", None),
60 | "min_length": cfg.infer.min_length,
61 | "temperature": cfg.infer.temperature,
62 | "do_sample": cfg.infer.do_sample,
63 | "top_k": cfg.infer.top_k,
64 | "top_p": cfg.infer.top_p,
65 | "repetition_penalty": cfg.infer.repetition_penalty,
66 | "num_beams": cfg.infer.num_beams,
67 | "bad_words_ids": cfg.infer.bad_words_ids,
68 | "use_cache": True,
69 | }
70 |
71 | if plm.config.is_encoder_decoder:
72 | # remove max_new_tokens
73 | print(f"The model is of is_encoder_decoder. Thus we remove max new tokens.")
74 | self.generation_arguments.pop("max_new_tokens")
75 | else:
76 | if cfg.infer.get("max_new_tokens", None):
77 | print(
78 | f"Max length in generation option shall be ignored as max_new_tokens presents."
79 | )
80 | self.generation_arguments["max_length"] = None
81 |
82 | self.rouge_scorer = rouge_scorer.RougeScorer(
83 | ["rouge1", "rouge2", "rougeL"], tokenizer=rouge_metric_utils.WhiteSpaceTokenizer()
84 | )
85 | def forward(self, target_parse, batch):
86 | loss = self.prompt_models[target_parse](batch[target_parse])
87 | return loss
88 |
89 | def training_step(self, batch, batch_idx):
90 | n_keys = len(self.target_parses_dict)
91 | loss = 0
92 | for i_target, (target_parse, _) in enumerate(self.target_parses_dict.items()):
93 | loss += self.forward(target_parse, batch)
94 | return {"loss": loss / n_keys}
95 |
96 | def training_epoch_end(self, outputs):
97 |
98 | loss_all = torch.stack(self.gather_loss(outputs))
99 | ave_loss = torch.mean(loss_all)
100 | self.log("training__ave_loss", ave_loss)
101 |
102 | def gather_loss(self, outputs):
103 | loss_all = []
104 | for output in outputs:
105 | loss_all.append(output["loss"])
106 |
107 | return loss_all
108 |
109 | def validation_step(self, batch, batch_idx):
110 | return self._eval_step(batch, batch_idx)
111 |
112 | def validation_epoch_end(self, outputs):
113 | (
114 | eval_score,
115 | doc_ids_all,
116 | pr_texts_all,
117 | gt_texts_all,
118 | confidences_all,
119 | ) = self._eval_epoch_end(outputs)
120 | print("\nValidation!-----------------------------------------")
121 | pprint(eval_score)
122 | pprint(f"GT: {gt_texts_all[self.tparam.validation_target_parse][0:2]}")
123 | pprint(f"PR: {pr_texts_all[self.tparam.validation_target_parse][0:2]}")
124 |
125 | if self.tparam.validation_metric in ["sentence_bleu"]:
126 | validation_score = eval_score[self.tparam.validation_target_parse]
127 |
128 | elif self.tparam.validation_metric in ["rougeL"]:
129 | validation_score = eval_score[self.tparam.validation_target_parse]
130 |
131 | elif self.tparam.validation_metric in ["em"]:
132 | if self.tparam.validation_sub_param.method == "single_parse":
133 | sub_parse_name = self.tparam.validation_sub_param.target_sub_parse
134 | validation_score = eval_score[self.tparam.validation_target_parse][
135 | "f1"
136 | ][sub_parse_name]
137 | elif self.tparam.validation_sub_param.method == "average":
138 | validation_score = 0
139 | cnt = 0
140 | for sub_parse_name, score in eval_score[
141 | self.tparam.validation_target_parse
142 | ]["f1"].items():
143 | validation_score += score
144 | cnt += 1
145 | validation_score /= cnt
146 | elif self.tparam.validation_sub_param.method == "text_em":
147 | validation_score = eval_score[self.tparam.validation_target_parse][
148 | "text_em"
149 | ]
150 | else:
151 | raise ValueError
152 | for sub_parse_name, score in eval_score[
153 | self.tparam.validation_target_parse
154 | ]["f1"].items():
155 | self.log(sub_parse_name, score)
156 | self.log(
157 | f"{self.tparam.validation_target_parse}_text_em",
158 | eval_score[self.tparam.validation_target_parse]["text_em"],
159 | )
160 | else:
161 | raise ValueError
162 |
163 | self.log(
164 | f"{self.tparam.validation_metric}_{self.tparam.validation_sub_param.method}",
165 | validation_score,
166 | )
167 |
168 | def test_step(self, batch, batch_idx):
169 | return self._eval_step(batch, batch_idx)
170 |
171 | def test_epoch_end(self, outputs):
172 | output_save_dir = (
173 | Path(self.tparam.weight.path).parent / "analysis" / self.cfg_name
174 | )
175 | os.makedirs(output_save_dir, exist_ok=True)
176 | (
177 | eval_score,
178 | doc_ids_all,
179 | pr_texts_all,
180 | gt_texts_all,
181 | confidences_all,
182 | ) = self._eval_epoch_end(
183 | outputs, save=True, output_save_dir=output_save_dir, verbose=True
184 | )
185 | print("Test!-----------------------------------------------")
186 | print(eval_score)
187 |
188 | output_save_path_eval_score = output_save_dir / "eval_score.json"
189 | gu.save_json(output_save_path_eval_score, eval_score)
190 |
191 | eval_result = {
192 | "doc_ids": doc_ids_all,
193 | "pr_texts": pr_texts_all,
194 | "gt_texts": gt_texts_all,
195 | }
196 | output_save_path_eval_result = output_save_dir / "eval_result.json"
197 | gu.save_json(output_save_path_eval_result, eval_result)
198 |
199 | output_save_path_confidences = output_save_dir / "confidences.json"
200 | gu.save_json(output_save_path_confidences, confidences_all)
201 |
202 | # add doc_ids to confidences_all
203 | confidences_all_with_doc_ids = {}
204 | for key_target_parse, confidences in confidences_all.items():
205 | c_with_ids = [
206 | (doc_id, c)
207 | for doc_id, c in zip_longest(doc_ids_all[key_target_parse], confidences)
208 | ]
209 | confidences_all_with_doc_ids[key_target_parse] = c_with_ids
210 |
211 | output_save_path_confidences_with_doc_ids = (
212 | output_save_dir / "confidences_with_doc_ids.json"
213 | )
214 | gu.save_json(
215 | output_save_path_confidences_with_doc_ids, confidences_all_with_doc_ids
216 | )
217 |
218 | def _eval_step(self, batch, batch_idx):
219 |
220 | out = defaultdict(dict)
221 | for target_parse, _ in self.target_parses_dict.items():
222 | _prs, _gts, confidences = self.evaluate(target_parse, batch)
223 |
224 | # add confidences as a saved output.
225 | out[target_parse]["pr_texts"] = _prs
226 | out[target_parse]["gt_texts"] = _gts
227 | out[target_parse]["doc_ids"] = batch[target_parse]["guid"]
228 | out[target_parse]["confidences"] = confidences
229 |
230 | return out
231 |
232 | def _eval_epoch_end(self, outputs, save=False, output_save_dir=None, verbose=False):
233 | # outputs = [list of each step outputs]
234 | pr_texts_all = self.gather_step_outputs("pr_texts", outputs)
235 | gt_texts_all = self.gather_step_outputs("gt_texts", outputs)
236 | doc_ids_all = self.gather_step_outputs("doc_ids", outputs)
237 | confidences_all = self.gather_step_outputs("confidences", outputs)
238 |
239 | eval_score = self.cal_score(
240 | doc_ids_all,
241 | pr_texts_all,
242 | gt_texts_all,
243 | save=save,
244 | output_save_dir=output_save_dir,
245 | confidences=confidences_all,
246 | threshold=0.0,
247 | verbose=False,
248 | )
249 |
250 | return eval_score, doc_ids_all, pr_texts_all, gt_texts_all, confidences_all
251 |
252 | def cal_score(
253 | self,
254 | doc_ids_all,
255 | pr_texts_all,
256 | gt_texts_all,
257 | save=False,
258 | output_save_dir=None,
259 | confidences=None,
260 | threshold=0.0,
261 | verbose=False,
262 | input_texts=None,
263 | ):
264 |
265 | if self.tparam.validation_metric == "sentence_bleu":
266 | eval_score = {}
267 | for target_parse, _ in self.target_parses_dict.items():
268 | groundtruth_sentence = gt_texts_all[target_parse]
269 | generated_sentence = pr_texts_all[target_parse]
270 | eval_score[target_parse] = generation_metric(
271 | generated_sentence, groundtruth_sentence, "sentence_bleu"
272 | )
273 | elif self.tparam.validation_metric == "rougeL":
274 | eval_score = {}
275 | for target_parse, _ in self.target_parses_dict.items():
276 | pr_texts = pr_texts_all[target_parse]
277 | gt_texts = gt_texts_all[target_parse]
278 | target_scores = []
279 | for pr_text, gt_text in zip_longest(pr_texts, gt_texts):
280 | r_score = self.rouge_scorer.score(
281 | prediction=pr_text, target=gt_text
282 | )
283 |
284 | target_scores.append(
285 | r_score[self.tparam.validation_metric].fmeasure
286 | )
287 |
288 | eval_score[target_parse] = np.mean(
289 | target_scores
290 | )
291 | print(eval_score)
292 |
293 | elif self.tparam.validation_metric == "em":
294 | # EM score
295 | parses = get_parses_from_eval_results(
296 | self.iparam,
297 | self.target_parses_dict,
298 | doc_ids_all,
299 | gt_texts_all,
300 | pr_texts_all,
301 | )
302 |
303 | # analysis
304 | eval_score = cal_em_from_parses(
305 | self.iparam,
306 | self.target_parses_dict,
307 | parses,
308 | verbose=verbose,
309 | save=save,
310 | output_save_dir=output_save_dir,
311 | input_texts=input_texts,
312 | confidences=confidences,
313 | threshold=threshold,
314 | )
315 |
316 | # text exact matching
317 | for target_parse, target_sub_parses in self.target_parses_dict.items():
318 | gt_texts = gt_texts_all[target_parse]
319 | pr_texts = pr_texts_all[target_parse]
320 | corrects = [str(x) == str(y) for x, y in zip(gt_texts, pr_texts)]
321 | text_em_score = sum(corrects) / len(corrects)
322 | eval_score[target_parse]["text_em"] = text_em_score
323 |
324 | else:
325 | raise ValueError
326 | return eval_score
327 |
328 | def gather_step_outputs(self, key, outputs):
329 | outputs_all = defaultdict(list)
330 |
331 | for target_parse, _ in self.target_parses_dict.items():
332 | for output in outputs:
333 | outputs_all[target_parse] += output[target_parse][key]
334 |
335 | return outputs_all
336 |
337 | def configure_optimizers(self):
338 | optimizer = get_optimizer(self.mparam, self.tparam, self)
339 | lr_dict = get_lr_dict(optimizer, self.tparam, "prompt")
340 |
341 | return {"optimizer": optimizer, "lr_scheduler": lr_dict}
342 |
343 | def evaluate(self, target_parse, batch):
344 | generated_sentence = []
345 | groundtruth_sentence = []
346 |
347 | seqs, output_sentence, confidences = self.prompt_models[target_parse].generate(
348 | batch[target_parse], **self.generation_arguments
349 | )
350 | generated_sentence.extend(output_sentence)
351 | groundtruth_sentence.extend(batch[target_parse]["tgt_text"])
352 |
353 | return generated_sentence, groundtruth_sentence, confidences
354 |
--------------------------------------------------------------------------------
/lbox_open/model/model_optimizer.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import torch
6 | import transformers
7 |
8 | map_optimizers_name_to_type = {
9 | "sgd": torch.optim.SGD,
10 | "adam": torch.optim.Adam,
11 | "adamw": torch.optim.AdamW,
12 | }
13 |
14 |
15 | def get_optimizer(mparam, tparam, model):
16 | # todo: plm training part
17 | _lr_type, lr_param = get_lr_type_and_param(tparam, "prompt")
18 |
19 | # prompt
20 | optimizer_type = map_optimizers_name_to_type[tparam.optim.prompt.optimizer_type]
21 |
22 | if model.task in [
23 | "ljp_civil",
24 | "ljp_criminal",
25 | "casename_classification",
26 | "statute_classification",
27 | "summarization",
28 | ]:
29 | optimizer_grouped_parameters = []
30 | if not mparam.plm.freeze:
31 | optimizer_grouped_parameters.append(
32 | {
33 | "params": list(
34 | filter(lambda p: p.requires_grad, model.plm.parameters())
35 | ),
36 | "lr": tparam.optim.plm.lr,
37 | }
38 | )
39 |
40 | for target_parse, _target_sub_parses in model.target_parses_dict.items():
41 | optimizer_grouped_parameters.append(
42 | {
43 | "params": [
44 | p
45 | for n, p in model.prompt_models[
46 | target_parse
47 | ].template.named_parameters()
48 | if "raw_embedding" not in n
49 | ]
50 | }
51 | )
52 |
53 | optimizer = optimizer_type(
54 | optimizer_grouped_parameters, lr=tparam.optim.prompt.lr, weight_decay=0
55 | )
56 |
57 | else:
58 | raise NotImplementedError
59 |
60 | return optimizer
61 |
62 |
63 | def get_lr_type_and_param(tparam, key):
64 | lr_type = tparam.optim[key].lr_scheduler_type
65 | lr_param = tparam.optim[key].lr_scheduler_param[lr_type]
66 | return lr_type, lr_param
67 |
68 |
69 | def gen_lr_scheduler(tparam, optimizer, lr_type, lr_param):
70 | if lr_type == "constant":
71 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
72 | optimizer, lr_lambda=[lambda epoch: 1, lambda epoch: 1], verbose=True
73 | )
74 | elif lr_type == "multi_step_lr":
75 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
76 | optimizer,
77 | milestones=lr_param["milestones"],
78 | gamma=lr_param["gamma"],
79 | verbose=True,
80 | )
81 |
82 | elif lr_type == "warmup_constant":
83 | lr_scheduler = transformers.get_constant_schedule_with_warmup(
84 | optimizer, num_warmup_steps=lr_param.num_warmup_steps
85 | )
86 | elif lr_type == "cos_with_hard_restarts":
87 | lr_scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
88 | optimizer,
89 | num_warmup_steps=lr_param.num_warmup_steps,
90 | num_training_steps=lr_param.num_training_steps,
91 | num_cycles=lr_param.num_cycles,
92 | )
93 | elif lr_type == "linear":
94 | lr_scheduler = transformers.get_linear_schedule_with_warmup(
95 | optimizer,
96 | num_warmup_steps=lr_param.num_warmup_steps,
97 | num_training_steps=tparam.max_epochs,
98 | )
99 |
100 | else:
101 | raise NotImplementedError
102 | return lr_scheduler
103 |
104 |
105 | def get_lr_dict(optimizer, tparam, key):
106 | lr_type, lr_param = get_lr_type_and_param(tparam, key)
107 | lr_scheduler = gen_lr_scheduler(tparam, optimizer, lr_type, lr_param)
108 | lr_dict = {
109 | "scheduler": lr_scheduler,
110 | "interval": "epoch",
111 | "frequency": 1,
112 | "monitor": "val_loss",
113 | "strict": True,
114 | "name": None,
115 | }
116 |
117 | return lr_dict
118 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_utils import InputExampleWrapper
2 | from .pipeline_base import PromptForGenerationCustom
3 | from .plms import load_plm_wrapper
4 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from openprompt import data_utils
4 |
5 |
6 | class InputExampleWrapper(data_utils.InputExample):
7 | def to_json_string(self):
8 | r"""Serialize this instance to a JSON string."""
9 | # return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
10 | return (
11 | json.dumps(self.to_dict(), indent=2, sort_keys=True, ensure_ascii=False)
12 | + "\n"
13 | )
14 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/pipeline_base.py:
--------------------------------------------------------------------------------
1 | # Wonseok add PromptForGenerationCustom by copying and tweak OpenPrompt-v1.0.0 PromptForGeneration class.
2 | # We modify two things: (1) L343--L345 for the compatibility with transformesr 4.19.4, and
3 | # (2) recover "confidences" which was available in the initial version of OpenPrompt
4 |
5 | from copy import deepcopy
6 | from typing import Any, Dict, Optional, Union
7 |
8 | import numpy as np
9 | import torch
10 | from openprompt.data_utils import InputFeatures
11 | from openprompt.pipeline_base import PromptForGeneration, PromptModel
12 | from openprompt.prompt_base import Template, Verbalizer
13 | from openprompt.utils import round_list, signature
14 | from openprompt.utils.logging import logger
15 | from torch import nn
16 | from transformers.generation_utils import GenerationMixin
17 | from transformers.tokenization_utils import PreTrainedTokenizer
18 | from transformers.utils.dummy_pt_objects import PreTrainedModel
19 | from yacs.config import CfgNode
20 |
21 |
22 | class PromptForGenerationCustom(torch.nn.Module, GenerationMixin):
23 | r"""``PromptModel`` with generation loss caculation and generation utils integrated.
24 |
25 |
26 | Args:
27 | plm (:obj:`PretrainedModel`): A pre-traiend model you decide to use for generation, e.g. GPT.
28 | template (:obj:`Template`): A ``Template`` object you use to wrap the input text for classification, e.g. ``PrefixTemplate``.
29 | tokenizer (:obj:`Tokenizer`): A ``Tokenizer`` of the current model.
30 | gen_config (:obj:`CfgNode`): The generation configs to pass into `GenerationMixin.generate `_
31 | freeze_plm (:obj:`bool`): whether or not to freeze the pretrained language model
32 | plm_eval_mode (:obj:`bool`): this is a stronger freezing mode than freeze_plm, i.e. the dropout of the model is turned off. No matter whether the other part is set to train.
33 | """
34 |
35 | def __init__(
36 | self,
37 | plm: PreTrainedModel,
38 | template: Template,
39 | freeze_plm: bool = False,
40 | plm_eval_mode: bool = False,
41 | gen_config: Optional[CfgNode] = None,
42 | tokenizer: Optional[PreTrainedTokenizer] = None,
43 | ):
44 | super().__init__()
45 | self.freeze_plm = freeze_plm
46 | if tokenizer is None:
47 | assert (
48 | template.tokenizer is not None
49 | ), "Tokenizer can't be set from input args or template"
50 | self.tokenizer = template.tokenizer
51 | else:
52 | self.tokenizer = tokenizer
53 | self.prompt_model = PromptModel(plm, template, freeze_plm, plm_eval_mode)
54 |
55 | self.loss_fct = nn.CrossEntropyLoss(reduction="none")
56 | self.config = plm.config
57 | if gen_config:
58 | for key in gen_config:
59 | setattr(self.config, key, gen_config[key])
60 | self.in_generation_function = False
61 |
62 | self.main_input_name = (
63 | self.prompt_model.main_input_name
64 | ) # for transformers 4.17.0 and higher.
65 |
66 | @property
67 | def plm(self):
68 | return self.prompt_model.plm
69 |
70 | @property
71 | def template(self):
72 | return self.prompt_model.template
73 |
74 | @property
75 | def device(self):
76 | return self.plm.device
77 |
78 | def shift_logits_and_labels(self, logits, loss_ids, reference_ids):
79 |
80 | r"""
81 | Left shift the label, and make label of the positions that are
82 | not loss position to -100, which is the ignore index in pytorch's
83 | loss function.
84 |
85 | Args:
86 | logits (:obj:`torch.Tensor`):
87 | batch (:obj:`InputFeatures`): The input features of batchified data sequences.
88 |
89 | Returns:
90 | shift_logits (:obj:`torch.Tensor`):
91 | shift_input_ids (:obj:`List[int]`):
92 |
93 | """
94 |
95 | shift_logits = logits[..., :-1, :].contiguous()
96 | shift_loss_ids = loss_ids[..., 1:].contiguous()
97 | shift_input_ids = reference_ids[..., 1:].contiguous()
98 | shift_input_ids = torch.where(shift_loss_ids > 0, shift_input_ids, -100)
99 | return shift_logits, shift_input_ids
100 |
101 | def forward(self, *args, **kwargs):
102 | r"""In generation process, it will use the plm's forward function.
103 | This is because, in the first step we will directly call the process_batch function to
104 | generate initial input with the template, after that the all template
105 | have been processed into the past_key_value,
106 | then we can use the normal generation function.
107 | In learning process, the forward is linked to ``_forward`` functions.
108 | in which the loss will be calculated for all the positions in the same time.
109 | """
110 | if self.in_generation_function:
111 | return self.plm.forward(*args, **kwargs)
112 | else:
113 | return self._forward(*args, **kwargs)
114 |
115 | def _forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
116 | r"""
117 | This is the forward method of the training of generation in prompt-learning framework.
118 |
119 | Args:
120 | batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
121 |
122 | Returns:
123 | loss(:obj:torch.Tensor): The loss of the current generation procedure.
124 | """
125 | if self.config.is_encoder_decoder:
126 | reference_ids = batch["decoder_input_ids"]
127 | else:
128 | reference_ids = batch[
129 | "input_ids"
130 | ] # in case in some template, these field is dropped
131 | outputs = self.prompt_model(batch)
132 | logits = outputs.logits
133 | logits, labels = self.shift_logits_and_labels(
134 | logits, batch["loss_ids"], reference_ids
135 | )
136 | batch_size, seq_len, vocab_size = logits.shape
137 | loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
138 | loss = loss.view(batch_size, -1).sum(dim=-1) # TODO support more objectives
139 | loss = loss.mean()
140 | return loss
141 |
142 | def generate(
143 | self,
144 | batch: Union[Dict, InputFeatures],
145 | verbose: Optional[bool] = False,
146 | **generation_kwargs,
147 | ):
148 | r"""This function wraps the generate() methods in parent class ``GenerationMixin``.
149 | Forward uses the ``PretrainedModel``'s forward method.
150 | generation_kwargs include all the parameters that are passed in to
151 | ``transformers.generation_util.GenerationMixin.generate``
152 |
153 | Args:
154 | batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
155 | verbose (:obj:`Optional[bool]`): Set to true to verbose the generated sentence.
156 |
157 | Returns:
158 | output_sequences (:obj:`List[torch.Tensor]`): The raw sequences generated by the generation model.
159 | generated_sentences (:obj:`List[torch.Tensor]`): The generated sentences that have been post-processed.
160 | """
161 | input_generation_kwargs = {
162 | key: value
163 | for key, value in generation_kwargs.items()
164 | if key in signature(GenerationMixin.generate).args
165 | }
166 | if self.config.is_encoder_decoder:
167 | loss_ids_start = batch["loss_ids"].argmax(dim=-1)
168 | assert (
169 | loss_ids_start.min() == loss_ids_start.max()
170 | ), "The generation start from different position in a batch."
171 | batch["decoder_input_ids"] = batch["decoder_input_ids"][
172 | :, : loss_ids_start.min() + 1
173 | ]
174 | input_length = batch["decoder_input_ids"].size(1)
175 | batch_size = batch["decoder_input_ids"].size(0)
176 |
177 | self.generate_ith_token = 0
178 | self.in_generation_function = True
179 |
180 | output_dict = super().generate(
181 | **batch,
182 | **input_generation_kwargs,
183 | pad_token_id=self.tokenizer.pad_token_id,
184 | eos_token_id=self.tokenizer.eos_token_id,
185 | output_scores=True,
186 | return_dict_in_generate=True,
187 | )
188 | output_sequences = output_dict["sequences"]
189 | output_scores = output_dict[
190 | "scores"
191 | ] # (L tuples, (B batches, N tokens)). each tuple = (B,
192 | self.in_generation_function = False
193 | output_sequences = output_sequences.cpu().tolist()
194 | generated_sentences, confidences = self.post_processing_with_confidence(
195 | output_sequences=output_sequences,
196 | input_lengths=input_length,
197 | output_scores=output_scores,
198 | )
199 | # output_sequences = super().generate(**batch, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id)
200 | # self.in_generation_function = False
201 | # output_sequences = output_sequences.cpu().tolist()
202 | # generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_length)
203 | else:
204 | input_length = batch["input_ids"].size(1)
205 | batch_size = batch["input_ids"].size(0)
206 |
207 | # Currently huggingface transformers only support single sample generation, or padding to the left (instead of the right).
208 | # because it will only extract the last position of the output
209 | # generate one_by_one
210 | if "input_ids_len" in batch:
211 | input_real_lens = batch["input_ids_len"]
212 | else:
213 | input_real_lens = torch.sum(
214 | (batch["input_ids"] != self.tokenizer.pad_token_id).to(torch.int),
215 | dim=-1,
216 | )
217 | output_sequences = []
218 | output_scores = []
219 | for instance_id in range(batch_size):
220 | # remove the pad token
221 | instance = {
222 | key: batch[key][instance_id : instance_id + 1][
223 | :, : input_real_lens[instance_id]
224 | ]
225 | for key in batch
226 | if isinstance(batch[key], torch.Tensor)
227 | and batch[key].shape[:2] == torch.Size([batch_size, input_length])
228 | }
229 | self.generate_ith_token = 0
230 | self.in_generation_function = True
231 | output_dict = super().generate(
232 | **instance,
233 | **input_generation_kwargs,
234 | pad_token_id=self.tokenizer.pad_token_id,
235 | eos_token_id=self.tokenizer.eos_token_id,
236 | output_scores=True,
237 | return_dict_in_generate=True,
238 | )
239 | output_sequence = output_dict["sequences"]
240 | self.in_generation_function = False
241 | output_sequences.extend(
242 | output_sequence.cpu().tolist()
243 | ) # TODO: to support generate multiple sentence
244 |
245 | output_score = output_dict["scores"]
246 | output_scores.append(output_score)
247 |
248 | generated_sentences, confidences = self.post_processing_with_confidence(
249 | output_sequences=output_sequences,
250 | input_lengths=input_real_lens.cpu().tolist(),
251 | output_scores=output_scores,
252 | )
253 | # for instance_id in range(batch_size):
254 | # # remove the pad token
255 | # instance = {key: batch[key][instance_id:instance_id+1][:,:input_real_lens[instance_id]] for key in batch if isinstance(batch[key], torch.Tensor) and batch[key].shape[:2]==torch.Size([batch_size, input_length])}
256 | # self.generate_ith_token = 0
257 | # self.in_generation_function = True
258 | # output_sequence = super().generate(**instance, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id)
259 | # self.in_generation_function = False
260 | # output_sequences.extend(output_sequence.cpu().tolist()) # TODO: to support generate multiple sentence
261 | # generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_real_lens.cpu().tolist())
262 | if verbose:
263 | logger.info(f"Generated:{generated_sentences}")
264 | return output_sequences, generated_sentences, confidences
265 |
266 | def post_processing(self, output_sequences, input_lengths):
267 | r"""
268 | Post-process the sequences generated by the generation model.
269 |
270 | Args:
271 | output_sequences (:obj:`torch.Tensor`): The raw sequences generated by the generation model.
272 | input_lengths (:obj:`int` or `list`): The length(s) of the input sequence.
273 |
274 | Returns:
275 | :obj:`List`: The generated sentences that have been post-processed.
276 | """
277 | generated_sentences = []
278 | if type(input_lengths) == int:
279 | input_lengths = [input_lengths] * len(output_sequences)
280 | for sent_id, seq in enumerate(output_sequences):
281 | seq = seq[input_lengths[sent_id] :]
282 |
283 | if (
284 | hasattr(self.tokenizer, "eos_token")
285 | and self.tokenizer.eos_token is not None
286 | ):
287 | text_output = self.tokenizer.decode(
288 | seq, clean_up_tokenization_spaces=True, skip_special_tokens=False
289 | )
290 | idx = text_output.find(self.tokenizer.eos_token)
291 | if idx >= 0:
292 | text_output = text_output[:idx]
293 | else:
294 | text_output = self.tokenizer.decode(
295 | seq, clean_up_tokenization_spaces=True, skip_special_tokens=True
296 | )
297 | text_output = text_output.strip()
298 | generated_sentences.append(text_output)
299 | return generated_sentences
300 |
301 | def prepare_inputs_for_generation(
302 | self, input_ids: Optional[torch.Tensor] = None, **model_kwargs
303 | ):
304 | r"""This function wraps the ``prepare_inputs_for_generation`` function in the huggingface transformers.
305 |
306 | When the `past` not in model_kwargs, we prepare the input from scratch.
307 | When `past` is in model_kwargs, we don't need to prepare the template wrapped input,
308 | instead we use the inner pretrain_models' function to prepare the next step's input.
309 | `model_kwargs` includes all the argument passed in the `batch`: InputFeatures, except ``input_ids``
310 | , as long as they do not conflict with keywords in ``generation_kwargs``. if 'past' not in model_kwargs: # the past_key_value not in model_kwargs, then we need to prepare input from scrath
311 | , as long as they do not conflict with keywords in ``generation_kwargs``.
312 |
313 | Args:
314 | input_ids(:obj:`torch.Tensor`): Indices of input sequence tokens in the vocabulary.
315 | """
316 | if (
317 | self.generate_ith_token == 0 and "encoder_outputs" not in model_kwargs
318 | ): # generating the first token in decoder only setting.
319 |
320 | batch = InputFeatures(input_ids=input_ids, **model_kwargs)
321 | model_inputs = self.prompt_model.prepare_model_inputs(batch)
322 | # check the compatibility for more models. Having checked gpt2, T5
323 | else: # generating the subsequence generation can use the default setting
324 | model_inputs = self.plm.prepare_inputs_for_generation(
325 | input_ids, **model_kwargs
326 | )
327 | self.last_model_inputs = model_inputs # to update the model_kwargs in _update_model_kwargs_for_generation, in-place operation.
328 | return model_inputs
329 |
330 | def _update_model_kwargs_for_generation(
331 | self, outputs, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
332 | ) -> Dict[str, Any]:
333 | r"""The parents class's ``_update_model_kwargs_for_generation`` method will
334 | add ``past_key_values`` to model_kwargs, and update ``token_type_ids``, and ``attention_mask_ids``.
335 |
336 | In case some of the model_kwargs are modified in the prepare_inputs_for_generation function
337 | and should be used as the subsequent model_kwargs, we upate these kwargs before the parent class
338 | call.
339 |
340 | Other updates should be added here after the parent's function call.
341 |
342 | Args:
343 | outputs (:obj:`torch.Tensor`):
344 | is_encoder_decoder (:obj:`bool`, defaults to False):
345 | """
346 | if self.generate_ith_token == 0:
347 | for key in self.last_model_inputs:
348 | if key in model_kwargs:
349 | model_kwargs[key] = self.last_model_inputs[key]
350 | model_kwargs = super(
351 | PromptForGeneration, PromptForGeneration
352 | )._update_model_kwargs_for_generation(
353 | outputs=outputs,
354 | model_kwargs=model_kwargs,
355 | is_encoder_decoder=is_encoder_decoder,
356 | )
357 | self.generate_ith_token += 1
358 | return model_kwargs
359 |
360 | def _prepare_encoder_decoder_kwargs_for_generation(
361 | self,
362 | input_ids: torch.LongTensor,
363 | model_kwargs,
364 | model_input_name: Optional[str] = None,
365 | ) -> Dict[str, Any]:
366 | r"""This function resemble the function in GeneraionMix
367 |
368 | Args:
369 | input_ids (:obj:`torch.LongTensor`) The input ids for
370 | """
371 | if "encoder_outputs" not in model_kwargs:
372 | # retrieve encoder hidden states
373 | encoder = self.plm.get_encoder()
374 | encoder_kwargs = {
375 | argument: value
376 | for argument, value in model_kwargs.items()
377 | if not (
378 | argument.startswith("decoder_") or argument.startswith("cross_attn")
379 | )
380 | }
381 | model_input_name = (
382 | model_input_name
383 | if model_input_name is not None
384 | else self.main_input_name
385 | )
386 | batch = {model_input_name: input_ids, **encoder_kwargs}
387 | model_inputs = self.prompt_model.prepare_model_inputs(
388 | batch
389 | ) # This line differs from the orinigal code base, we should process the input
390 | # with our template, then pass it into the model.
391 | # some of the arguments may have been changed by the template,
392 | # e.g. the attention mask. Here we update the model_kwargs
393 | for key in model_kwargs:
394 | if key in model_inputs:
395 | model_kwargs[key] = model_inputs[key]
396 | model_inputs_with_use_cache_false = deepcopy(model_inputs)
397 | model_inputs_with_use_cache_false["use_cache"] = False
398 | model_kwargs["encoder_outputs"] = encoder(
399 | return_dict=True, **model_inputs_with_use_cache_false
400 | )
401 | return model_kwargs
402 |
403 | ## We comment this code since it conflict with [OpenDelta](https://github.com/thunlp/OpenDelta)
404 | # def state_dict(self, *args, **kwargs):
405 | # """ Save the model using template and plm's save methods. """
406 | # _state_dict = {}
407 | # if not self.prompt_model.freeze_plm:
408 | # _state_dict['plm'] = self.plm.state_dict(*args, **kwargs)
409 | # _state_dict['template'] = self.template.state_dict(*args, **kwargs)
410 | # return _state_dict
411 |
412 | # def load_state_dict(self, state_dict, *args, **kwargs):
413 | # """ Load the model using template and plm's load methods. """
414 | # if 'plm' in state_dict and not self.prompt_model.freeze_plm:
415 | # self.plm.load_state_dict(state_dict['plm'], *args, **kwargs)
416 | # self.template.load_state_dict(state_dict['template'], *args, **kwargs)
417 |
418 | def _reorder_cache(self, past, beam_idx):
419 | r"""Use the plm's default _reorder_cache function"""
420 | return self.plm._reorder_cache(past, beam_idx)
421 |
422 | def parallelize(self, device_map=None):
423 | r"""Parallelize the model across device"""
424 | if hasattr(self.plm, "parallelize"):
425 | self.plm.parallelize(device_map)
426 | self.device_map = self.plm.device_map
427 | else:
428 | raise NotImplementedError(
429 | "parallelize method was not implemented for this plm."
430 | )
431 |
432 | def deparallelize(self):
433 | r"""Deparallelize the model across device"""
434 | if hasattr(self.plm, "deparallelize"):
435 | self.plm.deparallelize()
436 | self.device_map = None
437 | else:
438 | raise NotImplementedError(
439 | "parallelize method was not implemented for this plm."
440 | )
441 |
442 | def post_processing_with_confidence(
443 | self, output_sequences, input_lengths, output_scores
444 | ):
445 | r"""
446 | Post-process the sequences generated by the generation model.
447 |
448 | Args:
449 | output_sequences (:obj:`torch.Tensor`): The raw sequences generated by the generation model.
450 | input_lengths (:obj:`int` or `list`): The length(s) of the input sequence.
451 |
452 | Returns:
453 | :obj:`List`: The generated sentences that have been post-processed.
454 | """
455 | generated_sentences = []
456 | if type(input_lengths) == int:
457 | input_lengths = [input_lengths] * len(output_sequences)
458 | confidences = []
459 | confidences_list = []
460 | for sent_id, seq in enumerate(output_sequences):
461 | seq = seq[input_lengths[sent_id] :]
462 | if self.config.is_encoder_decoder:
463 | # [T, B, Ntoken]
464 | assert len(seq) == len(
465 | output_scores
466 | ) # (T, B, Ntoken), T is a length of sequence.
467 | else:
468 | # [B, T, Ntoken]
469 | assert len(seq) == len(output_scores[sent_id])
470 |
471 | text_output = self.tokenizer.decode(seq, clean_up_tokenization_spaces=True)
472 | idx = text_output.find(self.tokenizer.eos_token)
473 | if idx >= 0:
474 | text_output = text_output[:idx]
475 | text_output = text_output.strip()
476 | generated_sentences.append(text_output)
477 |
478 | if self.tokenizer.eos_token_id in seq:
479 | idx_token = seq.index(self.tokenizer.eos_token_id)
480 | else:
481 | idx_token = -1
482 |
483 | if idx_token >= 0:
484 | seq_trimmed = seq[:idx_token]
485 | else:
486 | seq_trimmed = seq
487 |
488 | confidence_list = []
489 | for i_tok, tok_id in enumerate(seq_trimmed):
490 | if self.config.is_encoder_decoder:
491 | # [T, B, Ntoken]
492 | scores = output_scores[i_tok] # [B, Ntok]
493 | prob = scores[sent_id, :].softmax(-1)
494 | else:
495 | # [B, T, Ntoken]
496 | scores = output_scores[sent_id] # [L, Ntok]
497 | prob = scores[i_tok].softmax(-1)[0]
498 | confidence_list.append(prob[tok_id].item())
499 | confidences_list.append(confidence_list)
500 | confidences.append(np.mean(confidence_list))
501 |
502 | return generated_sentences, confidences
503 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/plms/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | from openprompt import plms
5 | from transformers import (
6 | AutoTokenizer,
7 | GPT2Config,
8 | GPT2LMHeadModel,
9 | MT5Config,
10 | MT5ForConditionalGeneration,
11 | MT5Tokenizer,
12 | PreTrainedTokenizer,
13 | PreTrainedTokenizerFast,
14 | )
15 |
16 | from .lm import LMTFastokenizerWrapperCustom
17 |
18 |
19 | def get_model_class(plm_type: str):
20 | return _MODEL_CLASSES[plm_type]
21 |
22 |
23 | MT5TokenizerWrapper = plms.T5TokenizerWrapper
24 |
25 | _MODEL_CLASSES = {
26 | "mt5": plms.ModelClass(
27 | **{
28 | "config": MT5Config,
29 | "tokenizer": MT5Tokenizer,
30 | "model": MT5ForConditionalGeneration,
31 | "wrapper": MT5TokenizerWrapper,
32 | }
33 | ),
34 | "kogpt2": plms.ModelClass(
35 | **{
36 | "config": GPT2Config,
37 | "tokenizer": PreTrainedTokenizerFast,
38 | "model": GPT2LMHeadModel,
39 | "wrapper": LMTFastokenizerWrapperCustom,
40 | }
41 | ),
42 | "legal-gpt": plms.ModelClass(
43 | **{
44 | "config": GPT2Config,
45 | "tokenizer": AutoTokenizer,
46 | "model": GPT2LMHeadModel,
47 | "wrapper": LMTFastokenizerWrapperCustom,
48 | }
49 | ),
50 | }
51 |
52 |
53 | def load_plm_wrapper(
54 | model_name,
55 | model_path,
56 | specials_to_add=None,
57 | revision=None,
58 | do_not_load_pretrained_weight=False,
59 | use_custom_loader=False,
60 | ):
61 | if not use_custom_loader:
62 | return plms.load_plm(model_name, model_path, specials_to_add)
63 | else:
64 | model_class = get_model_class(plm_type=model_name)
65 | wrapper = model_class.wrapper
66 | if model_name in ["kogpt2"]:
67 | model_config = model_class.config.from_pretrained(
68 | model_path, revision=revision
69 | )
70 | if do_not_load_pretrained_weight:
71 | model = model_class.model(
72 | config=model_config,
73 | )
74 | else:
75 | model = model_class.model.from_pretrained(
76 | model_path, revision=revision, config=model_config
77 | )
78 |
79 | tokenizer = model_class.tokenizer.from_pretrained(
80 | model_path,
81 | bos_token="",
82 | eos_token="",
83 | unk_token="",
84 | pad_token="",
85 | mask_token="",
86 | )
87 | elif model_name in ["legal-gpt"]:
88 | model_config = model_class.config.from_pretrained(
89 | model_path, revision=revision
90 | )
91 | if do_not_load_pretrained_weight:
92 | model = model_class.model(
93 | config=model_config,
94 | )
95 | else:
96 | model = model_class.model.from_pretrained(
97 | model_path, revision=revision, config=model_config
98 | )
99 | tokenizer = model_class.tokenizer.from_pretrained(
100 | model_path,
101 | bos_token="[BOS]",
102 | unk_token="[UNK]",
103 | pad_token="[PAD]",
104 | mask_token="[MASK]",
105 | )
106 |
107 | else:
108 | model_config = model_class.config.from_pretrained(
109 | model_path, revision=revision
110 | )
111 | if do_not_load_pretrained_weight:
112 | model = model_class.model(
113 | config=model_config,
114 | )
115 | else:
116 |
117 | model = model_class.model.from_pretrained(
118 | model_path, revision=revision, config=model_config
119 | )
120 |
121 | if "gpt" in model_name: # add pad token for gpt
122 | specials_to_add = [""]
123 |
124 | tokenizer = model_class.tokenizer.from_pretrained(model_path)
125 | model, tokenizer = plms.add_special_tokens(
126 | model, tokenizer, specials_to_add=specials_to_add
127 | )
128 |
129 | if model_name in ["mt5"]:
130 | _path = (
131 | Path(__file__).parent.resolve() / "mt5_additional_special_tokens.json"
132 | )
133 | with open(_path) as f:
134 | mt5_additional_special_tokens = json.load(f)
135 | tokenizer.add_special_tokens(
136 | {
137 | "additional_special_tokens": mt5_additional_special_tokens[
138 | "additional_special_tokens"
139 | ]
140 | }
141 | )
142 |
143 | return model, tokenizer, model_config, wrapper
144 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/plms/lm.py:
--------------------------------------------------------------------------------
1 | # Wonseok add LMTFastokenizerWrapperCustom which is copied from OpenPrompt-v1.0.0 LMTokenizerWrapper class.
2 | # - The only difference is to inherit FastTokenizerWrapper instead of TokenizerWrapper
3 |
4 | from collections import defaultdict
5 | from typing import Optional
6 |
7 | from transformers.tokenization_utils import PreTrainedTokenizer
8 |
9 | from .utils import FastTokenizerWrapper
10 |
11 |
12 | class LMTFastokenizerWrapperCustom(FastTokenizerWrapper):
13 | r"""
14 | LMTokenizer is a causual language model. Therefore it can only predict position
15 | at the end of the sentence. A prefix-style template like: 'A news : ' is
16 | not applicable in this situation.
17 | For the template where there is '' or '' after '', we raise an exception and terminate
18 | the program.
19 | For the template where there are template words after '', we ignore these template words.
20 | Moreover, it can only predict one '' position. All template that has multiple '' will
21 | give rise to an exception.
22 | """
23 |
24 | def __init__(
25 | self,
26 | max_seq_length: int,
27 | tokenizer: PreTrainedTokenizer,
28 | truncate_method: Optional[str] = "tail",
29 | predict_eos_token: Optional[bool] = False,
30 | **kwargs
31 | ):
32 | super().__init__(
33 | max_seq_length=max_seq_length,
34 | tokenizer=tokenizer,
35 | truncate_method=truncate_method,
36 | )
37 | self.predict_eos = predict_eos_token
38 |
39 | @property
40 | def num_special_tokens_to_add(self):
41 | if not hasattr(self, "_num_specials"):
42 | self._num_specials = self.tokenizer.num_special_tokens_to_add()
43 | return self._num_specials
44 |
45 | def tokenize_one_example(self, wrapped_example, teacher_forcing):
46 | """# TODO doens't consider the situation that input has two parts"""
47 | wrapped_example, others = wrapped_example
48 |
49 | if teacher_forcing:
50 |
51 | tgt_text = others["tgt_text"]
52 | if isinstance(tgt_text, str):
53 | tgt_text = [tgt_text]
54 |
55 | if self.predict_eos:
56 | if not wrapped_example[-1]["text"].endswith(self.tokenizer.eos_token):
57 | wrapped_example.append(
58 | {
59 | "text": self.tokenizer.eos_token,
60 | "shortenable_ids": 0,
61 | "loss_ids": 1,
62 | }
63 | )
64 |
65 | encoder_inputs = defaultdict(list)
66 |
67 | num_mask_token_used = 0
68 |
69 | for piece_id, piece in enumerate(wrapped_example):
70 | if len(piece["text"]) == 0:
71 | continue
72 |
73 | if (
74 | piece["text"] == self.tokenizer.eos_token
75 | and self.predict_eos
76 | and wrapped_example[piece_id - 1]["loss_ids"] == 1
77 | ): # eos after the mask also need to be pred
78 | piece["loss_ids"] = 1
79 |
80 | if piece["text"] == self.template_mask_token:
81 | if teacher_forcing:
82 | piece["text"] = " " + tgt_text[num_mask_token_used] + " "
83 | else:
84 | encoder_inputs["loss_ids"][-1][-1] = 1
85 | break
86 |
87 | if piece["text"] in self.special_tokens_maps.keys():
88 | to_replace = self.special_tokens_maps[piece["text"]]
89 | if to_replace is not None:
90 | piece["text"] = to_replace
91 | else:
92 | raise KeyError(
93 | "This tokenizer doesn't specify {} token.".format(piece["text"])
94 | )
95 |
96 | if "soft_token_ids" in piece and piece["soft_token_ids"] != 0:
97 | encode_text = [
98 | 0
99 | ] # can be replace by any token, since these token will use their own embeddings
100 | else:
101 | encode_text = self.tokenizer.encode(
102 | piece["text"], add_special_tokens=False
103 | )
104 |
105 | encoding_length = len(encode_text)
106 |
107 | encoder_inputs["input_ids"].append(encode_text)
108 | for key in piece:
109 | if key not in ["text"]:
110 | encoder_inputs[key].append([piece[key]] * encoding_length)
111 |
112 | encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
113 |
114 | # delete shortenable ids
115 | encoder_inputs.pop("shortenable_ids")
116 | encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
117 | encoder_inputs = self.add_special_tokens(
118 | encoder_inputs=encoder_inputs
119 | ) # this will do nothing in GPT2 tokenizer
120 | # create special input ids
121 | encoder_inputs["attention_mask"] = [1] * len(encoder_inputs["input_ids"])
122 | if self.create_token_type_ids:
123 | encoder_inputs["token_type_ids"] = [0] * len(encoder_inputs["input_ids"])
124 | # pad to max length
125 | input_ids_len = len(encoder_inputs["input_ids"])
126 | encoder_inputs = self.padding(
127 | input_dict=encoder_inputs,
128 | max_len=self.max_seq_length,
129 | pad_id_for_inputs=self.tokenizer.pad_token_id,
130 | )
131 | encoder_inputs = {**encoder_inputs, "input_ids_len": input_ids_len}
132 | return encoder_inputs
133 |
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/plms/mt5_additional_special_tokens.json:
--------------------------------------------------------------------------------
1 | {
2 | "additional_special_tokens": [
3 | "",
4 | "",
5 | "",
6 | "",
7 | "",
8 | "",
9 | "",
10 | "",
11 | "",
12 | "",
13 | "",
14 | "",
15 | "",
16 | "",
17 | "",
18 | "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | "",
24 | "",
25 | "",
26 | "",
27 | "",
28 | "",
29 | "",
30 | "",
31 | "",
32 | "",
33 | "",
34 | "",
35 | "",
36 | "",
37 | "",
38 | "",
39 | "",
40 | "",
41 | "",
42 | "",
43 | "",
44 | "",
45 | "",
46 | "",
47 | "",
48 | "",
49 | "",
50 | "",
51 | "",
52 | "",
53 | "",
54 | "",
55 | "",
56 | "",
57 | "",
58 | "",
59 | "",
60 | "",
61 | "",
62 | "",
63 | "",
64 | "",
65 | "",
66 | "",
67 | "",
68 | "",
69 | "",
70 | "",
71 | "",
72 | "",
73 | "",
74 | "",
75 | "",
76 | "",
77 | "",
78 | "",
79 | "",
80 | "",
81 | "",
82 | "",
83 | "",
84 | "",
85 | "",
86 | "",
87 | "",
88 | "",
89 | "",
90 | "",
91 | "",
92 | "",
93 | "",
94 | "",
95 | "",
96 | "",
97 | "",
98 | "",
99 | "",
100 | "",
101 | "",
102 | ""
103 | ]
104 | }
--------------------------------------------------------------------------------
/lbox_open/openprompt_wrapper/plms/utils.py:
--------------------------------------------------------------------------------
1 | # Wonseok add FastTokenizerWrapper. The class inherit OpenPrompt-v1.0.0 TokenizerWrapper class.
2 |
3 | import warnings
4 |
5 | import numpy as np
6 | from openprompt import plms
7 |
8 |
9 | class FastTokenizerWrapper(plms.utils.TokenizerWrapper):
10 | def add_special_tokens(self, encoder_inputs):
11 | # add special tokens
12 | for key in encoder_inputs:
13 | if key == "input_ids":
14 | with warnings.catch_warnings():
15 | warnings.simplefilter("ignore")
16 | encoder_inputs[
17 | key
18 | ] = self.tokenizer.build_inputs_with_special_tokens(
19 | encoder_inputs[key]
20 | )
21 | else:
22 | # special_tokens_mask = np.array(self.tokenizer.get_special_tokens_mask(encoder_inputs[key], already_has_special_tokens=True))
23 | special_tokens_mask = np.array([0] * len(encoder_inputs[key]))
24 | with_special_tokens = np.array(
25 | self.tokenizer.build_inputs_with_special_tokens(encoder_inputs[key])
26 | )
27 | if key in ["soft_token_ids"]: # TODO maybe more than this
28 | encoder_inputs[key] = (
29 | (1 - special_tokens_mask) * with_special_tokens
30 | ).tolist() # use 0 as special
31 | else:
32 | encoder_inputs[key] = (
33 | (1 - special_tokens_mask) * with_special_tokens
34 | - special_tokens_mask * 100
35 | ).tolist() # use -100 as special
36 | return encoder_inputs
37 |
--------------------------------------------------------------------------------
/lbox_open/parser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/parser/__init__.py
--------------------------------------------------------------------------------
/lbox_open/parser/output_parser.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import re
6 |
7 |
8 | def sep_token_parser_baseline(
9 | parse_sep_token, value_sep_token, empty_token, sub_parse_keys, text
10 | ):
11 | parse = {}
12 | char_comma = ","
13 |
14 | # filter ',' inside of number
15 | ms_money = re.finditer("\d[\d|,|.]+\d", text)
16 | ms_comma = re.finditer(char_comma, text)
17 |
18 | idxs_comma = [m.start() for m in ms_comma]
19 | idxs_comma_I = []
20 | for ms_money in ms_money:
21 | st = ms_money.start()
22 | ed = ms_money.end()
23 | for idx_comma in idxs_comma:
24 | if idx_comma >= st and idx_comma <= ed:
25 | idxs_comma_I.append(idx_comma)
26 |
27 | text_copy = ""
28 | rpl_sym = "★"
29 | for i, c in enumerate(text):
30 | if i in idxs_comma_I:
31 | text_copy += rpl_sym
32 | else:
33 | text_copy += c
34 |
35 | values = text_copy.split(parse_sep_token)
36 |
37 | for i, k in enumerate(sub_parse_keys):
38 | if i <= len(values) - 1:
39 | if empty_token in values[i]:
40 | vals = empty_token
41 | else:
42 | parse_values_before_split = values[i]
43 | parse_values = parse_values_before_split.split(value_sep_token)
44 | vals = []
45 | for val in parse_values:
46 | v = val.replace(rpl_sym, char_comma).strip()
47 | v = re.sub("\s", "", v)
48 | vals.append(v)
49 | else:
50 | vals = None
51 | parse[k] = vals
52 | return parse
53 |
54 |
55 | def sep_token_based_parser(
56 | target_parse, parse_sep_token, value_sep_token, empty_token, keys, text
57 | ):
58 | if target_parse in [
59 | "fine_imprisonment_lvs",
60 | "claim_acceptance_lv",
61 | "casename_classification",
62 | "statute_classification",
63 | ]:
64 | # print(text)
65 | parse = sep_token_parser_baseline(
66 | parse_sep_token, value_sep_token, empty_token, keys, text
67 | )
68 | else:
69 | raise NotImplementedError
70 |
71 | return parse
72 |
--------------------------------------------------------------------------------
/lbox_open/parser/output_parser_utils.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | from collections import defaultdict
6 | from itertools import zip_longest
7 | from pathlib import Path
8 |
9 | import lbox_open.utils.general_utils as gu
10 | from lbox_open.metric.exact_match import ExactMatch
11 |
12 | from .output_parser import sep_token_based_parser
13 |
14 |
15 | def text_to_parse_separator_based(
16 | target_parse,
17 | parse_sep_token,
18 | value_sep_token,
19 | empty_token,
20 | target_sub_parses,
21 | texts,
22 | ):
23 | return list(
24 | map(
25 | lambda x: sep_token_based_parser(
26 | target_parse,
27 | parse_sep_token,
28 | value_sep_token,
29 | empty_token,
30 | target_sub_parses,
31 | x,
32 | ),
33 | texts,
34 | )
35 | )
36 |
37 |
38 | def get_parses_from_eval_results(
39 | infer_param,
40 | target_parses_dict,
41 | doc_ids,
42 | gt_texts,
43 | pr_texts,
44 | ):
45 | parses = defaultdict(dict)
46 | for target_parse, target_sub_parses in target_parses_dict.items():
47 | gt_parses = text_to_parse_separator_based(
48 | target_parse,
49 | infer_param.parse_sep_token,
50 | infer_param.value_sep_token,
51 | infer_param.empty_token,
52 | target_sub_parses,
53 | gt_texts[target_parse],
54 | )
55 |
56 | pr_parses = text_to_parse_separator_based(
57 | target_parse,
58 | infer_param.parse_sep_token,
59 | infer_param.value_sep_token,
60 | infer_param.empty_token,
61 | target_sub_parses,
62 | pr_texts[target_parse],
63 | )
64 |
65 | # insert doc_ids
66 | for doc_id, gt_parse, pr_parse in zip_longest(
67 | doc_ids[target_parse], gt_parses, pr_parses
68 | ):
69 | gt_parse["doc_id"] = doc_id
70 | pr_parse["doc_id"] = doc_id
71 |
72 | parses[target_parse]["gt_parses"] = gt_parses
73 | parses[target_parse]["pr_parses"] = pr_parses
74 |
75 | return parses
76 |
77 |
78 | def cal_em_from_parses(
79 | infer_param,
80 | target_parses_dict,
81 | parses,
82 | verbose=False,
83 | save=False,
84 | output_save_dir=None,
85 | confidences=None,
86 | threshold=0.0,
87 | input_texts=None,
88 | ):
89 | em_scores_full = {}
90 | for target_parse, target_sub_parses in target_parses_dict.items():
91 |
92 | gt_parses = parses[target_parse]["gt_parses"]
93 | pr_parses = parses[target_parse]["pr_parses"]
94 |
95 | if confidences is None:
96 | _confs = [1.0] * len(gt_parses)
97 | else:
98 | _confs = confidences[target_parse]
99 |
100 | exact_match = ExactMatch(
101 | list(gt_parses[0].keys()), empty_value=infer_param.empty_token
102 | )
103 |
104 | (
105 | f1_all,
106 | cnt_tp_all,
107 | cnt_fp_all,
108 | cnt_fn_all,
109 | cnt_tn_all,
110 | th_recall,
111 | ) = exact_match.compare_parses(gt_parses, pr_parses, _confs, threshold)
112 |
113 | if verbose:
114 | print(f"Target_parse: {target_parse} with th-recall: {th_recall}")
115 | print("tp-------------------")
116 | print(cnt_tp_all)
117 | print("fp-------------------")
118 | print(cnt_fp_all)
119 | print("fn-------------------")
120 | print(cnt_fn_all)
121 | print("tn-------------------")
122 | print(cnt_tn_all)
123 | print("f1-------------------")
124 | print(f1_all)
125 |
126 | score = {
127 | "f1": f1_all,
128 | "tp": cnt_tp_all,
129 | "fp": cnt_fp_all,
130 | "fn": cnt_fn_all,
131 | "tn": cnt_tn_all,
132 | "th_recall": th_recall,
133 | }
134 | em_scores_full[target_parse] = score
135 |
136 | if save:
137 | if output_save_dir is not None:
138 | if "path_eval_result" in infer_param:
139 | print("path_eval_result is ignored!!!")
140 | else:
141 | output_save_dir = infer_param.path_eval_result
142 |
143 | # path_save_dir = os.path.dirname(output_save_dir)
144 | path_save_dir = output_save_dir
145 | path_save = Path(path_save_dir) / f"eval_parse_{target_parse}.json"
146 | gu.save_json(path_save, parses)
147 |
148 | path_save = Path(path_save_dir) / f"score_exact_match_{target_parse}.json"
149 | gu.save_json(path_save, score)
150 |
151 | return em_scores_full
152 |
--------------------------------------------------------------------------------
/lbox_open/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .lbox_open_pipeline import *
2 |
--------------------------------------------------------------------------------
/lbox_open/pipeline/lbox_open_pipeline.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | from pathlib import Path
6 |
7 | import pytorch_lightning as pl
8 | import torch
9 |
10 | from lbox_open import openprompt_wrapper
11 | from lbox_open.data_module.data_precedent import PrecedentDataModule
12 | from lbox_open.model.generative_baseline_model import GenerativeParser
13 | from lbox_open.template import prompt_generation_utils
14 | from lbox_open.utils import general_utils as gu
15 |
16 |
17 | def get_data_module(
18 | cfg,
19 | plm_tokenizer,
20 | TokenizerWrapper,
21 | input_templates,
22 | ):
23 |
24 | if cfg.data.use_local_data:
25 | raw_data = {
26 | "train": gu.load_jsonl(cfg.data.path_train, None),
27 | "valid": gu.load_jsonl(cfg.data.path_valid, None),
28 | }
29 | if cfg.data.path_test is not None:
30 | raw_data["test"] = gu.load_jsonl(cfg.data.path_test, None)
31 | else:
32 | raw_data = None
33 |
34 | if cfg.model.task in [
35 | "ljp_civil",
36 | "ljp_criminal",
37 | "casename_classification",
38 | "statute_classification",
39 | "summarization",
40 | ]:
41 | data_module = PrecedentDataModule(
42 | cfg,
43 | plm_tokenizer,
44 | TokenizerWrapper,
45 | input_templates,
46 | raw_data,
47 | )
48 | else:
49 | raise NotImplementedError
50 |
51 | return data_module
52 |
53 |
54 | def get_plm(cfg):
55 | (
56 | plm,
57 | plm_tokenizer,
58 | plm_model_config,
59 | TokenizerWrapperClass,
60 | ) = openprompt_wrapper.load_plm_wrapper(
61 | model_name=cfg.model.plm.name,
62 | model_path=cfg.model.plm.path,
63 | revision=cfg.model.plm.revision,
64 | do_not_load_pretrained_weight=cfg.train.weight.do_not_load_pretrained_weight,
65 | use_custom_loader=True,
66 | )
67 | return plm, plm_tokenizer, plm_model_config, TokenizerWrapperClass
68 |
69 |
70 | def gen_input_templates(cfg, plm, plm_tokenizer):
71 | input_templates = {}
72 | for target_parse, target_sub_parses in cfg.model.target_parses_dict.items():
73 | input_templates[target_parse] = prompt_generation_utils.gen_template(
74 | cfg.model.task,
75 | target_parse,
76 | cfg.model.input_template_type,
77 | plm,
78 | plm_tokenizer,
79 | )
80 |
81 | return input_templates
82 |
83 |
84 | def get_model(cfg, plm, plm_tokenizer, input_templates):
85 | if cfg.model.model_type == "generative":
86 | model = GenerativeParser(cfg, plm, plm_tokenizer, input_templates)
87 | else:
88 | raise NotImplementedError
89 |
90 | if cfg.train.weight.trained:
91 | path_load = Path(cfg.train.weight.path)
92 |
93 | if cfg.model.task in [
94 | "ljp_civil",
95 | "ljp_criminal",
96 | "casename_classification",
97 | "statute_classification",
98 | "summarization",
99 | ]:
100 | ckpt = torch.load(path_load)
101 | if "state_dict" in ckpt:
102 | ckpt_state_dict = ckpt["state_dict"]
103 | else:
104 | ckpt_state_dict = ckpt
105 | model.load_state_dict(ckpt_state_dict, strict=False)
106 |
107 | else:
108 | raise NotImplementedError
109 |
110 | print(f"The model weights are loaded from {path_load}.")
111 |
112 | return model
113 |
114 |
115 | def get_trainer(cfg):
116 | from pytorch_lightning import loggers as pl_loggers
117 |
118 | tparam = cfg.train
119 | mparam = cfg.model
120 |
121 | log_dir = Path(cfg.train.log_dir) / cfg.name
122 | tb_logger = pl_loggers.TensorBoardLogger(log_dir)
123 |
124 | pl.utilities.seed.seed_everything(seed=cfg.train.seed, workers=False)
125 |
126 | n_gpus = torch.cuda.device_count()
127 |
128 | callbacks = [
129 | pl.callbacks.ModelCheckpoint(
130 | monitor=f"{cfg.train.validation_metric}_{cfg.train.validation_sub_param.method}",
131 | dirpath=gu.get_model_saving_path(tparam.weight.save_path_dir, cfg.name),
132 | save_top_k=1,
133 | mode="max",
134 | save_last=not True,
135 | )
136 | ]
137 | if tparam.optim.swa.use:
138 | callbacks.append(
139 | pl.callbacks.StochasticWeightAveraging(
140 | swa_epoch_start=tparam.optim.swa.swa_epoch_start,
141 | swa_lrs=tparam.optim.swa.lr,
142 | annealing_epochs=tparam.optim.swa.annealing_epochs,
143 | )
144 | )
145 |
146 | trainer = pl.Trainer(
147 | logger=tb_logger,
148 | accelerator=tparam.accelerator,
149 | strategy=tparam.strategy,
150 | max_epochs=tparam.max_epochs,
151 | precision=mparam.precision if torch.cuda.is_available() else 32,
152 | num_sanity_val_steps=tparam.num_sanity_val_steps,
153 | gpus=n_gpus,
154 | check_val_every_n_epoch=tparam.check_val_every_n_epoch,
155 | gradient_clip_val=tparam.optim.gradient_clip_val,
156 | gradient_clip_algorithm=tparam.optim.gradient_clip_algorithm,
157 | accumulate_grad_batches=tparam.accumulate_grad_batches,
158 | val_check_interval=tparam.val_check_interval,
159 | profiler=tparam.profiler,
160 | fast_dev_run=tparam.fast_dev_run,
161 | callbacks=callbacks,
162 | limit_train_batches=tparam.get("limit_train_batches", 1.0),
163 | limit_val_batches=tparam.get("limit_val_batches", 1.0),
164 | )
165 | return trainer
166 |
167 |
168 | def prepare_modules(mode, cfg):
169 |
170 | # get pretrained language models
171 | plm, plm_tokenizer, plm_model_config, TokenizerWrapperClass = get_plm(cfg)
172 |
173 | # gen templates
174 | input_templates = gen_input_templates(cfg, plm, plm_tokenizer)
175 |
176 | # get data module
177 | data_module = get_data_module(
178 | cfg, plm_tokenizer, TokenizerWrapperClass, input_templates
179 | )
180 |
181 | # get model
182 | model = get_model(cfg, plm, plm_tokenizer, input_templates)
183 |
184 | # get trainer
185 | trainer = get_trainer(cfg)
186 |
187 | return data_module, model, trainer
188 |
--------------------------------------------------------------------------------
/lbox_open/template/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/template/__init__.py
--------------------------------------------------------------------------------
/lbox_open/template/prompt_generation_utils.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | from openprompt import prompts
6 |
7 | from lbox_open.template import prompt_templates
8 |
9 | from ..constants import ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL
10 |
11 |
12 | def gen_template(task, key, type, plm, tokenizer):
13 | mytemplate = prompts.MixedTemplate(
14 | model=plm,
15 | tokenizer=tokenizer,
16 | text=prompt_templates.gen_input_template_str(task, key, type),
17 | )
18 |
19 | return mytemplate
20 |
21 |
22 | def gen_output_template(
23 | task,
24 | key,
25 | sub_keys,
26 | label,
27 | parse_sep_token,
28 | ):
29 | """ """
30 | # todo: move template part to ./template.py
31 |
32 | if task == "ljp_criminal":
33 | if key == "fine_imprisonment_lvs":
34 | label_dict = label
35 | out = ""
36 | for key in sub_keys:
37 | key_kor = ENG_TO_KOR_PARSE_NAMES_LJP_CRIMINAL[key]
38 | out += f"{key_kor}{label_dict[key]}{parse_sep_token} "
39 | out = out.strip(f"{parse_sep_token} ")
40 |
41 | else:
42 | raise NotImplementedError
43 |
44 | elif task == "ljp_civil":
45 | if key == "claim_acceptance_lv":
46 | out = str(label)
47 | else:
48 | raise NotImplementedError
49 |
50 | elif task == "casename_classification":
51 | if key == "casename_classification":
52 | out = str(label)
53 | else:
54 | raise NotImplementedError
55 | elif task == "statute_classification":
56 | assert isinstance(label, list)
57 | if key == "statute_classification":
58 | out = f"{parse_sep_token} ".join(label)
59 | else:
60 | raise NotImplementedError
61 | elif task == "summarization":
62 | if key == "summarization":
63 | out = str(label)
64 | else:
65 | raise NotImplementedError
66 | else:
67 | raise NotImplementedError
68 |
69 | return out
70 |
--------------------------------------------------------------------------------
/lbox_open/template/prompt_templates.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 |
6 | def gen_input_template_str(task, key, type):
7 | if key == "fine_imprisonment_lvs":
8 | if type == 0:
9 | input_template_str = (
10 | '{"placeholder":"text_a"} 형사사건에 대하여 순서대로 벌금, 징역, 금고 레벨을 쓰시오. {"mask"}'
11 | )
12 | else:
13 | raise NotImplementedError
14 | elif key == "claim_acceptance_lv":
15 | if type == 0:
16 | input_template_str = (
17 | '{"placeholder":"text_a"} 주어진 사실관계, 청구 취지를 읽고, 주장 인정율을 예측하시오. {"mask"}'
18 | )
19 | else:
20 | raise NotImplementedError
21 | elif key == "casename_classification":
22 | if type == 0:
23 | input_template_str = (
24 | '{"placeholder":"text_a"} 주어진 사실관계를 읽고, 사건명을 예측하시오. {"mask"}'
25 | )
26 | else:
27 | raise NotImplementedError
28 | elif key == "statute_classification":
29 | if type == 0:
30 | input_template_str = (
31 | '{"placeholder":"text_a"} 주어진 사실관계를 읽고, 적용될 형법 조항들을 예측하시오. {"mask"}'
32 | )
33 | else:
34 | raise NotImplementedError
35 | elif key == "summarization":
36 | if type == 0:
37 | input_template_str = '{"placeholder":"text_a"}\n요약하시오.\n{"mask"}'
38 | else:
39 | raise NotImplementedError
40 | else:
41 | raise NotImplementedError
42 |
43 | return input_template_str
44 |
--------------------------------------------------------------------------------
/lbox_open/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lbox-kr/lbox-open/fdad4b039af718d2b171e561e75f5771515572df/lbox_open/utils/__init__.py
--------------------------------------------------------------------------------
/lbox_open/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import json
6 | import os
7 | import pickle
8 | import subprocess
9 | import time
10 | from pathlib import Path
11 |
12 | from tqdm import tqdm
13 |
14 |
15 | def stop_flag(idx, toy_size):
16 | # idx + 1 = length
17 | data_size = idx + 1
18 | if toy_size is not None:
19 | if toy_size <= data_size:
20 | return True
21 | else:
22 | return False
23 |
24 |
25 | def save_pkl(path_save, data):
26 | with open(path_save, "wb") as f:
27 | pickle.dump(data, f)
28 |
29 |
30 | def load_pkl(path_load):
31 | with open(path_load, "rb") as f:
32 | data = pickle.load(f)
33 | return data
34 |
35 |
36 | def save_json(path_save, data):
37 | with open(path_save, "w") as f:
38 | json.dump(data, f, ensure_ascii=False)
39 |
40 |
41 | def load_json(fpath):
42 | with open(fpath) as f:
43 | return json.load(f)
44 |
45 |
46 | def save_jsonl(path_save, data):
47 | with open(path_save, "w") as f:
48 | for t1 in data:
49 | f.writelines(json.dumps(t1, ensure_ascii=False))
50 | f.writelines("\n")
51 |
52 |
53 | def load_jsonl(fpath, toy_size=None):
54 | data = []
55 | with open(fpath) as f:
56 | for i, line in tqdm(enumerate(f)):
57 | try:
58 | data1 = json.loads(line)
59 | except:
60 | print(f"{i}th sample failed.")
61 | print(f"We will wkip this!")
62 | print(line)
63 | data1 = None
64 | if data1 is not None:
65 | data.append(data1)
66 | if stop_flag(i, toy_size):
67 | break
68 |
69 | return data
70 |
71 |
72 | def my_timeit(func):
73 | def wrapped_func(*args, **kwargs):
74 | st = time.time()
75 | results = func(*args, **kwargs)
76 | ed = time.time()
77 | print(f"func {func.__name__} taks {ed - st} sec.")
78 | return results
79 |
80 | return wrapped_func
81 |
82 |
83 | def flatten_list(list_):
84 | out = []
85 | for x in list_:
86 | if isinstance(x, list):
87 | out += flatten_list(x)
88 | else:
89 | out += [x]
90 |
91 | return out
92 |
93 |
94 | def load_cfg(path_cfg):
95 | import munch
96 | import yaml
97 |
98 | with open(path_cfg) as f:
99 | cfg = yaml.full_load(f)
100 | cfg = munch.munchify(cfg)
101 | cfg.name = path_cfg.__str__().split("/")[-1]
102 | return cfg
103 |
104 |
105 | def get_model_saving_path(save_dir, cfg_name):
106 | return Path(save_dir) / cfg_name
107 |
108 |
109 | def download_url(path_save, url):
110 | p = subprocess.Popen(["wget", "-q", "-O", path_save.__str__(), url])
111 | sts = os.waitpid(p.pid, 0)
112 |
113 |
114 | def get_local_rank():
115 | """
116 | Pytorch lightning save local rank to environment variable "LOCAL_RANK".
117 | From rank_zero_only
118 | """
119 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
120 | return local_rank
121 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.62.3
2 | munch==2.5.0
3 |
4 | transformers==4.19.4
5 | pytorch-lightning==1.5.8
6 |
7 | setuptools==59.5.0 # for the compatiblity with pytorch 1.10
8 | sentencepiece==0.1.96
9 |
10 | # for OpenPrompt
11 | openprompt==1.0.0
12 | rouge_score==0.1.2
13 |
14 | # for facts
15 | thefuzz==0.19.0
16 | nltk==3.6.7
17 | python-Levenshtein==0.12.2
18 |
19 | # misc
20 | scikit-learn
21 | tweepy==3.10.0
22 | thefuzz==0.19.0
23 | python-Levenshtein==0.12.2
24 | scikit-learn==0.23.2
25 |
--------------------------------------------------------------------------------
/run_model.py:
--------------------------------------------------------------------------------
1 | # LBox Open
2 | # Copyright (c) 2022-present LBox Co. Ltd.
3 | # CC BY-NC 4.0
4 |
5 | import argparse
6 |
7 | from lbox_open.pipeline import prepare_modules
8 | from lbox_open.utils import general_utils
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("path_cfg", default="")
14 | parser.add_argument("--mode", default="")
15 | args = parser.parse_args()
16 |
17 | cfg = general_utils.load_cfg(args.path_cfg)
18 |
19 | if args.mode == "train":
20 | data_module, model, trainer = prepare_modules("train", cfg)
21 | trainer.fit(model, data_module)
22 |
23 | elif args.mode == "test":
24 | data_module, model, trainer = prepare_modules("train", cfg)
25 | trainer.test(model, datamodule=data_module)
26 | else:
27 | print(
28 | f"{args.mode} mode is not supported. The mode should be either 'train' or 'test'."
29 | )
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/scripts/predict_summarization.sh:
--------------------------------------------------------------------------------
1 | config="configs/summarization/summarization.legal-mt5s.predict.yaml"
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run_model.py $config --mode test
4 |
--------------------------------------------------------------------------------
/scripts/test_casename.sh:
--------------------------------------------------------------------------------
1 | #config="configs/casename_classification/casename.kogpt2.test.yaml"
2 | #config="configs/casename_classification/casename.lcube-base.test.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode test
5 |
6 |
--------------------------------------------------------------------------------
/scripts/test_ljp_civil.sh:
--------------------------------------------------------------------------------
1 | #config="configs/ljp/civil/ljp.civil.kogpt2.test.yaml"
2 | #config="configs/ljp/civil/ljp.civil.lcube-base.test.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode test
5 |
--------------------------------------------------------------------------------
/scripts/test_ljp_criminal.sh:
--------------------------------------------------------------------------------
1 | #config="configs/ljp/criminal/ljp.criminal.lcube-base.test.yaml"
2 | #config="configs/ljp/criminal/ljp.criminal.kogpt2.test.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode test
5 |
--------------------------------------------------------------------------------
/scripts/test_statute.sh:
--------------------------------------------------------------------------------
1 | #config="configs/statute_classification/statute.kogpt2.test.yaml"
2 | #config="configs/statute_classification/statute.lcube-base.test.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode test
5 |
--------------------------------------------------------------------------------
/scripts/test_summarization.sh:
--------------------------------------------------------------------------------
1 | config="configs/summarization/summarization.legal-mt5s.test.yaml"
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run_model.py $config --mode test
4 |
--------------------------------------------------------------------------------
/scripts/train_casename.sh:
--------------------------------------------------------------------------------
1 | #config="configs/casename_classification/casename.kogpt2.yaml"
2 | config="configs/casename_classification/casename.lcube-base.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode train
5 |
6 |
--------------------------------------------------------------------------------
/scripts/train_ljp_civil.sh:
--------------------------------------------------------------------------------
1 | #config="configs/ljp/civil/ljp.civil.kogpt2.yaml"
2 | config="configs/ljp/civil/ljp.civil.lcube-base.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode train
5 |
--------------------------------------------------------------------------------
/scripts/train_ljp_criminal.sh:
--------------------------------------------------------------------------------
1 | #config="configs/ljp/criminal/ljp.criminal.kogpt2.yaml"
2 | config="configs/ljp/criminal/ljp.criminal.lcube-base.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode train
5 |
--------------------------------------------------------------------------------
/scripts/train_statute.sh:
--------------------------------------------------------------------------------
1 | #config="configs/statute_classification/statute.kogpt2.yaml"
2 | config="configs/statute_classification/statute.lcube-base.yaml"
3 | export CUDA_VISIBLE_DEVICES=0
4 | python run_model.py $config --mode train
5 |
--------------------------------------------------------------------------------
/scripts/train_summarization.sh:
--------------------------------------------------------------------------------
1 | #config="configs/summarization/summarization.kogpt2.yaml"
2 | # config="configs/summarization/summarization.lcube-base.yaml"
3 | config="configs/summarization/summarization.legal-mt5s.yaml"
4 | export CUDA_VISIBLE_DEVICES=0
5 | python run_model.py $config --mode train
6 |
--------------------------------------------------------------------------------