├── .gitignore
├── LICENSE
├── README.md
├── README_en.md
├── question_generation
├── __init__.py
├── pipelines
│ ├── __init__.py
│ ├── question_answering.py
│ ├── question_generation.py
│ └── utils_qa.py
├── run_qa.py
├── run_qg.py
└── train_qa.py
├── requirements.txt
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | dist/
3 | build/
4 | __pycache__/
5 | *.egg-info/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2021- The Algolet team. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
基于mt5的中文问题生成任务
3 |
4 |
5 |
6 | 中文说明 |
7 | English
8 |
9 |
10 | 基于预训练模型mt5精调的问题生成模型
11 |
12 | ## 在线测试
13 | 可以直接在线使用我们的模型 https://www.algolet.com/applications/qg?accessSource=github
14 |
15 | ## 使用说明
16 | 我们提供了`question_generation` 和 `question_answering`的`pipeline` API,通过调用对应的pipeline,可以轻松实现相应任务
17 |
18 | 下面是如何使用问题生成pipepline
19 | ``` python
20 | >>> from question_generation import pipeline
21 |
22 | # Allocate a pipeline for question-generation
23 | # cpu版本, 如果不传入device参数,默认是cpu版本,
24 | >>> qg = pipeline("question-generation", device="cpu")
25 | # gpu版本
26 | >>> qg = pipeline("question-generation", device="cuda")
27 | # for single text
28 | >>> qg("在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”")
29 | ['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']
30 |
31 | # for batch input
32 | >>> texts = ["在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"]
33 | >>> qg(texts)
34 | [['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']]
35 | ```
36 | 可以使用你自己训练的模型,或者下载huggingface hub中已经微调好的模型. PyTorch版本的使用方式如下:
37 | ``` python
38 | >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39 | >>> tokenizer = AutoTokenizer.from_pretrained("algolet/mt5-base-chinese-qg")
40 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("algolet/mt5-base-chinese-qg")
41 | >>> pipe = pipeline("question-generation", model=model, tokenizer=tokenizer)
42 | ```
43 | 同时,我们也提供的问答pipeline,可以与问题生成模块集成,产生问题自动生成与回答的应用。
44 |
45 | 下面是如何使用问答pipeline
46 | ``` python
47 | >>> from question_generation import pipeline
48 |
49 | # Allocate a pipeline for question-generation
50 | >>> qa = pipeline("question-answering")
51 | >>> text = "在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"
52 | # for single qa input
53 | >>> question_answerer({
54 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?',
55 | ... 'context': text
56 | ... })
57 | {'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0}
58 |
59 | # for batch qa inputs
60 | >>> question_answerer([
61 | ... {
62 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?',
63 | ... 'context': text
64 | ... },
65 | ... {
66 | ... 'question': '农夫是如何看待蛇的?',
67 | ... 'context': text
68 | ... },
69 | ... {
70 | ... 'question': '当农夫遇到蛇时,他做了什么?',
71 | ... 'context': text
72 | ... }])
73 | [{'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0},
74 | {'answer': '我竟然救了一条可怜的毒蛇,就应该受到这种报应',
75 | 'start': 102,
76 | 'end': 124,
77 | 'score': 0.9996},
78 | {'answer': '放在怀里', 'start': 40, 'end': 44, 'score': 0.9995}]
79 | ```
80 | qa的返回中,`answer`为答案,`start`和`end`分别为答案在原文中的开始位置和结束位置
81 |
82 | ## 安装说明
83 | 需要安装pytorch>=1.3, transormfers>=4.12.5 和 datasets>=1.15.1, pip安装速度如果较慢,可使用阿里源,在安装命令后添加 -i https://mirrors.aliyun.com/pypi/simple/
84 |
85 | cuda版pytorch安装
86 | ```bash
87 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
88 | ```
89 | cup版pytorh安装
90 | ```bash
91 | pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
92 | ```
93 | 安装transformers和datasets
94 | ```bash
95 | pip install transformers
96 | pip install datasets
97 | ```
98 | 安装本项目
99 | ```bash
100 | pip install question_generation
101 | ```
102 |
103 | ## 模型训练
104 | 你可以训练自己的问题生成模型和问答模型
105 |
106 | #### 问题生成模型训练
107 | ##### 训练数据格式
108 | ``` python
109 | >>> train.json
110 | {"data": [{"source_text": "对于某些物理情况,不可能将力的形成归因于势的梯度。这通常是由于宏观物理的考虑,屈服力产生于微观状态的宏观统计平均值。例如,摩擦是由原子间大量静电势的梯度引起的,但表现为独立于任何宏观位置矢量的力模型。非保守力除摩擦力外,还包括其他接触力、拉力、压缩力和阻力。然而,对于任何足够详细的描述,所有这些力都是保守力的结果,因为每一个宏观力都是微观势梯度的净结果。",
111 | "target_text": "拉力、压缩和拉力是什么力?{sep_token}静电梯度电势会产生什么?{sep_token}为什么这些力是无法建模的呢?"}
112 | {"source_text": "绿宝石失窃案 (法语: Les Bijoux de la Castafiore ;英语: The Castafiore Emerald )是丁丁历险记的第21部作品。作者是比利时漫画家埃尔热。本作与之前的丁丁历险记有著很大的不同,丁丁首次进行没有离开自己家的冒险,同时故事中没有明显的反派角色,充满了喜剧色彩。丁丁和船长原本在城堡悠闲度假,却因歌后突然造访而弄得鸡飞狗跳;媒体对歌后的行踪极度关注,穷追猛打;歌后一颗珍贵的绿宝石又突然失踪,引起了一波接一波的疑团,究竟谁的嫌疑最大?是船长刚刚收留的一伙吉卜赛人?是偷偷混入记者群中的神秘男子?是歌后的贴身女仆?还是行迹鬼祟的钢琴师?",
113 | "target_text": "故事中引起众多谜团的原因是?{sep_token}此部作品与以往不同的地方在于哪里?{sep_token}丁丁和船长的悠闲假期因何被打乱?{sep_token}《绿宝石失窃案》是《丁丁历险记》系列的第几部?{sep_token}《绿宝石失窃案》的作者是谁?"}
114 | ...
115 | ]}
116 | ```
117 | ##### 训练配置文件
118 | ``` python
119 | >>> qg_config.json
120 | {
121 | "model_name_or_path": "google/mt5-small",
122 | "tokenizer_name": "google/mt5-small",
123 | "text_column": "source_text",
124 | "summary_column": "target_text",
125 | "train_file": "data/train.json",
126 | "validation_file": "data/dev.json",
127 | "output_dir": "data/qg",
128 | "model_type": "mt5",
129 | "overwrite_output_dir": true,
130 | "do_train": true,
131 | "do_eval": true,
132 | "source_prefix": "question generation: ",
133 | "predict_with_generate": true,
134 | "per_device_train_batch_size": 8,
135 | "per_device_eval_batch_size": 8,
136 | "gradient_accumulation_steps": 32,
137 | "learning_rate": 1e-3,
138 | "num_train_epochs": 4,
139 | "max_source_length": 512,
140 | "max_target_length": 200,
141 | "logging_steps": 100,
142 | "seed": 42
143 | }
144 | ```
145 | ##### 启动训练
146 | CUDA_VISIBLE_DEVICES=0 python run_qg.py qg_config.json
147 |
148 | #### 问答模型训练
149 | ##### 训练数据格式
150 | 与squad_2的数据格式一致
151 | ``` python
152 | >>> train.json
153 | {'version': 2.0,
154 | 'data': [{'id': 'c398789b7375e0ce7eac86f2b18c3808',
155 | 'question': '隐藏式行车记录仪哪个牌子好',
156 | 'context': '推荐使用360行车记录仪。行车记录仪的好坏,取决于行车记录仪的摄像头配置,配置越高越好,再就是性价比。 行车记录仪配置需要1296p超高清摄像头比较好,这样录制视频清晰度高。再就是价格,性价比高也是可以值得考虑的。 360行车记录仪我使用了一段时间 ,觉得360行车记录仪比较好录得广角比较大,并且便宜实惠 ,价格才299,在360商城可以买到。可以参考对比下。',
157 | 'answers': {'answer_start': [4], 'text': ['360行车记录仪']}}]}
158 | ```
159 | ##### 训练配置文件
160 | ``` python
161 | >>> qg_config.json
162 | {
163 | "model_name_or_path": "bert-base-chinese",
164 | "tokenizer_name": "bert-base-chinese",
165 | "train_file": "data/train.json",
166 | "validation_file": "data/dev.json",
167 | "output_dir": "data/qa",
168 | "per_device_train_batch_size": 8,
169 | "per_device_eval_batch_size": 8,
170 | "gradient_accumulation_steps": 32,
171 | "overwrite_output_dir": true,
172 | "do_train": true,
173 | "do_eval": true,
174 | "max_answer_length": 200
175 | }
176 | ```
177 | ##### 启动训练
178 | CUDA_VISIBLE_DEVICES=0 python run_qa.py qa_config.json
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Chinese Question Generation Using MT5 Model
4 |
5 |
6 |
7 | 中文说明 |
8 | English
9 |
10 |
11 | Question Generation model by finetuning mt5 model.
12 |
13 | ## Online demos
14 | You can test the model directly on https://www.algolet.com/applications/qg?accessSource=github
15 |
16 | ## Ouick tour
17 | To immediately use models on given inputs, we provide `question_generation` and `question_answering` `pipeline` API
18 | Pipelines group together a pretrained model with the preprocessing that was used during that model's training.
19 |
20 | Here is how to quickly use a pipeline to generate questions
21 | ``` python
22 | >>> from question_generation import pipeline
23 |
24 | # Allocate a pipeline for question-generation
25 | # for cpu
26 | >>> qg = pipeline("question-generation")
27 | # device is larger than -1 when using gpu
28 | >>> qg = pipeline("question-generation", device=0)
29 | # for single text
30 | >>> qg("在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”")
31 | ['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']
32 |
33 | # for batch input
34 | >>> texts = ["在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。
35 | "当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。
36 | "农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"]
37 | >>> qg(texts)
38 | [['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']]
39 | ```
40 | To use model of your own or any of the fine-tuned model on question-generation. Here is the PyTorch version:
41 | ``` python
42 | >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
43 | >>> tokenizer = AutoTokenizer.from_pretrained("algolet/mt5-base-chinese-qg")
44 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("algolet/mt5-base-chinese-qg")
45 | >>> pipe = pipeline("question-generation", model=model, tokenizer=tokenizer)
46 | ```
47 |
48 | Combining `question_generation` with `question_answering`
49 | so that you will have an automatic question generating ans answering application.
50 |
51 | Here is how to quickly use a pipeline to answer questions.
52 | ``` python
53 | >>> from question_generation import pipeline
54 |
55 | # Allocate a pipeline for question-generation
56 | # for cpu, default version is cpu if without device argument
57 | >>> qa = pipeline("question-answering", device="cpu")
58 | # for gpu
59 | >>> qa = pipeline("question-answering", device="cuda")
60 | >>> text = "在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"
61 | # for single qa input
62 | >>> question_answerer({
63 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?',
64 | ... 'context': text
65 | ... })
66 | {'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0}
67 |
68 | # for batch qa inputs
69 | >>> question_answerer([
70 | ... {
71 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?',
72 | ... 'context': text
73 | ... },
74 | ... {
75 | ... 'question': '农夫是如何看待蛇的?',
76 | ... 'context': text
77 | ... },
78 | ... {
79 | ... 'question': '当农夫遇到蛇时,他做了什么?',
80 | ... 'context': text
81 | ... }])
82 | [{'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0},
83 | {'answer': '我竟然救了一条可怜的毒蛇,就应该受到这种报应',
84 | 'start': 102,
85 | 'end': 124,
86 | 'score': 0.9996},
87 | {'answer': '放在怀里', 'start': 40, 'end': 44, 'score': 0.9995}]
88 | ```
89 |
90 | ## Installation
91 | This repository needs pytorch>=1.3, transormfers>=4.12.5 and datasets>=1.15.1
92 | cuda torch
93 | ```bash
94 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
95 | ```
96 | cpu torch
97 | ```bash
98 | pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
99 | ```
100 | transformers and datasets
101 | ```bash
102 | pip install transformers
103 | pip install datasets
104 | ```
105 | this repository
106 | ```bash
107 | pip install question_generation
108 | ```
109 |
110 | ## Model Training
111 | #### How to train a qg model
112 | ##### Sample data
113 | ``` python
114 | >>> train.json
115 | {"data": [{"source_text": "对于某些物理情况,不可能将力的形成归因于势的梯度。这通常是由于宏观物理的考虑,屈服力产生于微观状态的宏观统计平均值。例如,摩擦是由原子间大量静电势的梯度引起的,但表现为独立于任何宏观位置矢量的力模型。非保守力除摩擦力外,还包括其他接触力、拉力、压缩力和阻力。然而,对于任何足够详细的描述,所有这些力都是保守力的结果,因为每一个宏观力都是微观势梯度的净结果。",
116 | "target_text": "拉力、压缩和拉力是什么力?{sep_token}静电梯度电势会产生什么?{sep_token}为什么这些力是无法建模的呢?"}
117 | {"source_text": "绿宝石失窃案 (法语: Les Bijoux de la Castafiore ;英语: The Castafiore Emerald )是丁丁历险记的第21部作品。作者是比利时漫画家埃尔热。本作与之前的丁丁历险记有著很大的不同,丁丁首次进行没有离开自己家的冒险,同时故事中没有明显的反派角色,充满了喜剧色彩。丁丁和船长原本在城堡悠闲度假,却因歌后突然造访而弄得鸡飞狗跳;媒体对歌后的行踪极度关注,穷追猛打;歌后一颗珍贵的绿宝石又突然失踪,引起了一波接一波的疑团,究竟谁的嫌疑最大?是船长刚刚收留的一伙吉卜赛人?是偷偷混入记者群中的神秘男子?是歌后的贴身女仆?还是行迹鬼祟的钢琴师?",
118 | "target_text": "故事中引起众多谜团的原因是?{sep_token}此部作品与以往不同的地方在于哪里?{sep_token}丁丁和船长的悠闲假期因何被打乱?{sep_token}《绿宝石失窃案》是《丁丁历险记》系列的第几部?{sep_token}《绿宝石失窃案》的作者是谁?"}
119 | ...
120 | ]}
121 | ```
122 | ##### Example config
123 | ``` python
124 | >>> qg_config.json
125 | {
126 | "model_name_or_path": "google/mt5-small",
127 | "tokenizer_name": "google/mt5-small",
128 | "text_column": "source_text",
129 | "summary_column": "target_text",
130 | "train_file": "data/train.json",
131 | "validation_file": "data/dev.json",
132 | "output_dir": "data/qg",
133 | "model_type": "mt5",
134 | "overwrite_output_dir": true,
135 | "do_train": true,
136 | "do_eval": true,
137 | "source_prefix": "question generation: ",
138 | "predict_with_generate": true,
139 | "per_device_train_batch_size": 8,
140 | "per_device_eval_batch_size": 8,
141 | "gradient_accumulation_steps": 32,
142 | "learning_rate": 1e-3,
143 | "num_train_epochs": 4,
144 | "max_source_length": 512,
145 | "max_target_length": 200,
146 | "logging_steps": 100,
147 | "seed": 42
148 | }
149 | ```
150 | ##### Example command
151 | ```
152 | CUDA_VISIBLE_DEVICES=0 python run_qg.py qg_config.json
153 | ```
154 |
155 |
156 | #### How to train a qa model
157 | ##### Sample data
158 | ``` python
159 | >>> train.json
160 | {'version': 2.0,
161 | 'data': [{'id': 'c398789b7375e0ce7eac86f2b18c3808',
162 | 'question': '隐藏式行车记录仪哪个牌子好',
163 | 'context': '推荐使用360行车记录仪。行车记录仪的好坏,取决于行车记录仪的摄像头配置,配置越高越好,再就是性价比。 行车记录仪配置需要1296p超高清摄像头比较好,这样录制视频清晰度高。再就是价格,性价比高也是可以值得考虑的。 360行车记录仪我使用了一段时间 ,觉得360行车记录仪比较好录得广角比较大,并且便宜实惠 ,价格才299,在360商城可以买到。可以参考对比下。',
164 | 'answers': {'answer_start': [4], 'text': ['360行车记录仪']}}]}
165 | ```
166 | ##### Example config
167 | ``` python
168 | >>> qa_config.json
169 | {
170 | "model_name_or_path": "bert-base-chinese",
171 | "tokenizer_name": "bert-base-chinese",
172 | "train_file": "data/train.json",
173 | "validation_file": "data/dev.json",
174 | "output_dir": "data/qa",
175 | "per_device_train_batch_size": 8,
176 | "per_device_eval_batch_size": 8,
177 | "gradient_accumulation_steps": 32,
178 | "overwrite_output_dir": true,
179 | "do_train": true,
180 | "do_eval": true,
181 | "max_answer_length": 200
182 | }
183 | ```
184 | ##### Example command
185 | ```
186 | CUDA_VISIBLE_DEVICES=0 python run_qa.py qa_config.json
187 | ```
188 |
189 |
190 |
191 |
192 |
193 |
194 |
--------------------------------------------------------------------------------
/question_generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Pipelines
2 | from .pipelines import pipeline
3 |
--------------------------------------------------------------------------------
/question_generation/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .question_generation import QuestionGenerationPipeline
3 | from .question_answering import QuestionAnsweringPipeline
4 | from typing import Optional, Union
5 | from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | SUPPORTED_TASKS = {
10 | "question-generation": {
11 | "impl": QuestionGenerationPipeline,
12 | "pt": AutoModelForSeq2SeqLM,
13 | "default": {
14 | "model": "algolet/mt5-base-chinese-qg"
15 | }
16 | },
17 | "question-answering": {
18 | "impl": QuestionAnsweringPipeline,
19 | "pt": AutoModelForQuestionAnswering,
20 | "default": {
21 | "model": "luhua/chinese_pretrain_mrc_macbert_large"
22 | }
23 | }
24 | }
25 |
26 |
27 | def pipeline(
28 | task: str,
29 | model: Optional = None,
30 | tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
31 | device: str = "cpu"
32 | ):
33 | if task is None and model is None:
34 | raise RuntimeError(
35 | "Impossible to instantiate a pipeline without either a task or a model"
36 | "being specified."
37 | "Please provide a task class or a model"
38 | )
39 |
40 | if model is None and tokenizer is not None:
41 | raise RuntimeError(
42 | "Impossible to instantiate a pipeline with tokenizer specified but not the model "
43 | "as the provided tokenizer may not be compatible with the default model. "
44 | "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing tokenizer."
45 | )
46 |
47 | if task not in SUPPORTED_TASKS:
48 | raise KeyError(f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys())}")
49 | targeted_task = SUPPORTED_TASKS[task]
50 | pipeline_class = targeted_task["impl"]
51 |
52 | if model is None:
53 | model = targeted_task["default"]["model"]
54 |
55 | model_name = model if isinstance(model, str) else None
56 | model_classes = targeted_task["pt"]
57 |
58 | if tokenizer is None:
59 | if isinstance(model_name, str):
60 | tokenizer = model_name
61 | else:
62 | raise Exception(
63 | "Impossible to guess which tokenizer to use. "
64 | "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
65 | )
66 |
67 | # Instantiate tokenizer if needed
68 | if isinstance(tokenizer, (str, tuple)):
69 | if isinstance(tokenizer, tuple):
70 | # For tuple we have (tokenizer name, {kwargs})
71 | tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
72 | else:
73 | tokenizer = AutoTokenizer.from_pretrained(tokenizer)
74 |
75 | if isinstance(model, str):
76 | model = model_classes.from_pretrained(model)
77 | return pipeline_class(model=model, tokenizer=tokenizer, device=device)
78 |
--------------------------------------------------------------------------------
/question_generation/pipelines/question_answering.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset
2 | import logging
3 | import torch
4 | from .utils_qa import postprocess_qa_predictions
5 | from typing import Union, List, Dict
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class QuestionAnsweringPipeline:
11 | def __init__(self,
12 | tokenizer,
13 | model,
14 | max_seq_length: int = 384,
15 | doc_stride: int = 32,
16 | version_2_with_negative: bool = True,
17 | n_best_size: int = 1,
18 | max_answer_length: int = 30,
19 | null_score_diff_threshold: int = 0,
20 | device: str = "cpu"):
21 | self.tokenizer = tokenizer
22 | self.model = model
23 | self.max_seq_length = max_seq_length
24 | self.doc_stride = doc_stride
25 | self.question_column_name = "question"
26 | self.context_column_name = "context"
27 | self.version_2_with_negative = version_2_with_negative
28 | self.n_best_size = n_best_size
29 | self.max_answer_length = max_answer_length
30 | self.null_score_diff_threshold = null_score_diff_threshold
31 | if device == "cpu":
32 | self.device = "cpu"
33 | elif device == "cuda":
34 | self.device = torch.device("cuda")
35 | self.model = self.model.to(self.device)
36 | else:
37 | raise Exception(
38 | "device should be cup or cuda"
39 | )
40 | self.model.eval()
41 | # Validation preprocessing
42 |
43 | def prepare_validation_features(self, examples):
44 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the
45 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
46 | # left whitespace
47 | examples[self.question_column_name] = [q.lstrip() for q in examples[self.question_column_name]]
48 |
49 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
50 | # in one example possible giving several features when a context is long, each of those features having a
51 | # context that overlaps a bit the context of the previous feature.
52 | tokenized_examples = self.tokenizer(
53 | examples[self.question_column_name],
54 | examples[self.context_column_name],
55 | truncation="only_second",
56 | max_length=self.max_seq_length,
57 | stride=self.doc_stride,
58 | return_overflowing_tokens=True,
59 | return_offsets_mapping=True,
60 | padding=True
61 | )
62 |
63 | # Since one example might give us several features if it has a long context, we need a map from a feature to
64 | # its corresponding example. This key gives us just that.
65 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
66 |
67 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
68 | # correspnding example_id and we will store the offset mappings.
69 | tokenized_examples["example_id"] = []
70 |
71 | for i in range(len(tokenized_examples["input_ids"])):
72 | # Grab the sequence corresponding to that example (to know what is the context and what is the question).
73 | sequence_ids = tokenized_examples.sequence_ids(i)
74 | context_index = 1
75 |
76 | # One example can give several spans, this is the index of the example containing this span of text.
77 | sample_index = sample_mapping[i]
78 | tokenized_examples["example_id"].append(examples["id"][sample_index])
79 |
80 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
81 | # position is part of the context or not.
82 | tokenized_examples["offset_mapping"][i] = [
83 | (o if sequence_ids[k] == context_index else None)
84 | for k, o in enumerate(tokenized_examples["offset_mapping"][i])
85 | ]
86 | return tokenized_examples
87 |
88 | # region Metrics and Post-processing:
89 |
90 | def post_processing_function(self, examples, features, predictions, stage="eval"):
91 | # Post-processing: we match the start logits and end logits to answers in the original context.
92 | predictions = postprocess_qa_predictions(
93 | examples=examples,
94 | features=features,
95 | predictions=predictions,
96 | version_2_with_negative=self.version_2_with_negative,
97 | n_best_size=self.n_best_size,
98 | max_answer_length=self.max_answer_length,
99 | null_score_diff_threshold=self.null_score_diff_threshold,
100 | prefix=stage,
101 | )
102 | # Format the result to the format the metric expects.
103 | formatted_predictions = [{"id": k, "prediction": v} for k, v in predictions.items()]
104 | return formatted_predictions
105 |
106 | def __call__(self, input: Union[Dict, List[Dict]], *args, **kwargs):
107 | if isinstance(input, dict):
108 | questions = [input["question"]]
109 | contexts = [input["context"]]
110 | else:
111 | questions = [item["question"] for item in input]
112 | contexts = [item["context"] for item in input]
113 | ids = [i for i in range(len(questions))]
114 | examples = {"id": ids, "question": questions, "context": contexts}
115 | processed_datasets = self.prepare_validation_features(examples)
116 | inputs = {
117 | "input_ids": torch.tensor(processed_datasets["input_ids"]),
118 | "attention_mask": torch.tensor(processed_datasets["attention_mask"]),
119 | }
120 |
121 | input_ids = inputs["input_ids"].to(self.device)
122 | input_masks = inputs["attention_mask"].to(self.device)
123 |
124 | with torch.no_grad():
125 | eval_predictions = self.model(input_ids=input_ids,
126 | attention_mask=input_masks)
127 | examples = Dataset.from_dict(examples)
128 | flat_examples = []
129 | for idx in range(len(examples["id"])):
130 | flat_examples.append({"id": examples["id"][idx],
131 | "question": examples["question"][idx],
132 | "context": examples["context"][idx]})
133 |
134 | flat_processed_datasets = []
135 | for idx in range(len(processed_datasets["input_ids"])):
136 | flat_processed_datasets.append({"input_ids": processed_datasets["input_ids"][idx],
137 | "token_type_ids": processed_datasets["token_type_ids"][idx],
138 | "attention_mask": processed_datasets["attention_mask"][idx],
139 | "offset_mapping": processed_datasets["offset_mapping"][idx],
140 | "example_id": processed_datasets["example_id"][idx]})
141 | post_processed_eval = self.post_processing_function(
142 | flat_examples,
143 | flat_processed_datasets,
144 | (eval_predictions.start_logits.cpu().numpy(), eval_predictions.end_logits.cpu().numpy()),
145 | )
146 | res = []
147 | for item in post_processed_eval:
148 | if not item["prediction"]:
149 | res.append(dict())
150 | else:
151 | if item["prediction"]["start"] == item["prediction"]["end"]:
152 | res.append(dict())
153 | else:
154 | res.append({"answer": item["prediction"]["text"],
155 | "start": item["prediction"]["start"],
156 | "end": item["prediction"]["end"],
157 | "score": item["prediction"]["probability"]})
158 | if isinstance(input, dict):
159 | return res[0]
160 | return res
161 |
--------------------------------------------------------------------------------
/question_generation/pipelines/question_generation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union, List, Dict
3 | from torch import Tensor
4 | import torch
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class QuestionGenerationPipeline:
10 | def __init__(self,
11 | tokenizer,
12 | model,
13 | max_source_length: int = 512,
14 | max_target_length: int = 200,
15 | device: str = "cpu"):
16 | self.tokenizer = tokenizer
17 | self.model = model
18 | if device == "cpu":
19 | self.device = "cpu"
20 | elif device == "cuda":
21 | self.device = torch.device("cuda")
22 | self.model = self.model.to(self.device)
23 | else:
24 | raise Exception(
25 | "device should be cpu or cuda"
26 | )
27 | self.model.eval()
28 | self.max_source_length = max_source_length
29 | self.max_target_length = max_target_length
30 | self.source_prefix = "question generation: "
31 |
32 | def preprocess(self, examples: List[str]) -> Dict[str, Tensor]:
33 | added_prefix = ["question generation: " + example for example in examples]
34 | inputs = self.tokenizer(added_prefix,
35 | return_tensors='pt',
36 | padding=True,
37 | truncation=True,
38 | max_length=self.max_source_length)
39 | return inputs
40 |
41 | def post_processing_function(self, outs):
42 | questions = self.tokenizer.batch_decode(outs, skip_special_tokens=True)
43 | separated_questions = []
44 | for each_batch_questions in questions:
45 | question = [q.strip() for q in each_batch_questions.split("") if q.strip() != ""]
46 | if len(each_batch_questions) > 5 and each_batch_questions[-5:] != "": # 一个完整的问题需以结尾
47 | question = question[:-1]
48 | separated_questions.append(question)
49 |
50 | return separated_questions
51 |
52 | def __call__(self, input: Union[str, List[str]], *args, **kwargs):
53 | logger.info("*** Prediction ***")
54 | if isinstance(input, str):
55 | examples = [input]
56 | else:
57 | examples = input
58 | examples = self.preprocess(examples)
59 | input_ids = examples["input_ids"].to(self.device)
60 | input_mask = examples["attention_mask"].to(self.device)
61 |
62 | with torch.no_grad():
63 | outs = self.model.generate(input_ids=input_ids,
64 | attention_mask=input_mask,
65 | max_length=self.max_target_length,
66 | no_repeat_ngram_size=4,
67 | num_beams=4)
68 |
69 | post_processed = self.post_processing_function(outs)
70 | if isinstance(input, str):
71 | post_processed = post_processed[0]
72 | return post_processed
73 |
--------------------------------------------------------------------------------
/question_generation/pipelines/utils_qa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Team All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Post-processing utilities for question answering.
17 | """
18 | import collections
19 | import json
20 | import logging
21 | import os
22 | from typing import Optional, Tuple
23 |
24 | import numpy as np
25 | from tqdm.auto import tqdm
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | def postprocess_qa_predictions(
31 | examples,
32 | features,
33 | predictions: Tuple[np.ndarray, np.ndarray],
34 | version_2_with_negative: bool = False,
35 | n_best_size: int = 20,
36 | max_answer_length: int = 30,
37 | null_score_diff_threshold: float = 0.0,
38 | output_dir: Optional[str] = None,
39 | prefix: Optional[str] = None,
40 | log_level: Optional[int] = logging.WARNING,
41 | ):
42 | """
43 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
44 | original contexts. This is the base postprocessing functions for models that only return start and end logits.
45 |
46 | Args:
47 | examples: The non-preprocessed dataset (see the main script for more information).
48 | features: The processed dataset (see the main script for more information).
49 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
50 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
51 | first dimension must match the number of elements of :obj:`features`.
52 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
53 | Whether or not the underlying dataset contains examples with no answers.
54 | n_best_size (:obj:`int`, `optional`, defaults to 20):
55 | The total number of n-best predictions to generate when looking for an answer.
56 | max_answer_length (:obj:`int`, `optional`, defaults to 30):
57 | The maximum length of an answer that can be generated. This is needed because the start and end predictions
58 | are not conditioned on one another.
59 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
60 | The threshold used to select the null answer: if the best answer has a score that is less than the score of
61 | the null answer minus this threshold, the null answer is selected for this example (note that the score of
62 | the null answer for an example giving several features is the minimum of the scores for the null answer on
63 | each feature: all features must be aligned on the fact they `want` to predict a null answer).
64 |
65 | Only useful when :obj:`version_2_with_negative` is :obj:`True`.
66 | output_dir (:obj:`str`, `optional`):
67 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
68 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
69 | answers, are saved in `output_dir`.
70 | prefix (:obj:`str`, `optional`):
71 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
72 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
73 | ``logging`` log level (e.g., ``logging.WARNING``)
74 | """
75 |
76 | assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
77 | all_start_logits, all_end_logits = predictions
78 |
79 | assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
80 |
81 | # Build a map example to its corresponding features.
82 | # example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
83 | example_id_to_index = {k["id"]: i for i, k in enumerate(examples)}
84 | features_per_example = collections.defaultdict(list)
85 |
86 | for i, feature in enumerate(features):
87 | features_per_example[example_id_to_index[feature["example_id"]]].append(i)
88 |
89 | # The dictionaries we have to fill.
90 | all_predictions = collections.OrderedDict()
91 | if version_2_with_negative:
92 | scores_diff_json = collections.OrderedDict()
93 |
94 | # Logging.
95 | logger.setLevel(log_level)
96 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
97 |
98 | # Let's loop over all the examples!
99 | for example_index, example in enumerate(examples):
100 | # Those are the indices of the features associated to the current example.
101 | feature_indices = features_per_example[example_index]
102 |
103 | min_null_prediction = None
104 | prelim_predictions = []
105 |
106 | # Looping through all the features associated to the current example.
107 | for feature_index in feature_indices:
108 | # We grab the predictions of the model for this feature.
109 | start_logits = all_start_logits[feature_index]
110 | end_logits = all_end_logits[feature_index]
111 | # This is what will allow us to map some the positions in our logits to span of texts in the original
112 | # context.
113 | offset_mapping = features[feature_index]["offset_mapping"]
114 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
115 | # available in the current feature.
116 | token_is_max_context = features[feature_index].get("token_is_max_context", None)
117 |
118 | # Update minimum null prediction.
119 | feature_null_score = start_logits[0] + end_logits[0]
120 | if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
121 | min_null_prediction = {
122 | "offsets": (0, 0),
123 | "score": feature_null_score,
124 | "start_logit": start_logits[0],
125 | "end_logit": end_logits[0],
126 | }
127 |
128 | # Go through all possibilities for the `n_best_size` greater start and end logits.
129 | start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
130 | end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
131 | for start_index in start_indexes:
132 | for end_index in end_indexes:
133 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
134 | # to part of the input_ids that are not in the context.
135 | if (
136 | start_index >= len(offset_mapping)
137 | or end_index >= len(offset_mapping)
138 | or offset_mapping[start_index] is None
139 | or offset_mapping[end_index] is None
140 | ):
141 | continue
142 | # Don't consider answers with a length that is either < 0 or > max_answer_length.
143 | if end_index < start_index or end_index - start_index + 1 > max_answer_length:
144 | continue
145 | # Don't consider answer that don't have the maximum context available (if such information is
146 | # provided).
147 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
148 | continue
149 | prelim_predictions.append(
150 | {
151 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
152 | "score": start_logits[start_index] + end_logits[end_index],
153 | "start_logit": start_logits[start_index].item(),
154 | "end_logit": end_logits[end_index].item(),
155 | }
156 | )
157 |
158 | if version_2_with_negative:
159 | # Add the minimum null prediction
160 | prelim_predictions.append(min_null_prediction)
161 | null_score = min_null_prediction["score"]
162 |
163 | # Only keep the best `n_best_size` predictions.
164 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
165 |
166 | # Add back the minimum null prediction if it was removed because of its low score.
167 | if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
168 | predictions.append(min_null_prediction)
169 |
170 | # Use the offsets to gather the answer text in the original context.
171 | context = example["context"]
172 | for pred in predictions:
173 | offsets = pred.pop("offsets")
174 | pred["text"] = context[offsets[0]: offsets[1]]
175 | pred["start"] = offsets[0]
176 | pred["end"] = offsets[1]
177 |
178 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
179 | # failure.
180 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
181 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0, "start": 0,
182 | "end": 0})
183 |
184 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
185 | # the LogSumExp trick).
186 | scores = np.array([pred.pop("score") for pred in predictions])
187 | exp_scores = np.exp(scores - np.max(scores))
188 | probs = exp_scores / exp_scores.sum()
189 |
190 | # Include the probabilities in our predictions.
191 | for prob, pred in zip(probs, predictions):
192 | pred["probability"] = prob.item()
193 |
194 | # Pick the best prediction. If the null answer is not possible, this is easy.
195 | if not version_2_with_negative:
196 | all_predictions[example["id"]] = predictions[0]["text"]
197 | all_predictions[example["id"]] = predictions[0]
198 | else:
199 | # Otherwise we first need to find the best non-empty prediction.
200 | i = 0
201 | while predictions[i]["text"] == "":
202 | i += 1
203 | best_non_null_pred = predictions[i]
204 |
205 | # Then we compare to the null prediction using the threshold.
206 | score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
207 | scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
208 | if score_diff > null_score_diff_threshold:
209 | all_predictions[example["id"]] = {}
210 | else:
211 | all_predictions[example["id"]] = best_non_null_pred
212 | return all_predictions
213 |
214 |
215 | def postprocess_qa_predictions_with_beam_search(
216 | examples,
217 | features,
218 | predictions: Tuple[np.ndarray, np.ndarray],
219 | version_2_with_negative: bool = False,
220 | n_best_size: int = 20,
221 | max_answer_length: int = 30,
222 | start_n_top: int = 5,
223 | end_n_top: int = 5,
224 | output_dir: Optional[str] = None,
225 | prefix: Optional[str] = None,
226 | log_level: Optional[int] = logging.WARNING,
227 | ):
228 | """
229 | Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
230 | original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
231 | cls token predictions.
232 |
233 | Args:
234 | examples: The non-preprocessed dataset (see the main script for more information).
235 | features: The processed dataset (see the main script for more information).
236 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
237 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
238 | first dimension must match the number of elements of :obj:`features`.
239 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
240 | Whether or not the underlying dataset contains examples with no answers.
241 | n_best_size (:obj:`int`, `optional`, defaults to 20):
242 | The total number of n-best predictions to generate when looking for an answer.
243 | max_answer_length (:obj:`int`, `optional`, defaults to 30):
244 | The maximum length of an answer that can be generated. This is needed because the start and end predictions
245 | are not conditioned on one another.
246 | start_n_top (:obj:`int`, `optional`, defaults to 5):
247 | The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
248 | end_n_top (:obj:`int`, `optional`, defaults to 5):
249 | The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
250 | output_dir (:obj:`str`, `optional`):
251 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
252 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
253 | answers, are saved in `output_dir`.
254 | prefix (:obj:`str`, `optional`):
255 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
256 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
257 | ``logging`` log level (e.g., ``logging.WARNING``)
258 | """
259 | assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
260 | start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
261 |
262 | assert len(predictions[0]) == len(
263 | features
264 | ), f"Got {len(predictions[0])} predicitions and {len(features)} features."
265 |
266 | # Build a map example to its corresponding features.
267 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
268 | features_per_example = collections.defaultdict(list)
269 | for i, feature in enumerate(features):
270 | features_per_example[example_id_to_index[feature["example_id"]]].append(i)
271 |
272 | # The dictionaries we have to fill.
273 | all_predictions = collections.OrderedDict()
274 | all_nbest_json = collections.OrderedDict()
275 | scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
276 |
277 | # Logging.
278 | logger.setLevel(log_level)
279 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
280 |
281 | # Let's loop over all the examples!
282 | for example_index, example in enumerate(tqdm(examples)):
283 | # Those are the indices of the features associated to the current example.
284 | feature_indices = features_per_example[example_index]
285 |
286 | min_null_score = None
287 | prelim_predictions = []
288 |
289 | # Looping through all the features associated to the current example.
290 | for feature_index in feature_indices:
291 | # We grab the predictions of the model for this feature.
292 | start_log_prob = start_top_log_probs[feature_index]
293 | start_indexes = start_top_index[feature_index]
294 | end_log_prob = end_top_log_probs[feature_index]
295 | end_indexes = end_top_index[feature_index]
296 | feature_null_score = cls_logits[feature_index]
297 | # This is what will allow us to map some the positions in our logits to span of texts in the original
298 | # context.
299 | offset_mapping = features[feature_index]["offset_mapping"]
300 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
301 | # available in the current feature.
302 | token_is_max_context = features[feature_index].get("token_is_max_context", None)
303 |
304 | # Update minimum null prediction
305 | if min_null_score is None or feature_null_score < min_null_score:
306 | min_null_score = feature_null_score
307 |
308 | # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
309 | for i in range(start_n_top):
310 | for j in range(end_n_top):
311 | start_index = int(start_indexes[i])
312 | j_index = i * end_n_top + j
313 | end_index = int(end_indexes[j_index])
314 | # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
315 | # p_mask but let's not take any risk)
316 | if (
317 | start_index >= len(offset_mapping)
318 | or end_index >= len(offset_mapping)
319 | or offset_mapping[start_index] is None
320 | or offset_mapping[end_index] is None
321 | ):
322 | continue
323 | # Don't consider answers with a length negative or > max_answer_length.
324 | if end_index < start_index or end_index - start_index + 1 > max_answer_length:
325 | continue
326 | # Don't consider answer that don't have the maximum context available (if such information is
327 | # provided).
328 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
329 | continue
330 | prelim_predictions.append(
331 | {
332 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
333 | "score": start_log_prob[i] + end_log_prob[j_index],
334 | "start_log_prob": start_log_prob[i],
335 | "end_log_prob": end_log_prob[j_index],
336 | }
337 | )
338 |
339 | # Only keep the best `n_best_size` predictions.
340 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
341 |
342 | # Use the offsets to gather the answer text in the original context.
343 | context = example["context"]
344 | for pred in predictions:
345 | offsets = pred.pop("offsets")
346 | pred["text"] = context[offsets[0]: offsets[1]]
347 |
348 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
349 | # failure.
350 | if len(predictions) == 0:
351 | predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
352 |
353 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
354 | # the LogSumExp trick).
355 | scores = np.array([pred.pop("score") for pred in predictions])
356 | exp_scores = np.exp(scores - np.max(scores))
357 | probs = exp_scores / exp_scores.sum()
358 |
359 | # Include the probabilities in our predictions.
360 | for prob, pred in zip(probs, predictions):
361 | pred["probability"] = prob
362 |
363 | # Pick the best prediction and set the probability for the null answer.
364 | all_predictions[example["id"]] = predictions[0]["text"]
365 | if version_2_with_negative:
366 | scores_diff_json[example["id"]] = float(min_null_score)
367 |
368 | # Make `predictions` JSON-serializable by casting np.float back to float.
369 | all_nbest_json[example["id"]] = [
370 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
371 | for pred in predictions
372 | ]
373 |
374 | # If we have an output_dir, let's save all those dicts.
375 | if output_dir is not None:
376 | assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
377 |
378 | prediction_file = os.path.join(
379 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
380 | )
381 | nbest_file = os.path.join(
382 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
383 | )
384 | if version_2_with_negative:
385 | null_odds_file = os.path.join(
386 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
387 | )
388 |
389 | logger.info(f"Saving predictions to {prediction_file}.")
390 | with open(prediction_file, "w") as writer:
391 | writer.write(json.dumps(all_predictions, indent=4) + "\n")
392 | logger.info(f"Saving nbest_preds to {nbest_file}.")
393 | with open(nbest_file, "w") as writer:
394 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
395 | if version_2_with_negative:
396 | logger.info(f"Saving null_odds to {null_odds_file}.")
397 | with open(null_odds_file, "w") as writer:
398 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
399 |
400 | return all_predictions, scores_diff_json
401 |
--------------------------------------------------------------------------------
/question_generation/run_qa.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Team All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for question answering using a slightly adapted version of the 🤗 Trainer.
18 | """
19 | # You can also adapt this script on your own question answering task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | from dataclasses import dataclass, field
25 | from typing import Optional
26 |
27 | import datasets
28 | from datasets import load_dataset, load_metric
29 |
30 | import transformers
31 | from trainer_qa import QuestionAnsweringTrainer
32 | from transformers import (
33 | AutoConfig,
34 | AutoModelForQuestionAnswering,
35 | AutoTokenizer,
36 | DataCollatorWithPadding,
37 | EvalPrediction,
38 | HfArgumentParser,
39 | PreTrainedTokenizerFast,
40 | TrainingArguments,
41 | default_data_collator,
42 | set_seed,
43 | )
44 | from transformers.trainer_utils import get_last_checkpoint
45 | from utils_qa import postprocess_qa_predictions
46 |
47 | os.environ["WANDB_DISABLED"] = "true"
48 |
49 | logger = logging.getLogger(__name__)
50 |
51 |
52 | @dataclass
53 | class ModelArguments:
54 | """
55 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
56 | """
57 |
58 | model_name_or_path: str = field(
59 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
60 | )
61 | config_name: Optional[str] = field(
62 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
63 | )
64 | tokenizer_name: Optional[str] = field(
65 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
66 | )
67 | cache_dir: Optional[str] = field(
68 | default=None,
69 | metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
70 | )
71 | model_revision: str = field(
72 | default="main",
73 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
74 | )
75 | use_auth_token: bool = field(
76 | default=False,
77 | metadata={
78 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
79 | "with private models)."
80 | },
81 | )
82 |
83 |
84 | @dataclass
85 | class DataTrainingArguments:
86 | """
87 | Arguments pertaining to what data we are going to input our model for training and eval.
88 | """
89 |
90 | dataset_name: Optional[str] = field(
91 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
92 | )
93 | dataset_config_name: Optional[str] = field(
94 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
95 | )
96 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
97 | validation_file: Optional[str] = field(
98 | default=None,
99 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
100 | )
101 | test_file: Optional[str] = field(
102 | default=None,
103 | metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
104 | )
105 | overwrite_cache: bool = field(
106 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
107 | )
108 | preprocessing_num_workers: Optional[int] = field(
109 | default=None,
110 | metadata={"help": "The number of processes to use for the preprocessing."},
111 | )
112 | max_seq_length: int = field(
113 | default=384,
114 | metadata={
115 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
116 | "than this will be truncated, sequences shorter will be padded."
117 | },
118 | )
119 | pad_to_max_length: bool = field(
120 | default=True,
121 | metadata={
122 | "help": "Whether to pad all samples to `max_seq_length`. "
123 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
124 | "be faster on GPU but will be slower on TPU)."
125 | },
126 | )
127 | max_train_samples: Optional[int] = field(
128 | default=None,
129 | metadata={
130 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
131 | "value if set."
132 | },
133 | )
134 | max_eval_samples: Optional[int] = field(
135 | default=None,
136 | metadata={
137 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
138 | "value if set."
139 | },
140 | )
141 | max_predict_samples: Optional[int] = field(
142 | default=None,
143 | metadata={
144 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
145 | "value if set."
146 | },
147 | )
148 | version_2_with_negative: bool = field(
149 | default=True, metadata={"help": "If true, some of the examples do not have an answer."}
150 | )
151 | null_score_diff_threshold: float = field(
152 | default=0.0,
153 | metadata={
154 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
155 | "the score of the null answer minus this threshold, the null answer is selected for this example. "
156 | "Only useful when `version_2_with_negative=True`."
157 | },
158 | )
159 | doc_stride: int = field(
160 | default=128,
161 | metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
162 | )
163 | n_best_size: int = field(
164 | default=20,
165 | metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
166 | )
167 | max_answer_length: int = field(
168 | default=30,
169 | metadata={
170 | "help": "The maximum length of an answer that can be generated. This is needed because the start "
171 | "and end predictions are not conditioned on one another."
172 | },
173 | )
174 |
175 | def __post_init__(self):
176 | if (
177 | self.dataset_name is None
178 | and self.train_file is None
179 | and self.validation_file is None
180 | and self.test_file is None
181 | ):
182 | raise ValueError("Need either a dataset name or a training/validation file/test_file.")
183 | else:
184 | if self.train_file is not None:
185 | extension = self.train_file.split(".")[-1]
186 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
187 | if self.validation_file is not None:
188 | extension = self.validation_file.split(".")[-1]
189 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
190 | if self.test_file is not None:
191 | extension = self.test_file.split(".")[-1]
192 | assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
193 |
194 |
195 | def main():
196 | # See all possible arguments in src/transformers/training_args.py
197 | # or by passing the --help flag to this script.
198 | # We now keep distinct sets of args, for a cleaner separation of concerns.
199 |
200 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
201 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
202 | # If we pass only one argument to the script and it's the path to a json file,
203 | # let's parse it to get our arguments.
204 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
205 | else:
206 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
207 |
208 | # Setup logging
209 | logging.basicConfig(
210 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
211 | datefmt="%m/%d/%Y %H:%M:%S",
212 | handlers=[logging.StreamHandler(sys.stdout)],
213 | )
214 |
215 | log_level = training_args.get_process_log_level()
216 | logger.setLevel(log_level)
217 | datasets.utils.logging.set_verbosity(log_level)
218 | transformers.utils.logging.set_verbosity(log_level)
219 | transformers.utils.logging.enable_default_handler()
220 | transformers.utils.logging.enable_explicit_format()
221 |
222 | # Log on each process the small summary:
223 | logger.warning(
224 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
225 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
226 | )
227 | logger.info(f"Training/evaluation parameters {training_args}")
228 |
229 | # Detecting last checkpoint.
230 | last_checkpoint = None
231 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
232 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
233 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
234 | raise ValueError(
235 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
236 | "Use --overwrite_output_dir to overcome."
237 | )
238 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
239 | logger.info(
240 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
241 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
242 | )
243 |
244 | # Set seed before initializing model.
245 | set_seed(training_args.seed)
246 |
247 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
248 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
249 | # (the dataset will be downloaded automatically from the datasets Hub).
250 | #
251 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
252 | # 'text' is found. You can easily tweak this behavior (see below).
253 | #
254 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
255 | # download the dataset.
256 | if data_args.dataset_name is not None:
257 | # Downloading and loading a dataset from the hub.
258 | raw_datasets = load_dataset(
259 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
260 | )
261 | else:
262 | data_files = {}
263 | if data_args.train_file is not None:
264 | data_files["train"] = data_args.train_file
265 | extension = data_args.train_file.split(".")[-1]
266 |
267 | if data_args.validation_file is not None:
268 | data_files["validation"] = data_args.validation_file
269 | extension = data_args.validation_file.split(".")[-1]
270 | if data_args.test_file is not None:
271 | data_files["test"] = data_args.test_file
272 | extension = data_args.test_file.split(".")[-1]
273 | raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
274 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
275 | # https://huggingface.co/docs/datasets/loading_datasets.html.
276 |
277 | # Load pretrained model and tokenizer
278 | #
279 | # Distributed training:
280 | # The .from_pretrained methods guarantee that only one local process can concurrently
281 | # download model & vocab.
282 | config = AutoConfig.from_pretrained(
283 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
284 | cache_dir=model_args.cache_dir,
285 | revision=model_args.model_revision,
286 | use_auth_token=True if model_args.use_auth_token else None,
287 | )
288 | tokenizer = AutoTokenizer.from_pretrained(
289 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
290 | cache_dir=model_args.cache_dir,
291 | use_fast=True,
292 | revision=model_args.model_revision,
293 | use_auth_token=True if model_args.use_auth_token else None,
294 | )
295 | model = AutoModelForQuestionAnswering.from_pretrained(
296 | model_args.model_name_or_path,
297 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
298 | config=config,
299 | cache_dir=model_args.cache_dir,
300 | revision=model_args.model_revision,
301 | use_auth_token=True if model_args.use_auth_token else None,
302 | )
303 |
304 | # Tokenizer check: this script requires a fast tokenizer.
305 | if not isinstance(tokenizer, PreTrainedTokenizerFast):
306 | raise ValueError(
307 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
308 | "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
309 | "requirement"
310 | )
311 |
312 | # Preprocessing the datasets.
313 | # Preprocessing is slighlty different for training and evaluation.
314 | if training_args.do_train:
315 | column_names = raw_datasets["train"].column_names
316 | elif training_args.do_eval:
317 | column_names = raw_datasets["validation"].column_names
318 | else:
319 | column_names = raw_datasets["test"].column_names
320 | question_column_name = "question" if "question" in column_names else column_names[0]
321 | context_column_name = "context" if "context" in column_names else column_names[1]
322 | answer_column_name = "answers" if "answers" in column_names else column_names[2]
323 |
324 | # Padding side determines if we do (question|context) or (context|question).
325 | pad_on_right = tokenizer.padding_side == "right"
326 |
327 | if data_args.max_seq_length > tokenizer.model_max_length:
328 | logger.warning(
329 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
330 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
331 | )
332 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
333 |
334 | # Training preprocessing
335 | def prepare_train_features(examples):
336 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the
337 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
338 | # left whitespace
339 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
340 |
341 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
342 | # in one example possible giving several features when a context is long, each of those features having a
343 | # context that overlaps a bit the context of the previous feature.
344 | tokenized_examples = tokenizer(
345 | examples[question_column_name if pad_on_right else context_column_name],
346 | examples[context_column_name if pad_on_right else question_column_name],
347 | truncation="only_second" if pad_on_right else "only_first",
348 | max_length=max_seq_length,
349 | stride=data_args.doc_stride,
350 | return_overflowing_tokens=True,
351 | return_offsets_mapping=True,
352 | padding="max_length" if data_args.pad_to_max_length else False,
353 | )
354 |
355 | # Since one example might give us several features if it has a long context, we need a map from a feature to
356 | # its corresponding example. This key gives us just that.
357 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
358 | # The offset mappings will give us a map from token to character position in the original context. This will
359 | # help us compute the start_positions and end_positions.
360 | offset_mapping = tokenized_examples.pop("offset_mapping")
361 |
362 | # Let's label those examples!
363 | tokenized_examples["start_positions"] = []
364 | tokenized_examples["end_positions"] = []
365 |
366 | for i, offsets in enumerate(offset_mapping):
367 | # We will label impossible answers with the index of the CLS token.
368 | input_ids = tokenized_examples["input_ids"][i]
369 | cls_index = input_ids.index(tokenizer.cls_token_id)
370 |
371 | # Grab the sequence corresponding to that example (to know what is the context and what is the question).
372 | sequence_ids = tokenized_examples.sequence_ids(i)
373 |
374 | # One example can give several spans, this is the index of the example containing this span of text.
375 | sample_index = sample_mapping[i]
376 | answers = examples[answer_column_name][sample_index]
377 | # If no answers are given, set the cls_index as answer.
378 | if len(answers["answer_start"]) == 0:
379 | tokenized_examples["start_positions"].append(cls_index)
380 | tokenized_examples["end_positions"].append(cls_index)
381 | else:
382 | # Start/end character index of the answer in the text.
383 | start_char = answers["answer_start"][0]
384 | end_char = start_char + len(answers["text"][0])
385 |
386 | # Start token index of the current span in the text.
387 | token_start_index = 0
388 | while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
389 | token_start_index += 1
390 |
391 | # End token index of the current span in the text.
392 | token_end_index = len(input_ids) - 1
393 | while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
394 | token_end_index -= 1
395 |
396 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
397 | if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
398 | tokenized_examples["start_positions"].append(cls_index)
399 | tokenized_examples["end_positions"].append(cls_index)
400 | else:
401 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
402 | # Note: we could go after the last offset if the answer is the last word (edge case).
403 | while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
404 | token_start_index += 1
405 | tokenized_examples["start_positions"].append(token_start_index - 1)
406 | while offsets[token_end_index][1] >= end_char:
407 | token_end_index -= 1
408 | tokenized_examples["end_positions"].append(token_end_index + 1)
409 |
410 | return tokenized_examples
411 |
412 | if training_args.do_train:
413 | if "train" not in raw_datasets:
414 | raise ValueError("--do_train requires a train dataset")
415 | train_dataset = raw_datasets["train"]
416 | if data_args.max_train_samples is not None:
417 | # We will select sample from whole data if argument is specified
418 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
419 | # Create train feature from dataset
420 | with training_args.main_process_first(desc="train dataset map pre-processing"):
421 | train_dataset = train_dataset.map(
422 | prepare_train_features,
423 | batched=True,
424 | num_proc=data_args.preprocessing_num_workers,
425 | remove_columns=column_names,
426 | load_from_cache_file=not data_args.overwrite_cache,
427 | desc="Running tokenizer on train dataset",
428 | )
429 | if data_args.max_train_samples is not None:
430 | # Number of samples might increase during Feature Creation, We select only specified max samples
431 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
432 |
433 | # Validation preprocessing
434 | def prepare_validation_features(examples):
435 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the
436 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
437 | # left whitespace
438 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
439 |
440 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
441 | # in one example possible giving several features when a context is long, each of those features having a
442 | # context that overlaps a bit the context of the previous feature.
443 | tokenized_examples = tokenizer(
444 | examples[question_column_name if pad_on_right else context_column_name],
445 | examples[context_column_name if pad_on_right else question_column_name],
446 | truncation="only_second" if pad_on_right else "only_first",
447 | max_length=max_seq_length,
448 | stride=data_args.doc_stride,
449 | return_overflowing_tokens=True,
450 | return_offsets_mapping=True,
451 | padding="max_length" if data_args.pad_to_max_length else False,
452 | )
453 |
454 | # Since one example might give us several features if it has a long context, we need a map from a feature to
455 | # its corresponding example. This key gives us just that.
456 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
457 |
458 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
459 | # corresponding example_id and we will store the offset mappings.
460 | tokenized_examples["example_id"] = []
461 |
462 | for i in range(len(tokenized_examples["input_ids"])):
463 | # Grab the sequence corresponding to that example (to know what is the context and what is the question).
464 | sequence_ids = tokenized_examples.sequence_ids(i)
465 | context_index = 1 if pad_on_right else 0
466 |
467 | # One example can give several spans, this is the index of the example containing this span of text.
468 | sample_index = sample_mapping[i]
469 | tokenized_examples["example_id"].append(examples["id"][sample_index])
470 |
471 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
472 | # position is part of the context or not.
473 | tokenized_examples["offset_mapping"][i] = [
474 | (o if sequence_ids[k] == context_index else None)
475 | for k, o in enumerate(tokenized_examples["offset_mapping"][i])
476 | ]
477 |
478 | return tokenized_examples
479 |
480 | if training_args.do_eval:
481 | if "validation" not in raw_datasets:
482 | raise ValueError("--do_eval requires a validation dataset")
483 | eval_examples = raw_datasets["validation"]
484 | if data_args.max_eval_samples is not None:
485 | # We will select sample from whole data
486 | eval_examples = eval_examples.select(range(data_args.max_eval_samples))
487 | # Validation Feature Creation
488 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
489 | eval_dataset = eval_examples.map(
490 | prepare_validation_features,
491 | batched=True,
492 | num_proc=data_args.preprocessing_num_workers,
493 | remove_columns=column_names,
494 | load_from_cache_file=not data_args.overwrite_cache,
495 | desc="Running tokenizer on validation dataset",
496 | )
497 | if data_args.max_eval_samples is not None:
498 | # During Feature creation dataset samples might increase, we will select required samples again
499 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
500 |
501 | if training_args.do_predict:
502 | if "test" not in raw_datasets:
503 | raise ValueError("--do_predict requires a test dataset")
504 | predict_examples = raw_datasets["test"]
505 | if data_args.max_predict_samples is not None:
506 | # We will select sample from whole data
507 | predict_examples = predict_examples.select(range(data_args.max_predict_samples))
508 | # Predict Feature Creation
509 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
510 | predict_dataset = predict_examples.map(
511 | prepare_validation_features,
512 | batched=True,
513 | num_proc=data_args.preprocessing_num_workers,
514 | remove_columns=column_names,
515 | load_from_cache_file=not data_args.overwrite_cache,
516 | desc="Running tokenizer on prediction dataset",
517 | )
518 | if data_args.max_predict_samples is not None:
519 | # During Feature creation dataset samples might increase, we will select required samples again
520 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
521 |
522 | # Data collator
523 | # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
524 | # collator.
525 | data_collator = (
526 | default_data_collator
527 | if data_args.pad_to_max_length
528 | else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
529 | )
530 |
531 | # Post-processing:
532 | def post_processing_function(examples, features, predictions, stage="eval"):
533 | # Post-processing: we match the start logits and end logits to answers in the original context.
534 | predictions = postprocess_qa_predictions(
535 | examples=examples,
536 | features=features,
537 | predictions=predictions,
538 | version_2_with_negative=data_args.version_2_with_negative,
539 | n_best_size=data_args.n_best_size,
540 | max_answer_length=data_args.max_answer_length,
541 | null_score_diff_threshold=data_args.null_score_diff_threshold,
542 | output_dir=training_args.output_dir,
543 | log_level=log_level,
544 | prefix=stage,
545 | )
546 | # Format the result to the format the metric expects.
547 | if data_args.version_2_with_negative:
548 | formatted_predictions = [
549 | {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
550 | ]
551 | else:
552 | formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
553 |
554 | references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
555 | return EvalPrediction(predictions=formatted_predictions, label_ids=references)
556 |
557 | metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
558 |
559 | def compute_metrics(p: EvalPrediction):
560 | return metric.compute(predictions=p.predictions, references=p.label_ids)
561 |
562 | # Initialize our Trainer
563 | trainer = QuestionAnsweringTrainer(
564 | model=model,
565 | args=training_args,
566 | train_dataset=train_dataset if training_args.do_train else None,
567 | eval_dataset=eval_dataset if training_args.do_eval else None,
568 | eval_examples=eval_examples if training_args.do_eval else None,
569 | tokenizer=tokenizer,
570 | data_collator=data_collator,
571 | post_process_function=post_processing_function,
572 | compute_metrics=compute_metrics,
573 | )
574 |
575 | # Training
576 | if training_args.do_train:
577 | checkpoint = None
578 | if training_args.resume_from_checkpoint is not None:
579 | checkpoint = training_args.resume_from_checkpoint
580 | elif last_checkpoint is not None:
581 | checkpoint = last_checkpoint
582 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
583 | trainer.save_model() # Saves the tokenizer too for easy upload
584 |
585 | metrics = train_result.metrics
586 | max_train_samples = (
587 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
588 | )
589 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
590 |
591 | trainer.log_metrics("train", metrics)
592 | trainer.save_metrics("train", metrics)
593 | trainer.save_state()
594 |
595 | # Evaluation
596 | if training_args.do_eval:
597 | logger.info("*** Evaluate ***")
598 | metrics = trainer.evaluate()
599 |
600 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
601 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
602 |
603 | trainer.log_metrics("eval", metrics)
604 | trainer.save_metrics("eval", metrics)
605 |
606 | # Prediction
607 | if training_args.do_predict:
608 | logger.info("*** Predict ***")
609 | results = trainer.predict(predict_dataset, predict_examples)
610 | metrics = results.metrics
611 |
612 | max_predict_samples = (
613 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
614 | )
615 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
616 |
617 | trainer.log_metrics("predict", metrics)
618 | trainer.save_metrics("predict", metrics)
619 |
620 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
621 | if data_args.dataset_name is not None:
622 | kwargs["dataset_tags"] = data_args.dataset_name
623 | if data_args.dataset_config_name is not None:
624 | kwargs["dataset_args"] = data_args.dataset_config_name
625 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
626 | else:
627 | kwargs["dataset"] = data_args.dataset_name
628 |
629 | if training_args.push_to_hub:
630 | trainer.push_to_hub(**kwargs)
631 | else:
632 | trainer.create_model_card(**kwargs)
633 |
634 |
635 | if __name__ == "__main__":
636 | main()
637 |
--------------------------------------------------------------------------------
/question_generation/run_qg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for question generation.
18 | """
19 | import logging
20 | import os
21 | import sys
22 | from dataclasses import dataclass, field
23 | from typing import Optional
24 |
25 | import datasets
26 | from datasets import load_dataset
27 | import numpy as np
28 |
29 | import transformers
30 | from transformers import (
31 | AutoConfig,
32 | AutoModelForSeq2SeqLM,
33 | AutoTokenizer,
34 | DataCollatorForSeq2Seq,
35 | HfArgumentParser,
36 | Seq2SeqTrainer,
37 | Seq2SeqTrainingArguments,
38 | set_seed,
39 | )
40 | from rouge import Rouge
41 | from transformers.trainer_utils import get_last_checkpoint
42 |
43 | os.environ["WANDB_DISABLED"] = "true"
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | @dataclass
49 | class ModelArguments:
50 | """
51 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
52 | """
53 |
54 | model_name_or_path: str = field(
55 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
56 | )
57 | config_name: Optional[str] = field(
58 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
59 | )
60 | tokenizer_name: Optional[str] = field(
61 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
62 | )
63 | cache_dir: Optional[str] = field(
64 | default=None,
65 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
66 | )
67 | use_fast_tokenizer: bool = field(
68 | default=True,
69 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
70 | )
71 | model_revision: str = field(
72 | default="main",
73 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
74 | )
75 | use_auth_token: bool = field(
76 | default=False,
77 | metadata={
78 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
79 | "with private models)."
80 | },
81 | )
82 | resize_position_embeddings: Optional[bool] = field(
83 | default=None,
84 | metadata={
85 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
86 | "the model's position embeddings."
87 | },
88 | )
89 |
90 |
91 | @dataclass
92 | class DataTrainingArguments:
93 | """
94 | Arguments pertaining to what data we are going to input our model for training and eval.
95 | """
96 |
97 | dataset_name: Optional[str] = field(
98 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
99 | )
100 | dataset_config_name: Optional[str] = field(
101 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
102 | )
103 | text_column: Optional[str] = field(
104 | default=None,
105 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
106 | )
107 | summary_column: Optional[str] = field(
108 | default=None,
109 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
110 | )
111 | train_file: Optional[str] = field(
112 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
113 | )
114 | validation_file: Optional[str] = field(
115 | default=None,
116 | metadata={
117 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
118 | "(a jsonlines or csv file)."
119 | },
120 | )
121 | test_file: Optional[str] = field(
122 | default=None,
123 | metadata={
124 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
125 | },
126 | )
127 | overwrite_cache: bool = field(
128 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
129 | )
130 | preprocessing_num_workers: Optional[int] = field(
131 | default=None,
132 | metadata={"help": "The number of processes to use for the preprocessing."},
133 | )
134 | max_source_length: Optional[int] = field(
135 | default=1024,
136 | metadata={
137 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
138 | "than this will be truncated, sequences shorter will be padded."
139 | },
140 | )
141 | max_target_length: Optional[int] = field(
142 | default=128,
143 | metadata={
144 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
145 | "than this will be truncated, sequences shorter will be padded."
146 | },
147 | )
148 | val_max_target_length: Optional[int] = field(
149 | default=None,
150 | metadata={
151 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
152 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
153 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
154 | "during ``evaluate`` and ``predict``."
155 | },
156 | )
157 | pad_to_max_length: bool = field(
158 | default=False,
159 | metadata={
160 | "help": "Whether to pad all samples to model maximum sentence length. "
161 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
162 | "efficient on GPU but very bad for TPU."
163 | },
164 | )
165 | max_train_samples: Optional[int] = field(
166 | default=None,
167 | metadata={
168 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
169 | "value if set."
170 | },
171 | )
172 | max_eval_samples: Optional[int] = field(
173 | default=None,
174 | metadata={
175 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
176 | "value if set."
177 | },
178 | )
179 | max_predict_samples: Optional[int] = field(
180 | default=None,
181 | metadata={
182 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
183 | "value if set."
184 | },
185 | )
186 | num_beams: Optional[int] = field(
187 | default=None,
188 | metadata={
189 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
190 | "which is used during ``evaluate`` and ``predict``."
191 | },
192 | )
193 | ignore_pad_token_for_loss: bool = field(
194 | default=True,
195 | metadata={
196 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
197 | },
198 | )
199 | source_prefix: Optional[str] = field(
200 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
201 | )
202 |
203 | def __post_init__(self):
204 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
205 | raise ValueError("Need either a dataset name or a training/validation file.")
206 | else:
207 | if self.train_file is not None:
208 | extension = self.train_file.split(".")[-1]
209 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
210 | if self.validation_file is not None:
211 | extension = self.validation_file.split(".")[-1]
212 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
213 | if self.val_max_target_length is None:
214 | self.val_max_target_length = self.max_target_length
215 |
216 |
217 | def main():
218 | # See all possible arguments in src/transformers/training_args.py
219 | # or by passing the --help flag to this script.
220 | # We now keep distinct sets of args, for a cleaner separation of concerns.
221 |
222 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
223 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
224 | # If we pass only one argument to the script and it's the path to a json file,
225 | # let's parse it to get our arguments.
226 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
227 | else:
228 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
229 |
230 | # Setup logging
231 | logging.basicConfig(
232 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
233 | datefmt="%m/%d/%Y %H:%M:%S",
234 | handlers=[logging.StreamHandler(sys.stdout)],
235 | )
236 | log_level = training_args.get_process_log_level()
237 | logger.setLevel(log_level)
238 | datasets.utils.logging.set_verbosity(log_level)
239 | transformers.utils.logging.set_verbosity(log_level)
240 | transformers.utils.logging.enable_default_handler()
241 | transformers.utils.logging.enable_explicit_format()
242 |
243 | # Log on each process the small summary:
244 | logger.warning(
245 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
246 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
247 | )
248 | logger.info(f"Training/evaluation parameters {training_args}")
249 | # Detecting last checkpoint.
250 | last_checkpoint = None
251 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
252 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
253 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
254 | raise ValueError(
255 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
256 | "Use --overwrite_output_dir to overcome."
257 | )
258 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
259 | logger.info(
260 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
261 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
262 | )
263 |
264 | # Set seed before initializing model.
265 | set_seed(training_args.seed)
266 |
267 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
268 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
269 | # (the dataset will be downloaded automatically from the datasets Hub).
270 | #
271 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
272 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
273 | #
274 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
275 | # download the dataset.
276 | if data_args.dataset_name is not None:
277 | # Downloading and loading a dataset from the hub.
278 | raw_datasets = load_dataset(
279 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
280 | )
281 | else:
282 | data_files = {}
283 | if data_args.train_file is not None:
284 | data_files["train"] = data_args.train_file
285 | extension = data_args.train_file.split(".")[-1]
286 | if data_args.validation_file is not None:
287 | data_files["validation"] = data_args.validation_file
288 | extension = data_args.validation_file.split(".")[-1]
289 | if data_args.test_file is not None:
290 | data_files["test"] = data_args.test_file
291 | extension = data_args.test_file.split(".")[-1]
292 | raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
293 |
294 | # Load pretrained model and tokenizer
295 | config = AutoConfig.from_pretrained(
296 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
297 | cache_dir=model_args.cache_dir,
298 | revision=model_args.model_revision,
299 | use_auth_token=True if model_args.use_auth_token else None,
300 | )
301 | tokenizer = AutoTokenizer.from_pretrained(
302 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
303 | cache_dir=model_args.cache_dir,
304 | use_fast=model_args.use_fast_tokenizer,
305 | revision=model_args.model_revision,
306 | use_auth_token=True if model_args.use_auth_token else None,
307 | )
308 | tokenizer.add_tokens("")
309 | model = AutoModelForSeq2SeqLM.from_pretrained(
310 | model_args.model_name_or_path,
311 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
312 | config=config,
313 | cache_dir=model_args.cache_dir,
314 | revision=model_args.model_revision,
315 | use_auth_token=True if model_args.use_auth_token else None,
316 | )
317 |
318 | model.resize_token_embeddings(len(tokenizer))
319 |
320 | if model.config.decoder_start_token_id is None:
321 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
322 |
323 | if (
324 | hasattr(model.config, "max_position_embeddings")
325 | and model.config.max_position_embeddings < data_args.max_source_length
326 | ):
327 | if model_args.resize_position_embeddings is None:
328 | logger.warning(
329 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} "
330 | f"to {data_args.max_source_length}."
331 | )
332 | model.resize_position_embeddings(data_args.max_source_length)
333 | elif model_args.resize_position_embeddings:
334 | model.resize_position_embeddings(data_args.max_source_length)
335 | else:
336 | raise ValueError(
337 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
338 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
339 | "resize the model's position encodings by passing `--resize_position_embeddings`."
340 | )
341 |
342 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
343 |
344 | # Preprocessing the datasets.
345 | # We need to tokenize inputs and targets.
346 | if training_args.do_train:
347 | column_names = raw_datasets["train"].column_names
348 | elif training_args.do_eval:
349 | column_names = raw_datasets["validation"].column_names
350 | elif training_args.do_predict:
351 | column_names = raw_datasets["test"].column_names
352 | else:
353 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
354 | return
355 |
356 | # Get the column names for input/target.
357 | text_column = data_args.text_column
358 | if text_column not in column_names:
359 | raise ValueError(
360 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
361 | )
362 | summary_column = data_args.summary_column
363 | if summary_column not in column_names:
364 | raise ValueError(
365 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
366 | )
367 |
368 | # Temporarily set max_target_length for training.
369 | max_target_length = data_args.max_target_length
370 | padding = "max_length" if data_args.pad_to_max_length else False
371 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
372 | logger.warning(
373 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
374 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
375 | )
376 |
377 | def preprocess_function(examples):
378 | inputs = examples[text_column]
379 | targets = examples[summary_column]
380 | inputs = [prefix + inp for inp in inputs]
381 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
382 |
383 | # Setup the tokenizer for targets
384 | with tokenizer.as_target_tokenizer():
385 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
386 |
387 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
388 | # padding in the loss.
389 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
390 | labels["input_ids"] = [
391 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
392 | ]
393 |
394 | model_inputs["labels"] = labels["input_ids"]
395 | return model_inputs
396 |
397 | def _add_special_tokens(example):
398 | example['target_text'] = example['target_text'] + "{sep_token}"
399 | example['target_text'] = example['target_text'].replace("{sep_token}", "")
400 | return example
401 |
402 | if training_args.do_train:
403 | if "train" not in raw_datasets:
404 | raise ValueError("--do_train requires a train dataset")
405 | train_dataset = raw_datasets["train"]
406 | if data_args.max_train_samples is not None:
407 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
408 | with training_args.main_process_first(desc="train dataset map pre-processing"):
409 | train_dataset = train_dataset.map(_add_special_tokens)
410 | train_dataset = train_dataset.map(
411 | preprocess_function,
412 | batched=True,
413 | num_proc=data_args.preprocessing_num_workers,
414 | remove_columns=column_names,
415 | load_from_cache_file=not data_args.overwrite_cache,
416 | desc="Running tokenizer on train dataset",
417 | )
418 |
419 | if training_args.do_eval:
420 | max_target_length = data_args.val_max_target_length
421 | if "validation" not in raw_datasets:
422 | raise ValueError("--do_eval requires a validation dataset")
423 | eval_dataset = raw_datasets["validation"]
424 | if data_args.max_eval_samples is not None:
425 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
426 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
427 | eval_dataset = eval_dataset.map(_add_special_tokens)
428 | eval_dataset = eval_dataset.map(
429 | preprocess_function,
430 | batched=True,
431 | num_proc=data_args.preprocessing_num_workers,
432 | remove_columns=column_names,
433 | load_from_cache_file=not data_args.overwrite_cache,
434 | desc="Running tokenizer on validation dataset",
435 | )
436 |
437 | if training_args.do_predict:
438 | max_target_length = data_args.val_max_target_length
439 | if "test" not in raw_datasets:
440 | raise ValueError("--do_predict requires a test dataset")
441 | predict_dataset = raw_datasets["test"]
442 | if data_args.max_predict_samples is not None:
443 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
444 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
445 | predict_dataset = predict_dataset.map(_add_special_tokens)
446 | predict_dataset = predict_dataset.map(
447 | preprocess_function,
448 | batched=True,
449 | num_proc=data_args.preprocessing_num_workers,
450 | remove_columns=column_names,
451 | load_from_cache_file=not data_args.overwrite_cache,
452 | desc="Running tokenizer on prediction dataset",
453 | )
454 |
455 | # Data collator
456 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
457 | data_collator = DataCollatorForSeq2Seq(
458 | tokenizer,
459 | model=model,
460 | label_pad_token_id=label_pad_token_id,
461 | pad_to_multiple_of=8 if training_args.fp16 else None,
462 | )
463 |
464 | # Metric
465 | def postprocess_text(preds, labels):
466 | preds = [pred.strip() for pred in preds]
467 | labels = [label.strip() for label in labels]
468 | preds = ["\n".join(pred.split("")) for pred in preds]
469 | labels = ["\n".join(label.split("")) for label in labels]
470 | return preds, labels
471 |
472 | def compute_metrics(eval_preds):
473 | metric = Rouge()
474 | preds, labels = eval_preds
475 | if isinstance(preds, tuple):
476 | preds = preds[0]
477 | decoded_preds = [" ".join(tokenizer.convert_ids_to_tokens(pred)) for pred in preds]
478 | if data_args.ignore_pad_token_for_loss:
479 | # Replace -100 in the labels as we can't decode them.
480 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
481 | decoded_labels = [" ".join(tokenizer.convert_ids_to_tokens(label)) for label in labels]
482 | decode_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
483 | result = metric.get_scores(decoded_preds, decoded_labels, avg=True)
484 | result = {k: round(v["f"], 4) for k, v in result.items()}
485 | return result
486 |
487 | # Initialize our Trainer
488 | trainer = Seq2SeqTrainer(
489 | model=model,
490 | args=training_args,
491 | train_dataset=train_dataset if training_args.do_train else None,
492 | eval_dataset=eval_dataset if training_args.do_eval else None,
493 | tokenizer=tokenizer,
494 | data_collator=data_collator,
495 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
496 | )
497 |
498 | # Training
499 | if training_args.do_train:
500 | checkpoint = None
501 | if training_args.resume_from_checkpoint is not None:
502 | checkpoint = training_args.resume_from_checkpoint
503 | elif last_checkpoint is not None:
504 | checkpoint = last_checkpoint
505 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
506 | trainer.save_model() # Saves the tokenizer too for easy upload
507 |
508 | metrics = train_result.metrics
509 | max_train_samples = (
510 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
511 | )
512 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
513 |
514 | trainer.log_metrics("train", metrics)
515 | trainer.save_metrics("train", metrics)
516 | trainer.save_state()
517 |
518 | # Evaluation
519 | results = {}
520 | max_length = (
521 | training_args.generation_max_length
522 | if training_args.generation_max_length is not None
523 | else data_args.val_max_target_length
524 | )
525 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
526 | if training_args.do_eval:
527 | logger.info("*** Evaluate ***")
528 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
529 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
530 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
531 |
532 | trainer.log_metrics("eval", metrics)
533 | trainer.save_metrics("eval", metrics)
534 |
535 | if training_args.do_predict:
536 | logger.info("*** Predict ***")
537 |
538 | predict_results = trainer.predict(
539 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
540 | )
541 | metrics = predict_results.metrics
542 | max_predict_samples = (
543 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
544 | )
545 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
546 |
547 | trainer.log_metrics("predict", metrics)
548 | trainer.save_metrics("predict", metrics)
549 |
550 | if trainer.is_world_process_zero():
551 | if training_args.predict_with_generate:
552 | predictions = tokenizer.batch_decode(
553 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
554 | )
555 | predictions = [pred.strip() for pred in predictions]
556 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
557 | with open(output_prediction_file, "w") as writer:
558 | writer.write("\n".join(predictions))
559 |
560 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
561 | if data_args.dataset_name is not None:
562 | kwargs["dataset_tags"] = data_args.dataset_name
563 | if data_args.dataset_config_name is not None:
564 | kwargs["dataset_args"] = data_args.dataset_config_name
565 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
566 | else:
567 | kwargs["dataset"] = data_args.dataset_name
568 |
569 | if training_args.push_to_hub:
570 | trainer.push_to_hub(**kwargs)
571 | else:
572 | trainer.create_model_card(**kwargs)
573 |
574 | return results
575 |
576 |
577 | if __name__ == "__main__":
578 | main()
579 |
--------------------------------------------------------------------------------
/question_generation/train_qa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Team All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | A subclass of `Trainer` specific to Question-Answering tasks
17 | """
18 |
19 | from transformers import Trainer, is_torch_tpu_available
20 | from transformers.trainer_utils import PredictionOutput
21 |
22 |
23 | if is_torch_tpu_available():
24 | import torch_xla.core.xla_model as xm
25 | import torch_xla.debug.metrics as met
26 |
27 |
28 | class QuestionAnsweringTrainer(Trainer):
29 | def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
30 | super().__init__(*args, **kwargs)
31 | self.eval_examples = eval_examples
32 | self.post_process_function = post_process_function
33 |
34 | def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
35 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
36 | eval_dataloader = self.get_eval_dataloader(eval_dataset)
37 | eval_examples = self.eval_examples if eval_examples is None else eval_examples
38 |
39 | # Temporarily disable metric computation, we will do it in the loop here.
40 | compute_metrics = self.compute_metrics
41 | self.compute_metrics = None
42 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
43 | try:
44 | output = eval_loop(
45 | eval_dataloader,
46 | description="Evaluation",
47 | # No point gathering the predictions if there are no metrics, otherwise we defer to
48 | # self.args.prediction_loss_only
49 | prediction_loss_only=True if compute_metrics is None else None,
50 | ignore_keys=ignore_keys,
51 | )
52 | finally:
53 | self.compute_metrics = compute_metrics
54 |
55 | if self.post_process_function is not None and self.compute_metrics is not None:
56 | eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
57 | metrics = self.compute_metrics(eval_preds)
58 |
59 | # Prefix all keys with metric_key_prefix + '_'
60 | for key in list(metrics.keys()):
61 | if not key.startswith(f"{metric_key_prefix}_"):
62 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
63 |
64 | self.log(metrics)
65 | else:
66 | metrics = {}
67 |
68 | if self.args.tpu_metrics_debug or self.args.debug:
69 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
70 | xm.master_print(met.metrics_report())
71 |
72 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
73 | return metrics
74 |
75 | def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
76 | predict_dataloader = self.get_test_dataloader(predict_dataset)
77 |
78 | # Temporarily disable metric computation, we will do it in the loop here.
79 | compute_metrics = self.compute_metrics
80 | self.compute_metrics = None
81 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
82 | try:
83 | output = eval_loop(
84 | predict_dataloader,
85 | description="Prediction",
86 | # No point gathering the predictions if there are no metrics, otherwise we defer to
87 | # self.args.prediction_loss_only
88 | prediction_loss_only=True if compute_metrics is None else None,
89 | ignore_keys=ignore_keys,
90 | )
91 | finally:
92 | self.compute_metrics = compute_metrics
93 |
94 | if self.post_process_function is None or self.compute_metrics is None:
95 | return output
96 |
97 | predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
98 | metrics = self.compute_metrics(predictions)
99 |
100 | # Prefix all keys with metric_key_prefix + '_'
101 | for key in list(metrics.keys()):
102 | if not key.startswith(f"{metric_key_prefix}_"):
103 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
104 |
105 | return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
106 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.12.5
2 | datasets==1.15.1
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | requirements = [
4 | "transformers>=4.12.5",
5 | "datasets>=1.15.1",
6 | "torch>=1.3"
7 | ]
8 |
9 | setup(
10 | name="question_generation",
11 | version="1.0.5",
12 | author="algolet",
13 | author_email="wei.cai@algolet.com",
14 | description="Question Generation and Question Answering Pipeline",
15 | long_description=open("README.md", "r", encoding="utf-8").read(),
16 | long_description_content_type="text/markdown",
17 | license="Apache",
18 | url="https://github.com/algolet/question_generation",
19 | packages=find_packages(),
20 | install_requires=requirements,
21 | python_requires=">=3.6.0",
22 | classifiers=[
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Education",
25 | "Intended Audience :: Science/Research",
26 | "License :: OSI Approved :: Apache Software License",
27 | "Programming Language :: Python :: 3",
28 | "Programming Language :: Python :: 3.6",
29 | "Programming Language :: Python :: 3.7",
30 | "Programming Language :: Python :: 3.8",
31 | "Programming Language :: Python :: 3.9",
32 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
33 | ]
34 | )
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------