├── .gitignore ├── LICENSE ├── README.md ├── README_en.md ├── question_generation ├── __init__.py ├── pipelines │ ├── __init__.py │ ├── question_answering.py │ ├── question_generation.py │ └── utils_qa.py ├── run_qa.py ├── run_qg.py └── train_qa.py ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | dist/ 3 | build/ 4 | __pycache__/ 5 | *.egg-info/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021- The Algolet team. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

基于mt5的中文问题生成任务

3 |

4 |

5 |

6 | 中文说明 | 7 | English 8 |

9 |

10 | 基于预训练模型mt5精调的问题生成模型 11 | 12 | ## 在线测试 13 | 可以直接在线使用我们的模型 https://www.algolet.com/applications/qg?accessSource=github 14 | 15 | ## 使用说明 16 | 我们提供了`question_generation` 和 `question_answering`的`pipeline` API,通过调用对应的pipeline,可以轻松实现相应任务 17 | 18 | 下面是如何使用问题生成pipepline 19 | ``` python 20 | >>> from question_generation import pipeline 21 | 22 | # Allocate a pipeline for question-generation 23 | # cpu版本, 如果不传入device参数,默认是cpu版本, 24 | >>> qg = pipeline("question-generation", device="cpu") 25 | # gpu版本 26 | >>> qg = pipeline("question-generation", device="cuda") 27 | # for single text 28 | >>> qg("在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”") 29 | ['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?'] 30 | 31 | # for batch input 32 | >>> texts = ["在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"] 33 | >>> qg(texts) 34 | [['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']] 35 | ``` 36 | 可以使用你自己训练的模型,或者下载huggingface hub中已经微调好的模型. PyTorch版本的使用方式如下: 37 | ``` python 38 | >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 39 | >>> tokenizer = AutoTokenizer.from_pretrained("algolet/mt5-base-chinese-qg") 40 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("algolet/mt5-base-chinese-qg") 41 | >>> pipe = pipeline("question-generation", model=model, tokenizer=tokenizer) 42 | ``` 43 | 同时,我们也提供的问答pipeline,可以与问题生成模块集成,产生问题自动生成与回答的应用。 44 | 45 | 下面是如何使用问答pipeline 46 | ``` python 47 | >>> from question_generation import pipeline 48 | 49 | # Allocate a pipeline for question-generation 50 | >>> qa = pipeline("question-answering") 51 | >>> text = "在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”" 52 | # for single qa input 53 | >>> question_answerer({ 54 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', 55 | ... 'context': text 56 | ... }) 57 | {'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0} 58 | 59 | # for batch qa inputs 60 | >>> question_answerer([ 61 | ... { 62 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', 63 | ... 'context': text 64 | ... }, 65 | ... { 66 | ... 'question': '农夫是如何看待蛇的?', 67 | ... 'context': text 68 | ... }, 69 | ... { 70 | ... 'question': '当农夫遇到蛇时,他做了什么?', 71 | ... 'context': text 72 | ... }]) 73 | [{'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0}, 74 | {'answer': '我竟然救了一条可怜的毒蛇,就应该受到这种报应', 75 | 'start': 102, 76 | 'end': 124, 77 | 'score': 0.9996}, 78 | {'answer': '放在怀里', 'start': 40, 'end': 44, 'score': 0.9995}] 79 | ``` 80 | qa的返回中,`answer`为答案,`start`和`end`分别为答案在原文中的开始位置和结束位置 81 | 82 | ## 安装说明 83 | 需要安装pytorch>=1.3, transormfers>=4.12.5 和 datasets>=1.15.1, pip安装速度如果较慢,可使用阿里源,在安装命令后添加 -i https://mirrors.aliyun.com/pypi/simple/ 84 | 85 | cuda版pytorch安装 86 | ```bash 87 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 88 | ``` 89 | cup版pytorh安装 90 | ```bash 91 | pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html 92 | ``` 93 | 安装transformers和datasets 94 | ```bash 95 | pip install transformers 96 | pip install datasets 97 | ``` 98 | 安装本项目 99 | ```bash 100 | pip install question_generation 101 | ``` 102 | 103 | ## 模型训练 104 | 你可以训练自己的问题生成模型和问答模型 105 | 106 | #### 问题生成模型训练 107 | ##### 训练数据格式 108 | ``` python 109 | >>> train.json 110 | {"data": [{"source_text": "对于某些物理情况,不可能将力的形成归因于势的梯度。这通常是由于宏观物理的考虑,屈服力产生于微观状态的宏观统计平均值。例如,摩擦是由原子间大量静电势的梯度引起的,但表现为独立于任何宏观位置矢量的力模型。非保守力除摩擦力外,还包括其他接触力、拉力、压缩力和阻力。然而,对于任何足够详细的描述,所有这些力都是保守力的结果,因为每一个宏观力都是微观势梯度的净结果。", 111 | "target_text": "拉力、压缩和拉力是什么力?{sep_token}静电梯度电势会产生什么?{sep_token}为什么这些力是无法建模的呢?"} 112 | {"source_text": "绿宝石失窃案 (法语: Les Bijoux de la Castafiore ;英语: The Castafiore Emerald )是丁丁历险记的第21部作品。作者是比利时漫画家埃尔热。本作与之前的丁丁历险记有著很大的不同,丁丁首次进行没有离开自己家的冒险,同时故事中没有明显的反派角色,充满了喜剧色彩。丁丁和船长原本在城堡悠闲度假,却因歌后突然造访而弄得鸡飞狗跳;媒体对歌后的行踪极度关注,穷追猛打;歌后一颗珍贵的绿宝石又突然失踪,引起了一波接一波的疑团,究竟谁的嫌疑最大?是船长刚刚收留的一伙吉卜赛人?是偷偷混入记者群中的神秘男子?是歌后的贴身女仆?还是行迹鬼祟的钢琴师?", 113 | "target_text": "故事中引起众多谜团的原因是?{sep_token}此部作品与以往不同的地方在于哪里?{sep_token}丁丁和船长的悠闲假期因何被打乱?{sep_token}《绿宝石失窃案》是《丁丁历险记》系列的第几部?{sep_token}《绿宝石失窃案》的作者是谁?"} 114 | ... 115 | ]} 116 | ``` 117 | ##### 训练配置文件 118 | ``` python 119 | >>> qg_config.json 120 | { 121 | "model_name_or_path": "google/mt5-small", 122 | "tokenizer_name": "google/mt5-small", 123 | "text_column": "source_text", 124 | "summary_column": "target_text", 125 | "train_file": "data/train.json", 126 | "validation_file": "data/dev.json", 127 | "output_dir": "data/qg", 128 | "model_type": "mt5", 129 | "overwrite_output_dir": true, 130 | "do_train": true, 131 | "do_eval": true, 132 | "source_prefix": "question generation: ", 133 | "predict_with_generate": true, 134 | "per_device_train_batch_size": 8, 135 | "per_device_eval_batch_size": 8, 136 | "gradient_accumulation_steps": 32, 137 | "learning_rate": 1e-3, 138 | "num_train_epochs": 4, 139 | "max_source_length": 512, 140 | "max_target_length": 200, 141 | "logging_steps": 100, 142 | "seed": 42 143 | } 144 | ``` 145 | ##### 启动训练 146 | CUDA_VISIBLE_DEVICES=0 python run_qg.py qg_config.json 147 | 148 | #### 问答模型训练 149 | ##### 训练数据格式 150 | 与squad_2的数据格式一致 151 | ``` python 152 | >>> train.json 153 | {'version': 2.0, 154 | 'data': [{'id': 'c398789b7375e0ce7eac86f2b18c3808', 155 | 'question': '隐藏式行车记录仪哪个牌子好', 156 | 'context': '推荐使用360行车记录仪。行车记录仪的好坏,取决于行车记录仪的摄像头配置,配置越高越好,再就是性价比。 行车记录仪配置需要1296p超高清摄像头比较好,这样录制视频清晰度高。再就是价格,性价比高也是可以值得考虑的。 360行车记录仪我使用了一段时间 ,觉得360行车记录仪比较好录得广角比较大,并且便宜实惠 ,价格才299,在360商城可以买到。可以参考对比下。', 157 | 'answers': {'answer_start': [4], 'text': ['360行车记录仪']}}]} 158 | ``` 159 | ##### 训练配置文件 160 | ``` python 161 | >>> qg_config.json 162 | { 163 | "model_name_or_path": "bert-base-chinese", 164 | "tokenizer_name": "bert-base-chinese", 165 | "train_file": "data/train.json", 166 | "validation_file": "data/dev.json", 167 | "output_dir": "data/qa", 168 | "per_device_train_batch_size": 8, 169 | "per_device_eval_batch_size": 8, 170 | "gradient_accumulation_steps": 32, 171 | "overwrite_output_dir": true, 172 | "do_train": true, 173 | "do_eval": true, 174 | "max_answer_length": 200 175 | } 176 | ``` 177 | ##### 启动训练 178 | CUDA_VISIBLE_DEVICES=0 python run_qa.py qa_config.json 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | 2 |

3 |

Chinese Question Generation Using MT5 Model

4 |

5 |

6 |

7 | 中文说明 | 8 | English 9 |

10 |

11 | Question Generation model by finetuning mt5 model. 12 | 13 | ## Online demos 14 | You can test the model directly on https://www.algolet.com/applications/qg?accessSource=github 15 | 16 | ## Ouick tour 17 | To immediately use models on given inputs, we provide `question_generation` and `question_answering` `pipeline` API 18 | Pipelines group together a pretrained model with the preprocessing that was used during that model's training. 19 | 20 | Here is how to quickly use a pipeline to generate questions 21 | ``` python 22 | >>> from question_generation import pipeline 23 | 24 | # Allocate a pipeline for question-generation 25 | # for cpu 26 | >>> qg = pipeline("question-generation") 27 | # device is larger than -1 when using gpu 28 | >>> qg = pipeline("question-generation", device=0) 29 | # for single text 30 | >>> qg("在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”") 31 | ['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?'] 32 | 33 | # for batch input 34 | >>> texts = ["在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。 35 | "当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。 36 | "农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”"] 37 | >>> qg(texts) 38 | [['在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', '农夫是如何看待蛇的?', '当农夫遇到蛇时,他做了什么?']] 39 | ``` 40 | To use model of your own or any of the fine-tuned model on question-generation. Here is the PyTorch version: 41 | ``` python 42 | >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 43 | >>> tokenizer = AutoTokenizer.from_pretrained("algolet/mt5-base-chinese-qg") 44 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("algolet/mt5-base-chinese-qg") 45 | >>> pipe = pipeline("question-generation", model=model, tokenizer=tokenizer) 46 | ``` 47 | 48 | Combining `question_generation` with `question_answering` 49 | so that you will have an automatic question generating ans answering application. 50 | 51 | Here is how to quickly use a pipeline to answer questions. 52 | ``` python 53 | >>> from question_generation import pipeline 54 | 55 | # Allocate a pipeline for question-generation 56 | # for cpu, default version is cpu if without device argument 57 | >>> qa = pipeline("question-answering", device="cpu") 58 | # for gpu 59 | >>> qa = pipeline("question-answering", device="cuda") 60 | >>> text = "在一个寒冷的冬天,赶集完回家的农夫在路边发现了一条冻僵了的蛇。他很可怜蛇,就把它放在怀里。当他身上的热气把蛇温暖以后,蛇很快苏醒了,露出了残忍的本性,给了农夫致命的伤害——咬了农夫一口。农夫临死之前说:“我竟然救了一条可怜的毒蛇,就应该受到这种报应啊!”" 61 | # for single qa input 62 | >>> question_answerer({ 63 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', 64 | ... 'context': text 65 | ... }) 66 | {'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0} 67 | 68 | # for batch qa inputs 69 | >>> question_answerer([ 70 | ... { 71 | ... 'question': '在寒冷的冬天,农夫在哪里发现了一条可怜的蛇?', 72 | ... 'context': text 73 | ... }, 74 | ... { 75 | ... 'question': '农夫是如何看待蛇的?', 76 | ... 'context': text 77 | ... }, 78 | ... { 79 | ... 'question': '当农夫遇到蛇时,他做了什么?', 80 | ... 'context': text 81 | ... }]) 82 | [{'answer': '路边', 'start': 18, 'end': 20, 'score': 1.0}, 83 | {'answer': '我竟然救了一条可怜的毒蛇,就应该受到这种报应', 84 | 'start': 102, 85 | 'end': 124, 86 | 'score': 0.9996}, 87 | {'answer': '放在怀里', 'start': 40, 'end': 44, 'score': 0.9995}] 88 | ``` 89 | 90 | ## Installation 91 | This repository needs pytorch>=1.3, transormfers>=4.12.5 and datasets>=1.15.1 92 | cuda torch 93 | ```bash 94 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 95 | ``` 96 | cpu torch 97 | ```bash 98 | pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html 99 | ``` 100 | transformers and datasets 101 | ```bash 102 | pip install transformers 103 | pip install datasets 104 | ``` 105 | this repository 106 | ```bash 107 | pip install question_generation 108 | ``` 109 | 110 | ## Model Training 111 | #### How to train a qg model 112 | ##### Sample data 113 | ``` python 114 | >>> train.json 115 | {"data": [{"source_text": "对于某些物理情况,不可能将力的形成归因于势的梯度。这通常是由于宏观物理的考虑,屈服力产生于微观状态的宏观统计平均值。例如,摩擦是由原子间大量静电势的梯度引起的,但表现为独立于任何宏观位置矢量的力模型。非保守力除摩擦力外,还包括其他接触力、拉力、压缩力和阻力。然而,对于任何足够详细的描述,所有这些力都是保守力的结果,因为每一个宏观力都是微观势梯度的净结果。", 116 | "target_text": "拉力、压缩和拉力是什么力?{sep_token}静电梯度电势会产生什么?{sep_token}为什么这些力是无法建模的呢?"} 117 | {"source_text": "绿宝石失窃案 (法语: Les Bijoux de la Castafiore ;英语: The Castafiore Emerald )是丁丁历险记的第21部作品。作者是比利时漫画家埃尔热。本作与之前的丁丁历险记有著很大的不同,丁丁首次进行没有离开自己家的冒险,同时故事中没有明显的反派角色,充满了喜剧色彩。丁丁和船长原本在城堡悠闲度假,却因歌后突然造访而弄得鸡飞狗跳;媒体对歌后的行踪极度关注,穷追猛打;歌后一颗珍贵的绿宝石又突然失踪,引起了一波接一波的疑团,究竟谁的嫌疑最大?是船长刚刚收留的一伙吉卜赛人?是偷偷混入记者群中的神秘男子?是歌后的贴身女仆?还是行迹鬼祟的钢琴师?", 118 | "target_text": "故事中引起众多谜团的原因是?{sep_token}此部作品与以往不同的地方在于哪里?{sep_token}丁丁和船长的悠闲假期因何被打乱?{sep_token}《绿宝石失窃案》是《丁丁历险记》系列的第几部?{sep_token}《绿宝石失窃案》的作者是谁?"} 119 | ... 120 | ]} 121 | ``` 122 | ##### Example config 123 | ``` python 124 | >>> qg_config.json 125 | { 126 | "model_name_or_path": "google/mt5-small", 127 | "tokenizer_name": "google/mt5-small", 128 | "text_column": "source_text", 129 | "summary_column": "target_text", 130 | "train_file": "data/train.json", 131 | "validation_file": "data/dev.json", 132 | "output_dir": "data/qg", 133 | "model_type": "mt5", 134 | "overwrite_output_dir": true, 135 | "do_train": true, 136 | "do_eval": true, 137 | "source_prefix": "question generation: ", 138 | "predict_with_generate": true, 139 | "per_device_train_batch_size": 8, 140 | "per_device_eval_batch_size": 8, 141 | "gradient_accumulation_steps": 32, 142 | "learning_rate": 1e-3, 143 | "num_train_epochs": 4, 144 | "max_source_length": 512, 145 | "max_target_length": 200, 146 | "logging_steps": 100, 147 | "seed": 42 148 | } 149 | ``` 150 | ##### Example command 151 | ``` 152 | CUDA_VISIBLE_DEVICES=0 python run_qg.py qg_config.json 153 | ``` 154 | 155 | 156 | #### How to train a qa model 157 | ##### Sample data 158 | ``` python 159 | >>> train.json 160 | {'version': 2.0, 161 | 'data': [{'id': 'c398789b7375e0ce7eac86f2b18c3808', 162 | 'question': '隐藏式行车记录仪哪个牌子好', 163 | 'context': '推荐使用360行车记录仪。行车记录仪的好坏,取决于行车记录仪的摄像头配置,配置越高越好,再就是性价比。 行车记录仪配置需要1296p超高清摄像头比较好,这样录制视频清晰度高。再就是价格,性价比高也是可以值得考虑的。 360行车记录仪我使用了一段时间 ,觉得360行车记录仪比较好录得广角比较大,并且便宜实惠 ,价格才299,在360商城可以买到。可以参考对比下。', 164 | 'answers': {'answer_start': [4], 'text': ['360行车记录仪']}}]} 165 | ``` 166 | ##### Example config 167 | ``` python 168 | >>> qa_config.json 169 | { 170 | "model_name_or_path": "bert-base-chinese", 171 | "tokenizer_name": "bert-base-chinese", 172 | "train_file": "data/train.json", 173 | "validation_file": "data/dev.json", 174 | "output_dir": "data/qa", 175 | "per_device_train_batch_size": 8, 176 | "per_device_eval_batch_size": 8, 177 | "gradient_accumulation_steps": 32, 178 | "overwrite_output_dir": true, 179 | "do_train": true, 180 | "do_eval": true, 181 | "max_answer_length": 200 182 | } 183 | ``` 184 | ##### Example command 185 | ``` 186 | CUDA_VISIBLE_DEVICES=0 python run_qa.py qa_config.json 187 | ``` 188 | 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /question_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Pipelines 2 | from .pipelines import pipeline 3 | -------------------------------------------------------------------------------- /question_generation/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .question_generation import QuestionGenerationPipeline 3 | from .question_answering import QuestionAnsweringPipeline 4 | from typing import Optional, Union 5 | from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | SUPPORTED_TASKS = { 10 | "question-generation": { 11 | "impl": QuestionGenerationPipeline, 12 | "pt": AutoModelForSeq2SeqLM, 13 | "default": { 14 | "model": "algolet/mt5-base-chinese-qg" 15 | } 16 | }, 17 | "question-answering": { 18 | "impl": QuestionAnsweringPipeline, 19 | "pt": AutoModelForQuestionAnswering, 20 | "default": { 21 | "model": "luhua/chinese_pretrain_mrc_macbert_large" 22 | } 23 | } 24 | } 25 | 26 | 27 | def pipeline( 28 | task: str, 29 | model: Optional = None, 30 | tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, 31 | device: str = "cpu" 32 | ): 33 | if task is None and model is None: 34 | raise RuntimeError( 35 | "Impossible to instantiate a pipeline without either a task or a model" 36 | "being specified." 37 | "Please provide a task class or a model" 38 | ) 39 | 40 | if model is None and tokenizer is not None: 41 | raise RuntimeError( 42 | "Impossible to instantiate a pipeline with tokenizer specified but not the model " 43 | "as the provided tokenizer may not be compatible with the default model. " 44 | "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing tokenizer." 45 | ) 46 | 47 | if task not in SUPPORTED_TASKS: 48 | raise KeyError(f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys())}") 49 | targeted_task = SUPPORTED_TASKS[task] 50 | pipeline_class = targeted_task["impl"] 51 | 52 | if model is None: 53 | model = targeted_task["default"]["model"] 54 | 55 | model_name = model if isinstance(model, str) else None 56 | model_classes = targeted_task["pt"] 57 | 58 | if tokenizer is None: 59 | if isinstance(model_name, str): 60 | tokenizer = model_name 61 | else: 62 | raise Exception( 63 | "Impossible to guess which tokenizer to use. " 64 | "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer." 65 | ) 66 | 67 | # Instantiate tokenizer if needed 68 | if isinstance(tokenizer, (str, tuple)): 69 | if isinstance(tokenizer, tuple): 70 | # For tuple we have (tokenizer name, {kwargs}) 71 | tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1]) 72 | else: 73 | tokenizer = AutoTokenizer.from_pretrained(tokenizer) 74 | 75 | if isinstance(model, str): 76 | model = model_classes.from_pretrained(model) 77 | return pipeline_class(model=model, tokenizer=tokenizer, device=device) 78 | -------------------------------------------------------------------------------- /question_generation/pipelines/question_answering.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | import logging 3 | import torch 4 | from .utils_qa import postprocess_qa_predictions 5 | from typing import Union, List, Dict 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class QuestionAnsweringPipeline: 11 | def __init__(self, 12 | tokenizer, 13 | model, 14 | max_seq_length: int = 384, 15 | doc_stride: int = 32, 16 | version_2_with_negative: bool = True, 17 | n_best_size: int = 1, 18 | max_answer_length: int = 30, 19 | null_score_diff_threshold: int = 0, 20 | device: str = "cpu"): 21 | self.tokenizer = tokenizer 22 | self.model = model 23 | self.max_seq_length = max_seq_length 24 | self.doc_stride = doc_stride 25 | self.question_column_name = "question" 26 | self.context_column_name = "context" 27 | self.version_2_with_negative = version_2_with_negative 28 | self.n_best_size = n_best_size 29 | self.max_answer_length = max_answer_length 30 | self.null_score_diff_threshold = null_score_diff_threshold 31 | if device == "cpu": 32 | self.device = "cpu" 33 | elif device == "cuda": 34 | self.device = torch.device("cuda") 35 | self.model = self.model.to(self.device) 36 | else: 37 | raise Exception( 38 | "device should be cup or cuda" 39 | ) 40 | self.model.eval() 41 | # Validation preprocessing 42 | 43 | def prepare_validation_features(self, examples): 44 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 45 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 46 | # left whitespace 47 | examples[self.question_column_name] = [q.lstrip() for q in examples[self.question_column_name]] 48 | 49 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 50 | # in one example possible giving several features when a context is long, each of those features having a 51 | # context that overlaps a bit the context of the previous feature. 52 | tokenized_examples = self.tokenizer( 53 | examples[self.question_column_name], 54 | examples[self.context_column_name], 55 | truncation="only_second", 56 | max_length=self.max_seq_length, 57 | stride=self.doc_stride, 58 | return_overflowing_tokens=True, 59 | return_offsets_mapping=True, 60 | padding=True 61 | ) 62 | 63 | # Since one example might give us several features if it has a long context, we need a map from a feature to 64 | # its corresponding example. This key gives us just that. 65 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 66 | 67 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 68 | # correspnding example_id and we will store the offset mappings. 69 | tokenized_examples["example_id"] = [] 70 | 71 | for i in range(len(tokenized_examples["input_ids"])): 72 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 73 | sequence_ids = tokenized_examples.sequence_ids(i) 74 | context_index = 1 75 | 76 | # One example can give several spans, this is the index of the example containing this span of text. 77 | sample_index = sample_mapping[i] 78 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 79 | 80 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token 81 | # position is part of the context or not. 82 | tokenized_examples["offset_mapping"][i] = [ 83 | (o if sequence_ids[k] == context_index else None) 84 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 85 | ] 86 | return tokenized_examples 87 | 88 | # region Metrics and Post-processing: 89 | 90 | def post_processing_function(self, examples, features, predictions, stage="eval"): 91 | # Post-processing: we match the start logits and end logits to answers in the original context. 92 | predictions = postprocess_qa_predictions( 93 | examples=examples, 94 | features=features, 95 | predictions=predictions, 96 | version_2_with_negative=self.version_2_with_negative, 97 | n_best_size=self.n_best_size, 98 | max_answer_length=self.max_answer_length, 99 | null_score_diff_threshold=self.null_score_diff_threshold, 100 | prefix=stage, 101 | ) 102 | # Format the result to the format the metric expects. 103 | formatted_predictions = [{"id": k, "prediction": v} for k, v in predictions.items()] 104 | return formatted_predictions 105 | 106 | def __call__(self, input: Union[Dict, List[Dict]], *args, **kwargs): 107 | if isinstance(input, dict): 108 | questions = [input["question"]] 109 | contexts = [input["context"]] 110 | else: 111 | questions = [item["question"] for item in input] 112 | contexts = [item["context"] for item in input] 113 | ids = [i for i in range(len(questions))] 114 | examples = {"id": ids, "question": questions, "context": contexts} 115 | processed_datasets = self.prepare_validation_features(examples) 116 | inputs = { 117 | "input_ids": torch.tensor(processed_datasets["input_ids"]), 118 | "attention_mask": torch.tensor(processed_datasets["attention_mask"]), 119 | } 120 | 121 | input_ids = inputs["input_ids"].to(self.device) 122 | input_masks = inputs["attention_mask"].to(self.device) 123 | 124 | with torch.no_grad(): 125 | eval_predictions = self.model(input_ids=input_ids, 126 | attention_mask=input_masks) 127 | examples = Dataset.from_dict(examples) 128 | flat_examples = [] 129 | for idx in range(len(examples["id"])): 130 | flat_examples.append({"id": examples["id"][idx], 131 | "question": examples["question"][idx], 132 | "context": examples["context"][idx]}) 133 | 134 | flat_processed_datasets = [] 135 | for idx in range(len(processed_datasets["input_ids"])): 136 | flat_processed_datasets.append({"input_ids": processed_datasets["input_ids"][idx], 137 | "token_type_ids": processed_datasets["token_type_ids"][idx], 138 | "attention_mask": processed_datasets["attention_mask"][idx], 139 | "offset_mapping": processed_datasets["offset_mapping"][idx], 140 | "example_id": processed_datasets["example_id"][idx]}) 141 | post_processed_eval = self.post_processing_function( 142 | flat_examples, 143 | flat_processed_datasets, 144 | (eval_predictions.start_logits.cpu().numpy(), eval_predictions.end_logits.cpu().numpy()), 145 | ) 146 | res = [] 147 | for item in post_processed_eval: 148 | if not item["prediction"]: 149 | res.append(dict()) 150 | else: 151 | if item["prediction"]["start"] == item["prediction"]["end"]: 152 | res.append(dict()) 153 | else: 154 | res.append({"answer": item["prediction"]["text"], 155 | "start": item["prediction"]["start"], 156 | "end": item["prediction"]["end"], 157 | "score": item["prediction"]["probability"]}) 158 | if isinstance(input, dict): 159 | return res[0] 160 | return res 161 | -------------------------------------------------------------------------------- /question_generation/pipelines/question_generation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union, List, Dict 3 | from torch import Tensor 4 | import torch 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class QuestionGenerationPipeline: 10 | def __init__(self, 11 | tokenizer, 12 | model, 13 | max_source_length: int = 512, 14 | max_target_length: int = 200, 15 | device: str = "cpu"): 16 | self.tokenizer = tokenizer 17 | self.model = model 18 | if device == "cpu": 19 | self.device = "cpu" 20 | elif device == "cuda": 21 | self.device = torch.device("cuda") 22 | self.model = self.model.to(self.device) 23 | else: 24 | raise Exception( 25 | "device should be cpu or cuda" 26 | ) 27 | self.model.eval() 28 | self.max_source_length = max_source_length 29 | self.max_target_length = max_target_length 30 | self.source_prefix = "question generation: " 31 | 32 | def preprocess(self, examples: List[str]) -> Dict[str, Tensor]: 33 | added_prefix = ["question generation: " + example for example in examples] 34 | inputs = self.tokenizer(added_prefix, 35 | return_tensors='pt', 36 | padding=True, 37 | truncation=True, 38 | max_length=self.max_source_length) 39 | return inputs 40 | 41 | def post_processing_function(self, outs): 42 | questions = self.tokenizer.batch_decode(outs, skip_special_tokens=True) 43 | separated_questions = [] 44 | for each_batch_questions in questions: 45 | question = [q.strip() for q in each_batch_questions.split("") if q.strip() != ""] 46 | if len(each_batch_questions) > 5 and each_batch_questions[-5:] != "": # 一个完整的问题需以结尾 47 | question = question[:-1] 48 | separated_questions.append(question) 49 | 50 | return separated_questions 51 | 52 | def __call__(self, input: Union[str, List[str]], *args, **kwargs): 53 | logger.info("*** Prediction ***") 54 | if isinstance(input, str): 55 | examples = [input] 56 | else: 57 | examples = input 58 | examples = self.preprocess(examples) 59 | input_ids = examples["input_ids"].to(self.device) 60 | input_mask = examples["attention_mask"].to(self.device) 61 | 62 | with torch.no_grad(): 63 | outs = self.model.generate(input_ids=input_ids, 64 | attention_mask=input_mask, 65 | max_length=self.max_target_length, 66 | no_repeat_ngram_size=4, 67 | num_beams=4) 68 | 69 | post_processed = self.post_processing_function(outs) 70 | if isinstance(input, str): 71 | post_processed = post_processed[0] 72 | return post_processed 73 | -------------------------------------------------------------------------------- /question_generation/pipelines/utils_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Post-processing utilities for question answering. 17 | """ 18 | import collections 19 | import json 20 | import logging 21 | import os 22 | from typing import Optional, Tuple 23 | 24 | import numpy as np 25 | from tqdm.auto import tqdm 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def postprocess_qa_predictions( 31 | examples, 32 | features, 33 | predictions: Tuple[np.ndarray, np.ndarray], 34 | version_2_with_negative: bool = False, 35 | n_best_size: int = 20, 36 | max_answer_length: int = 30, 37 | null_score_diff_threshold: float = 0.0, 38 | output_dir: Optional[str] = None, 39 | prefix: Optional[str] = None, 40 | log_level: Optional[int] = logging.WARNING, 41 | ): 42 | """ 43 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 44 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 45 | 46 | Args: 47 | examples: The non-preprocessed dataset (see the main script for more information). 48 | features: The processed dataset (see the main script for more information). 49 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 50 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 51 | first dimension must match the number of elements of :obj:`features`. 52 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 53 | Whether or not the underlying dataset contains examples with no answers. 54 | n_best_size (:obj:`int`, `optional`, defaults to 20): 55 | The total number of n-best predictions to generate when looking for an answer. 56 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 57 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 58 | are not conditioned on one another. 59 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): 60 | The threshold used to select the null answer: if the best answer has a score that is less than the score of 61 | the null answer minus this threshold, the null answer is selected for this example (note that the score of 62 | the null answer for an example giving several features is the minimum of the scores for the null answer on 63 | each feature: all features must be aligned on the fact they `want` to predict a null answer). 64 | 65 | Only useful when :obj:`version_2_with_negative` is :obj:`True`. 66 | output_dir (:obj:`str`, `optional`): 67 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 68 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 69 | answers, are saved in `output_dir`. 70 | prefix (:obj:`str`, `optional`): 71 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 72 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 73 | ``logging`` log level (e.g., ``logging.WARNING``) 74 | """ 75 | 76 | assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." 77 | all_start_logits, all_end_logits = predictions 78 | 79 | assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." 80 | 81 | # Build a map example to its corresponding features. 82 | # example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 83 | example_id_to_index = {k["id"]: i for i, k in enumerate(examples)} 84 | features_per_example = collections.defaultdict(list) 85 | 86 | for i, feature in enumerate(features): 87 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 88 | 89 | # The dictionaries we have to fill. 90 | all_predictions = collections.OrderedDict() 91 | if version_2_with_negative: 92 | scores_diff_json = collections.OrderedDict() 93 | 94 | # Logging. 95 | logger.setLevel(log_level) 96 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 97 | 98 | # Let's loop over all the examples! 99 | for example_index, example in enumerate(examples): 100 | # Those are the indices of the features associated to the current example. 101 | feature_indices = features_per_example[example_index] 102 | 103 | min_null_prediction = None 104 | prelim_predictions = [] 105 | 106 | # Looping through all the features associated to the current example. 107 | for feature_index in feature_indices: 108 | # We grab the predictions of the model for this feature. 109 | start_logits = all_start_logits[feature_index] 110 | end_logits = all_end_logits[feature_index] 111 | # This is what will allow us to map some the positions in our logits to span of texts in the original 112 | # context. 113 | offset_mapping = features[feature_index]["offset_mapping"] 114 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 115 | # available in the current feature. 116 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 117 | 118 | # Update minimum null prediction. 119 | feature_null_score = start_logits[0] + end_logits[0] 120 | if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: 121 | min_null_prediction = { 122 | "offsets": (0, 0), 123 | "score": feature_null_score, 124 | "start_logit": start_logits[0], 125 | "end_logit": end_logits[0], 126 | } 127 | 128 | # Go through all possibilities for the `n_best_size` greater start and end logits. 129 | start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist() 130 | end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist() 131 | for start_index in start_indexes: 132 | for end_index in end_indexes: 133 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 134 | # to part of the input_ids that are not in the context. 135 | if ( 136 | start_index >= len(offset_mapping) 137 | or end_index >= len(offset_mapping) 138 | or offset_mapping[start_index] is None 139 | or offset_mapping[end_index] is None 140 | ): 141 | continue 142 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 143 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 144 | continue 145 | # Don't consider answer that don't have the maximum context available (if such information is 146 | # provided). 147 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 148 | continue 149 | prelim_predictions.append( 150 | { 151 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 152 | "score": start_logits[start_index] + end_logits[end_index], 153 | "start_logit": start_logits[start_index].item(), 154 | "end_logit": end_logits[end_index].item(), 155 | } 156 | ) 157 | 158 | if version_2_with_negative: 159 | # Add the minimum null prediction 160 | prelim_predictions.append(min_null_prediction) 161 | null_score = min_null_prediction["score"] 162 | 163 | # Only keep the best `n_best_size` predictions. 164 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 165 | 166 | # Add back the minimum null prediction if it was removed because of its low score. 167 | if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): 168 | predictions.append(min_null_prediction) 169 | 170 | # Use the offsets to gather the answer text in the original context. 171 | context = example["context"] 172 | for pred in predictions: 173 | offsets = pred.pop("offsets") 174 | pred["text"] = context[offsets[0]: offsets[1]] 175 | pred["start"] = offsets[0] 176 | pred["end"] = offsets[1] 177 | 178 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 179 | # failure. 180 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): 181 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0, "start": 0, 182 | "end": 0}) 183 | 184 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 185 | # the LogSumExp trick). 186 | scores = np.array([pred.pop("score") for pred in predictions]) 187 | exp_scores = np.exp(scores - np.max(scores)) 188 | probs = exp_scores / exp_scores.sum() 189 | 190 | # Include the probabilities in our predictions. 191 | for prob, pred in zip(probs, predictions): 192 | pred["probability"] = prob.item() 193 | 194 | # Pick the best prediction. If the null answer is not possible, this is easy. 195 | if not version_2_with_negative: 196 | all_predictions[example["id"]] = predictions[0]["text"] 197 | all_predictions[example["id"]] = predictions[0] 198 | else: 199 | # Otherwise we first need to find the best non-empty prediction. 200 | i = 0 201 | while predictions[i]["text"] == "": 202 | i += 1 203 | best_non_null_pred = predictions[i] 204 | 205 | # Then we compare to the null prediction using the threshold. 206 | score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] 207 | scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. 208 | if score_diff > null_score_diff_threshold: 209 | all_predictions[example["id"]] = {} 210 | else: 211 | all_predictions[example["id"]] = best_non_null_pred 212 | return all_predictions 213 | 214 | 215 | def postprocess_qa_predictions_with_beam_search( 216 | examples, 217 | features, 218 | predictions: Tuple[np.ndarray, np.ndarray], 219 | version_2_with_negative: bool = False, 220 | n_best_size: int = 20, 221 | max_answer_length: int = 30, 222 | start_n_top: int = 5, 223 | end_n_top: int = 5, 224 | output_dir: Optional[str] = None, 225 | prefix: Optional[str] = None, 226 | log_level: Optional[int] = logging.WARNING, 227 | ): 228 | """ 229 | Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the 230 | original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as 231 | cls token predictions. 232 | 233 | Args: 234 | examples: The non-preprocessed dataset (see the main script for more information). 235 | features: The processed dataset (see the main script for more information). 236 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 237 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 238 | first dimension must match the number of elements of :obj:`features`. 239 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 240 | Whether or not the underlying dataset contains examples with no answers. 241 | n_best_size (:obj:`int`, `optional`, defaults to 20): 242 | The total number of n-best predictions to generate when looking for an answer. 243 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 244 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 245 | are not conditioned on one another. 246 | start_n_top (:obj:`int`, `optional`, defaults to 5): 247 | The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. 248 | end_n_top (:obj:`int`, `optional`, defaults to 5): 249 | The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. 250 | output_dir (:obj:`str`, `optional`): 251 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 252 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 253 | answers, are saved in `output_dir`. 254 | prefix (:obj:`str`, `optional`): 255 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 256 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 257 | ``logging`` log level (e.g., ``logging.WARNING``) 258 | """ 259 | assert len(predictions) == 5, "`predictions` should be a tuple with five elements." 260 | start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions 261 | 262 | assert len(predictions[0]) == len( 263 | features 264 | ), f"Got {len(predictions[0])} predicitions and {len(features)} features." 265 | 266 | # Build a map example to its corresponding features. 267 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 268 | features_per_example = collections.defaultdict(list) 269 | for i, feature in enumerate(features): 270 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 271 | 272 | # The dictionaries we have to fill. 273 | all_predictions = collections.OrderedDict() 274 | all_nbest_json = collections.OrderedDict() 275 | scores_diff_json = collections.OrderedDict() if version_2_with_negative else None 276 | 277 | # Logging. 278 | logger.setLevel(log_level) 279 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 280 | 281 | # Let's loop over all the examples! 282 | for example_index, example in enumerate(tqdm(examples)): 283 | # Those are the indices of the features associated to the current example. 284 | feature_indices = features_per_example[example_index] 285 | 286 | min_null_score = None 287 | prelim_predictions = [] 288 | 289 | # Looping through all the features associated to the current example. 290 | for feature_index in feature_indices: 291 | # We grab the predictions of the model for this feature. 292 | start_log_prob = start_top_log_probs[feature_index] 293 | start_indexes = start_top_index[feature_index] 294 | end_log_prob = end_top_log_probs[feature_index] 295 | end_indexes = end_top_index[feature_index] 296 | feature_null_score = cls_logits[feature_index] 297 | # This is what will allow us to map some the positions in our logits to span of texts in the original 298 | # context. 299 | offset_mapping = features[feature_index]["offset_mapping"] 300 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 301 | # available in the current feature. 302 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 303 | 304 | # Update minimum null prediction 305 | if min_null_score is None or feature_null_score < min_null_score: 306 | min_null_score = feature_null_score 307 | 308 | # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. 309 | for i in range(start_n_top): 310 | for j in range(end_n_top): 311 | start_index = int(start_indexes[i]) 312 | j_index = i * end_n_top + j 313 | end_index = int(end_indexes[j_index]) 314 | # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the 315 | # p_mask but let's not take any risk) 316 | if ( 317 | start_index >= len(offset_mapping) 318 | or end_index >= len(offset_mapping) 319 | or offset_mapping[start_index] is None 320 | or offset_mapping[end_index] is None 321 | ): 322 | continue 323 | # Don't consider answers with a length negative or > max_answer_length. 324 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 325 | continue 326 | # Don't consider answer that don't have the maximum context available (if such information is 327 | # provided). 328 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 329 | continue 330 | prelim_predictions.append( 331 | { 332 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 333 | "score": start_log_prob[i] + end_log_prob[j_index], 334 | "start_log_prob": start_log_prob[i], 335 | "end_log_prob": end_log_prob[j_index], 336 | } 337 | ) 338 | 339 | # Only keep the best `n_best_size` predictions. 340 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 341 | 342 | # Use the offsets to gather the answer text in the original context. 343 | context = example["context"] 344 | for pred in predictions: 345 | offsets = pred.pop("offsets") 346 | pred["text"] = context[offsets[0]: offsets[1]] 347 | 348 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 349 | # failure. 350 | if len(predictions) == 0: 351 | predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) 352 | 353 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 354 | # the LogSumExp trick). 355 | scores = np.array([pred.pop("score") for pred in predictions]) 356 | exp_scores = np.exp(scores - np.max(scores)) 357 | probs = exp_scores / exp_scores.sum() 358 | 359 | # Include the probabilities in our predictions. 360 | for prob, pred in zip(probs, predictions): 361 | pred["probability"] = prob 362 | 363 | # Pick the best prediction and set the probability for the null answer. 364 | all_predictions[example["id"]] = predictions[0]["text"] 365 | if version_2_with_negative: 366 | scores_diff_json[example["id"]] = float(min_null_score) 367 | 368 | # Make `predictions` JSON-serializable by casting np.float back to float. 369 | all_nbest_json[example["id"]] = [ 370 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 371 | for pred in predictions 372 | ] 373 | 374 | # If we have an output_dir, let's save all those dicts. 375 | if output_dir is not None: 376 | assert os.path.isdir(output_dir), f"{output_dir} is not a directory." 377 | 378 | prediction_file = os.path.join( 379 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" 380 | ) 381 | nbest_file = os.path.join( 382 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" 383 | ) 384 | if version_2_with_negative: 385 | null_odds_file = os.path.join( 386 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" 387 | ) 388 | 389 | logger.info(f"Saving predictions to {prediction_file}.") 390 | with open(prediction_file, "w") as writer: 391 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 392 | logger.info(f"Saving nbest_preds to {nbest_file}.") 393 | with open(nbest_file, "w") as writer: 394 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 395 | if version_2_with_negative: 396 | logger.info(f"Saving null_odds to {null_odds_file}.") 397 | with open(null_odds_file, "w") as writer: 398 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 399 | 400 | return all_predictions, scores_diff_json 401 | -------------------------------------------------------------------------------- /question_generation/run_qa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for question answering using a slightly adapted version of the 🤗 Trainer. 18 | """ 19 | # You can also adapt this script on your own question answering task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | from datasets import load_dataset, load_metric 29 | 30 | import transformers 31 | from trainer_qa import QuestionAnsweringTrainer 32 | from transformers import ( 33 | AutoConfig, 34 | AutoModelForQuestionAnswering, 35 | AutoTokenizer, 36 | DataCollatorWithPadding, 37 | EvalPrediction, 38 | HfArgumentParser, 39 | PreTrainedTokenizerFast, 40 | TrainingArguments, 41 | default_data_collator, 42 | set_seed, 43 | ) 44 | from transformers.trainer_utils import get_last_checkpoint 45 | from utils_qa import postprocess_qa_predictions 46 | 47 | os.environ["WANDB_DISABLED"] = "true" 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | 52 | @dataclass 53 | class ModelArguments: 54 | """ 55 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 56 | """ 57 | 58 | model_name_or_path: str = field( 59 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 60 | ) 61 | config_name: Optional[str] = field( 62 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 63 | ) 64 | tokenizer_name: Optional[str] = field( 65 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 66 | ) 67 | cache_dir: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, 70 | ) 71 | model_revision: str = field( 72 | default="main", 73 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 74 | ) 75 | use_auth_token: bool = field( 76 | default=False, 77 | metadata={ 78 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 79 | "with private models)." 80 | }, 81 | ) 82 | 83 | 84 | @dataclass 85 | class DataTrainingArguments: 86 | """ 87 | Arguments pertaining to what data we are going to input our model for training and eval. 88 | """ 89 | 90 | dataset_name: Optional[str] = field( 91 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 92 | ) 93 | dataset_config_name: Optional[str] = field( 94 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 95 | ) 96 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 97 | validation_file: Optional[str] = field( 98 | default=None, 99 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 100 | ) 101 | test_file: Optional[str] = field( 102 | default=None, 103 | metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, 104 | ) 105 | overwrite_cache: bool = field( 106 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 107 | ) 108 | preprocessing_num_workers: Optional[int] = field( 109 | default=None, 110 | metadata={"help": "The number of processes to use for the preprocessing."}, 111 | ) 112 | max_seq_length: int = field( 113 | default=384, 114 | metadata={ 115 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 116 | "than this will be truncated, sequences shorter will be padded." 117 | }, 118 | ) 119 | pad_to_max_length: bool = field( 120 | default=True, 121 | metadata={ 122 | "help": "Whether to pad all samples to `max_seq_length`. " 123 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " 124 | "be faster on GPU but will be slower on TPU)." 125 | }, 126 | ) 127 | max_train_samples: Optional[int] = field( 128 | default=None, 129 | metadata={ 130 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 131 | "value if set." 132 | }, 133 | ) 134 | max_eval_samples: Optional[int] = field( 135 | default=None, 136 | metadata={ 137 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 138 | "value if set." 139 | }, 140 | ) 141 | max_predict_samples: Optional[int] = field( 142 | default=None, 143 | metadata={ 144 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 145 | "value if set." 146 | }, 147 | ) 148 | version_2_with_negative: bool = field( 149 | default=True, metadata={"help": "If true, some of the examples do not have an answer."} 150 | ) 151 | null_score_diff_threshold: float = field( 152 | default=0.0, 153 | metadata={ 154 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than " 155 | "the score of the null answer minus this threshold, the null answer is selected for this example. " 156 | "Only useful when `version_2_with_negative=True`." 157 | }, 158 | ) 159 | doc_stride: int = field( 160 | default=128, 161 | metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, 162 | ) 163 | n_best_size: int = field( 164 | default=20, 165 | metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, 166 | ) 167 | max_answer_length: int = field( 168 | default=30, 169 | metadata={ 170 | "help": "The maximum length of an answer that can be generated. This is needed because the start " 171 | "and end predictions are not conditioned on one another." 172 | }, 173 | ) 174 | 175 | def __post_init__(self): 176 | if ( 177 | self.dataset_name is None 178 | and self.train_file is None 179 | and self.validation_file is None 180 | and self.test_file is None 181 | ): 182 | raise ValueError("Need either a dataset name or a training/validation file/test_file.") 183 | else: 184 | if self.train_file is not None: 185 | extension = self.train_file.split(".")[-1] 186 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 187 | if self.validation_file is not None: 188 | extension = self.validation_file.split(".")[-1] 189 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 190 | if self.test_file is not None: 191 | extension = self.test_file.split(".")[-1] 192 | assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." 193 | 194 | 195 | def main(): 196 | # See all possible arguments in src/transformers/training_args.py 197 | # or by passing the --help flag to this script. 198 | # We now keep distinct sets of args, for a cleaner separation of concerns. 199 | 200 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 201 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 202 | # If we pass only one argument to the script and it's the path to a json file, 203 | # let's parse it to get our arguments. 204 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 205 | else: 206 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 207 | 208 | # Setup logging 209 | logging.basicConfig( 210 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 211 | datefmt="%m/%d/%Y %H:%M:%S", 212 | handlers=[logging.StreamHandler(sys.stdout)], 213 | ) 214 | 215 | log_level = training_args.get_process_log_level() 216 | logger.setLevel(log_level) 217 | datasets.utils.logging.set_verbosity(log_level) 218 | transformers.utils.logging.set_verbosity(log_level) 219 | transformers.utils.logging.enable_default_handler() 220 | transformers.utils.logging.enable_explicit_format() 221 | 222 | # Log on each process the small summary: 223 | logger.warning( 224 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 225 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 226 | ) 227 | logger.info(f"Training/evaluation parameters {training_args}") 228 | 229 | # Detecting last checkpoint. 230 | last_checkpoint = None 231 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 232 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 233 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 234 | raise ValueError( 235 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 236 | "Use --overwrite_output_dir to overcome." 237 | ) 238 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 239 | logger.info( 240 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 241 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 242 | ) 243 | 244 | # Set seed before initializing model. 245 | set_seed(training_args.seed) 246 | 247 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 248 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 249 | # (the dataset will be downloaded automatically from the datasets Hub). 250 | # 251 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 252 | # 'text' is found. You can easily tweak this behavior (see below). 253 | # 254 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 255 | # download the dataset. 256 | if data_args.dataset_name is not None: 257 | # Downloading and loading a dataset from the hub. 258 | raw_datasets = load_dataset( 259 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 260 | ) 261 | else: 262 | data_files = {} 263 | if data_args.train_file is not None: 264 | data_files["train"] = data_args.train_file 265 | extension = data_args.train_file.split(".")[-1] 266 | 267 | if data_args.validation_file is not None: 268 | data_files["validation"] = data_args.validation_file 269 | extension = data_args.validation_file.split(".")[-1] 270 | if data_args.test_file is not None: 271 | data_files["test"] = data_args.test_file 272 | extension = data_args.test_file.split(".")[-1] 273 | raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) 274 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 275 | # https://huggingface.co/docs/datasets/loading_datasets.html. 276 | 277 | # Load pretrained model and tokenizer 278 | # 279 | # Distributed training: 280 | # The .from_pretrained methods guarantee that only one local process can concurrently 281 | # download model & vocab. 282 | config = AutoConfig.from_pretrained( 283 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 284 | cache_dir=model_args.cache_dir, 285 | revision=model_args.model_revision, 286 | use_auth_token=True if model_args.use_auth_token else None, 287 | ) 288 | tokenizer = AutoTokenizer.from_pretrained( 289 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 290 | cache_dir=model_args.cache_dir, 291 | use_fast=True, 292 | revision=model_args.model_revision, 293 | use_auth_token=True if model_args.use_auth_token else None, 294 | ) 295 | model = AutoModelForQuestionAnswering.from_pretrained( 296 | model_args.model_name_or_path, 297 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 298 | config=config, 299 | cache_dir=model_args.cache_dir, 300 | revision=model_args.model_revision, 301 | use_auth_token=True if model_args.use_auth_token else None, 302 | ) 303 | 304 | # Tokenizer check: this script requires a fast tokenizer. 305 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 306 | raise ValueError( 307 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 308 | "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " 309 | "requirement" 310 | ) 311 | 312 | # Preprocessing the datasets. 313 | # Preprocessing is slighlty different for training and evaluation. 314 | if training_args.do_train: 315 | column_names = raw_datasets["train"].column_names 316 | elif training_args.do_eval: 317 | column_names = raw_datasets["validation"].column_names 318 | else: 319 | column_names = raw_datasets["test"].column_names 320 | question_column_name = "question" if "question" in column_names else column_names[0] 321 | context_column_name = "context" if "context" in column_names else column_names[1] 322 | answer_column_name = "answers" if "answers" in column_names else column_names[2] 323 | 324 | # Padding side determines if we do (question|context) or (context|question). 325 | pad_on_right = tokenizer.padding_side == "right" 326 | 327 | if data_args.max_seq_length > tokenizer.model_max_length: 328 | logger.warning( 329 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 330 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 331 | ) 332 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 333 | 334 | # Training preprocessing 335 | def prepare_train_features(examples): 336 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 337 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 338 | # left whitespace 339 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] 340 | 341 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 342 | # in one example possible giving several features when a context is long, each of those features having a 343 | # context that overlaps a bit the context of the previous feature. 344 | tokenized_examples = tokenizer( 345 | examples[question_column_name if pad_on_right else context_column_name], 346 | examples[context_column_name if pad_on_right else question_column_name], 347 | truncation="only_second" if pad_on_right else "only_first", 348 | max_length=max_seq_length, 349 | stride=data_args.doc_stride, 350 | return_overflowing_tokens=True, 351 | return_offsets_mapping=True, 352 | padding="max_length" if data_args.pad_to_max_length else False, 353 | ) 354 | 355 | # Since one example might give us several features if it has a long context, we need a map from a feature to 356 | # its corresponding example. This key gives us just that. 357 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 358 | # The offset mappings will give us a map from token to character position in the original context. This will 359 | # help us compute the start_positions and end_positions. 360 | offset_mapping = tokenized_examples.pop("offset_mapping") 361 | 362 | # Let's label those examples! 363 | tokenized_examples["start_positions"] = [] 364 | tokenized_examples["end_positions"] = [] 365 | 366 | for i, offsets in enumerate(offset_mapping): 367 | # We will label impossible answers with the index of the CLS token. 368 | input_ids = tokenized_examples["input_ids"][i] 369 | cls_index = input_ids.index(tokenizer.cls_token_id) 370 | 371 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 372 | sequence_ids = tokenized_examples.sequence_ids(i) 373 | 374 | # One example can give several spans, this is the index of the example containing this span of text. 375 | sample_index = sample_mapping[i] 376 | answers = examples[answer_column_name][sample_index] 377 | # If no answers are given, set the cls_index as answer. 378 | if len(answers["answer_start"]) == 0: 379 | tokenized_examples["start_positions"].append(cls_index) 380 | tokenized_examples["end_positions"].append(cls_index) 381 | else: 382 | # Start/end character index of the answer in the text. 383 | start_char = answers["answer_start"][0] 384 | end_char = start_char + len(answers["text"][0]) 385 | 386 | # Start token index of the current span in the text. 387 | token_start_index = 0 388 | while sequence_ids[token_start_index] != (1 if pad_on_right else 0): 389 | token_start_index += 1 390 | 391 | # End token index of the current span in the text. 392 | token_end_index = len(input_ids) - 1 393 | while sequence_ids[token_end_index] != (1 if pad_on_right else 0): 394 | token_end_index -= 1 395 | 396 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). 397 | if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): 398 | tokenized_examples["start_positions"].append(cls_index) 399 | tokenized_examples["end_positions"].append(cls_index) 400 | else: 401 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer. 402 | # Note: we could go after the last offset if the answer is the last word (edge case). 403 | while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: 404 | token_start_index += 1 405 | tokenized_examples["start_positions"].append(token_start_index - 1) 406 | while offsets[token_end_index][1] >= end_char: 407 | token_end_index -= 1 408 | tokenized_examples["end_positions"].append(token_end_index + 1) 409 | 410 | return tokenized_examples 411 | 412 | if training_args.do_train: 413 | if "train" not in raw_datasets: 414 | raise ValueError("--do_train requires a train dataset") 415 | train_dataset = raw_datasets["train"] 416 | if data_args.max_train_samples is not None: 417 | # We will select sample from whole data if argument is specified 418 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 419 | # Create train feature from dataset 420 | with training_args.main_process_first(desc="train dataset map pre-processing"): 421 | train_dataset = train_dataset.map( 422 | prepare_train_features, 423 | batched=True, 424 | num_proc=data_args.preprocessing_num_workers, 425 | remove_columns=column_names, 426 | load_from_cache_file=not data_args.overwrite_cache, 427 | desc="Running tokenizer on train dataset", 428 | ) 429 | if data_args.max_train_samples is not None: 430 | # Number of samples might increase during Feature Creation, We select only specified max samples 431 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 432 | 433 | # Validation preprocessing 434 | def prepare_validation_features(examples): 435 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 436 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 437 | # left whitespace 438 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] 439 | 440 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 441 | # in one example possible giving several features when a context is long, each of those features having a 442 | # context that overlaps a bit the context of the previous feature. 443 | tokenized_examples = tokenizer( 444 | examples[question_column_name if pad_on_right else context_column_name], 445 | examples[context_column_name if pad_on_right else question_column_name], 446 | truncation="only_second" if pad_on_right else "only_first", 447 | max_length=max_seq_length, 448 | stride=data_args.doc_stride, 449 | return_overflowing_tokens=True, 450 | return_offsets_mapping=True, 451 | padding="max_length" if data_args.pad_to_max_length else False, 452 | ) 453 | 454 | # Since one example might give us several features if it has a long context, we need a map from a feature to 455 | # its corresponding example. This key gives us just that. 456 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 457 | 458 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 459 | # corresponding example_id and we will store the offset mappings. 460 | tokenized_examples["example_id"] = [] 461 | 462 | for i in range(len(tokenized_examples["input_ids"])): 463 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 464 | sequence_ids = tokenized_examples.sequence_ids(i) 465 | context_index = 1 if pad_on_right else 0 466 | 467 | # One example can give several spans, this is the index of the example containing this span of text. 468 | sample_index = sample_mapping[i] 469 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 470 | 471 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token 472 | # position is part of the context or not. 473 | tokenized_examples["offset_mapping"][i] = [ 474 | (o if sequence_ids[k] == context_index else None) 475 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 476 | ] 477 | 478 | return tokenized_examples 479 | 480 | if training_args.do_eval: 481 | if "validation" not in raw_datasets: 482 | raise ValueError("--do_eval requires a validation dataset") 483 | eval_examples = raw_datasets["validation"] 484 | if data_args.max_eval_samples is not None: 485 | # We will select sample from whole data 486 | eval_examples = eval_examples.select(range(data_args.max_eval_samples)) 487 | # Validation Feature Creation 488 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 489 | eval_dataset = eval_examples.map( 490 | prepare_validation_features, 491 | batched=True, 492 | num_proc=data_args.preprocessing_num_workers, 493 | remove_columns=column_names, 494 | load_from_cache_file=not data_args.overwrite_cache, 495 | desc="Running tokenizer on validation dataset", 496 | ) 497 | if data_args.max_eval_samples is not None: 498 | # During Feature creation dataset samples might increase, we will select required samples again 499 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 500 | 501 | if training_args.do_predict: 502 | if "test" not in raw_datasets: 503 | raise ValueError("--do_predict requires a test dataset") 504 | predict_examples = raw_datasets["test"] 505 | if data_args.max_predict_samples is not None: 506 | # We will select sample from whole data 507 | predict_examples = predict_examples.select(range(data_args.max_predict_samples)) 508 | # Predict Feature Creation 509 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 510 | predict_dataset = predict_examples.map( 511 | prepare_validation_features, 512 | batched=True, 513 | num_proc=data_args.preprocessing_num_workers, 514 | remove_columns=column_names, 515 | load_from_cache_file=not data_args.overwrite_cache, 516 | desc="Running tokenizer on prediction dataset", 517 | ) 518 | if data_args.max_predict_samples is not None: 519 | # During Feature creation dataset samples might increase, we will select required samples again 520 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 521 | 522 | # Data collator 523 | # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data 524 | # collator. 525 | data_collator = ( 526 | default_data_collator 527 | if data_args.pad_to_max_length 528 | else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 529 | ) 530 | 531 | # Post-processing: 532 | def post_processing_function(examples, features, predictions, stage="eval"): 533 | # Post-processing: we match the start logits and end logits to answers in the original context. 534 | predictions = postprocess_qa_predictions( 535 | examples=examples, 536 | features=features, 537 | predictions=predictions, 538 | version_2_with_negative=data_args.version_2_with_negative, 539 | n_best_size=data_args.n_best_size, 540 | max_answer_length=data_args.max_answer_length, 541 | null_score_diff_threshold=data_args.null_score_diff_threshold, 542 | output_dir=training_args.output_dir, 543 | log_level=log_level, 544 | prefix=stage, 545 | ) 546 | # Format the result to the format the metric expects. 547 | if data_args.version_2_with_negative: 548 | formatted_predictions = [ 549 | {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() 550 | ] 551 | else: 552 | formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] 553 | 554 | references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] 555 | return EvalPrediction(predictions=formatted_predictions, label_ids=references) 556 | 557 | metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") 558 | 559 | def compute_metrics(p: EvalPrediction): 560 | return metric.compute(predictions=p.predictions, references=p.label_ids) 561 | 562 | # Initialize our Trainer 563 | trainer = QuestionAnsweringTrainer( 564 | model=model, 565 | args=training_args, 566 | train_dataset=train_dataset if training_args.do_train else None, 567 | eval_dataset=eval_dataset if training_args.do_eval else None, 568 | eval_examples=eval_examples if training_args.do_eval else None, 569 | tokenizer=tokenizer, 570 | data_collator=data_collator, 571 | post_process_function=post_processing_function, 572 | compute_metrics=compute_metrics, 573 | ) 574 | 575 | # Training 576 | if training_args.do_train: 577 | checkpoint = None 578 | if training_args.resume_from_checkpoint is not None: 579 | checkpoint = training_args.resume_from_checkpoint 580 | elif last_checkpoint is not None: 581 | checkpoint = last_checkpoint 582 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 583 | trainer.save_model() # Saves the tokenizer too for easy upload 584 | 585 | metrics = train_result.metrics 586 | max_train_samples = ( 587 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 588 | ) 589 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 590 | 591 | trainer.log_metrics("train", metrics) 592 | trainer.save_metrics("train", metrics) 593 | trainer.save_state() 594 | 595 | # Evaluation 596 | if training_args.do_eval: 597 | logger.info("*** Evaluate ***") 598 | metrics = trainer.evaluate() 599 | 600 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 601 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 602 | 603 | trainer.log_metrics("eval", metrics) 604 | trainer.save_metrics("eval", metrics) 605 | 606 | # Prediction 607 | if training_args.do_predict: 608 | logger.info("*** Predict ***") 609 | results = trainer.predict(predict_dataset, predict_examples) 610 | metrics = results.metrics 611 | 612 | max_predict_samples = ( 613 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 614 | ) 615 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 616 | 617 | trainer.log_metrics("predict", metrics) 618 | trainer.save_metrics("predict", metrics) 619 | 620 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} 621 | if data_args.dataset_name is not None: 622 | kwargs["dataset_tags"] = data_args.dataset_name 623 | if data_args.dataset_config_name is not None: 624 | kwargs["dataset_args"] = data_args.dataset_config_name 625 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 626 | else: 627 | kwargs["dataset"] = data_args.dataset_name 628 | 629 | if training_args.push_to_hub: 630 | trainer.push_to_hub(**kwargs) 631 | else: 632 | trainer.create_model_card(**kwargs) 633 | 634 | 635 | if __name__ == "__main__": 636 | main() 637 | -------------------------------------------------------------------------------- /question_generation/run_qg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for question generation. 18 | """ 19 | import logging 20 | import os 21 | import sys 22 | from dataclasses import dataclass, field 23 | from typing import Optional 24 | 25 | import datasets 26 | from datasets import load_dataset 27 | import numpy as np 28 | 29 | import transformers 30 | from transformers import ( 31 | AutoConfig, 32 | AutoModelForSeq2SeqLM, 33 | AutoTokenizer, 34 | DataCollatorForSeq2Seq, 35 | HfArgumentParser, 36 | Seq2SeqTrainer, 37 | Seq2SeqTrainingArguments, 38 | set_seed, 39 | ) 40 | from rouge import Rouge 41 | from transformers.trainer_utils import get_last_checkpoint 42 | 43 | os.environ["WANDB_DISABLED"] = "true" 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | @dataclass 49 | class ModelArguments: 50 | """ 51 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 52 | """ 53 | 54 | model_name_or_path: str = field( 55 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 56 | ) 57 | config_name: Optional[str] = field( 58 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 59 | ) 60 | tokenizer_name: Optional[str] = field( 61 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 62 | ) 63 | cache_dir: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 66 | ) 67 | use_fast_tokenizer: bool = field( 68 | default=True, 69 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 70 | ) 71 | model_revision: str = field( 72 | default="main", 73 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 74 | ) 75 | use_auth_token: bool = field( 76 | default=False, 77 | metadata={ 78 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 79 | "with private models)." 80 | }, 81 | ) 82 | resize_position_embeddings: Optional[bool] = field( 83 | default=None, 84 | metadata={ 85 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 86 | "the model's position embeddings." 87 | }, 88 | ) 89 | 90 | 91 | @dataclass 92 | class DataTrainingArguments: 93 | """ 94 | Arguments pertaining to what data we are going to input our model for training and eval. 95 | """ 96 | 97 | dataset_name: Optional[str] = field( 98 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 99 | ) 100 | dataset_config_name: Optional[str] = field( 101 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 102 | ) 103 | text_column: Optional[str] = field( 104 | default=None, 105 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 106 | ) 107 | summary_column: Optional[str] = field( 108 | default=None, 109 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 110 | ) 111 | train_file: Optional[str] = field( 112 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 113 | ) 114 | validation_file: Optional[str] = field( 115 | default=None, 116 | metadata={ 117 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 118 | "(a jsonlines or csv file)." 119 | }, 120 | ) 121 | test_file: Optional[str] = field( 122 | default=None, 123 | metadata={ 124 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 125 | }, 126 | ) 127 | overwrite_cache: bool = field( 128 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 129 | ) 130 | preprocessing_num_workers: Optional[int] = field( 131 | default=None, 132 | metadata={"help": "The number of processes to use for the preprocessing."}, 133 | ) 134 | max_source_length: Optional[int] = field( 135 | default=1024, 136 | metadata={ 137 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 138 | "than this will be truncated, sequences shorter will be padded." 139 | }, 140 | ) 141 | max_target_length: Optional[int] = field( 142 | default=128, 143 | metadata={ 144 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 145 | "than this will be truncated, sequences shorter will be padded." 146 | }, 147 | ) 148 | val_max_target_length: Optional[int] = field( 149 | default=None, 150 | metadata={ 151 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 152 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 153 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 154 | "during ``evaluate`` and ``predict``." 155 | }, 156 | ) 157 | pad_to_max_length: bool = field( 158 | default=False, 159 | metadata={ 160 | "help": "Whether to pad all samples to model maximum sentence length. " 161 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 162 | "efficient on GPU but very bad for TPU." 163 | }, 164 | ) 165 | max_train_samples: Optional[int] = field( 166 | default=None, 167 | metadata={ 168 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 169 | "value if set." 170 | }, 171 | ) 172 | max_eval_samples: Optional[int] = field( 173 | default=None, 174 | metadata={ 175 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 176 | "value if set." 177 | }, 178 | ) 179 | max_predict_samples: Optional[int] = field( 180 | default=None, 181 | metadata={ 182 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 183 | "value if set." 184 | }, 185 | ) 186 | num_beams: Optional[int] = field( 187 | default=None, 188 | metadata={ 189 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 190 | "which is used during ``evaluate`` and ``predict``." 191 | }, 192 | ) 193 | ignore_pad_token_for_loss: bool = field( 194 | default=True, 195 | metadata={ 196 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 197 | }, 198 | ) 199 | source_prefix: Optional[str] = field( 200 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 201 | ) 202 | 203 | def __post_init__(self): 204 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 205 | raise ValueError("Need either a dataset name or a training/validation file.") 206 | else: 207 | if self.train_file is not None: 208 | extension = self.train_file.split(".")[-1] 209 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 210 | if self.validation_file is not None: 211 | extension = self.validation_file.split(".")[-1] 212 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 213 | if self.val_max_target_length is None: 214 | self.val_max_target_length = self.max_target_length 215 | 216 | 217 | def main(): 218 | # See all possible arguments in src/transformers/training_args.py 219 | # or by passing the --help flag to this script. 220 | # We now keep distinct sets of args, for a cleaner separation of concerns. 221 | 222 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 223 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 224 | # If we pass only one argument to the script and it's the path to a json file, 225 | # let's parse it to get our arguments. 226 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 227 | else: 228 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 229 | 230 | # Setup logging 231 | logging.basicConfig( 232 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 233 | datefmt="%m/%d/%Y %H:%M:%S", 234 | handlers=[logging.StreamHandler(sys.stdout)], 235 | ) 236 | log_level = training_args.get_process_log_level() 237 | logger.setLevel(log_level) 238 | datasets.utils.logging.set_verbosity(log_level) 239 | transformers.utils.logging.set_verbosity(log_level) 240 | transformers.utils.logging.enable_default_handler() 241 | transformers.utils.logging.enable_explicit_format() 242 | 243 | # Log on each process the small summary: 244 | logger.warning( 245 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 246 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 247 | ) 248 | logger.info(f"Training/evaluation parameters {training_args}") 249 | # Detecting last checkpoint. 250 | last_checkpoint = None 251 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 252 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 253 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 254 | raise ValueError( 255 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 256 | "Use --overwrite_output_dir to overcome." 257 | ) 258 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 259 | logger.info( 260 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 261 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 262 | ) 263 | 264 | # Set seed before initializing model. 265 | set_seed(training_args.seed) 266 | 267 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 268 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 269 | # (the dataset will be downloaded automatically from the datasets Hub). 270 | # 271 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 272 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 273 | # 274 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 275 | # download the dataset. 276 | if data_args.dataset_name is not None: 277 | # Downloading and loading a dataset from the hub. 278 | raw_datasets = load_dataset( 279 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 280 | ) 281 | else: 282 | data_files = {} 283 | if data_args.train_file is not None: 284 | data_files["train"] = data_args.train_file 285 | extension = data_args.train_file.split(".")[-1] 286 | if data_args.validation_file is not None: 287 | data_files["validation"] = data_args.validation_file 288 | extension = data_args.validation_file.split(".")[-1] 289 | if data_args.test_file is not None: 290 | data_files["test"] = data_args.test_file 291 | extension = data_args.test_file.split(".")[-1] 292 | raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) 293 | 294 | # Load pretrained model and tokenizer 295 | config = AutoConfig.from_pretrained( 296 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 297 | cache_dir=model_args.cache_dir, 298 | revision=model_args.model_revision, 299 | use_auth_token=True if model_args.use_auth_token else None, 300 | ) 301 | tokenizer = AutoTokenizer.from_pretrained( 302 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 303 | cache_dir=model_args.cache_dir, 304 | use_fast=model_args.use_fast_tokenizer, 305 | revision=model_args.model_revision, 306 | use_auth_token=True if model_args.use_auth_token else None, 307 | ) 308 | tokenizer.add_tokens("") 309 | model = AutoModelForSeq2SeqLM.from_pretrained( 310 | model_args.model_name_or_path, 311 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 312 | config=config, 313 | cache_dir=model_args.cache_dir, 314 | revision=model_args.model_revision, 315 | use_auth_token=True if model_args.use_auth_token else None, 316 | ) 317 | 318 | model.resize_token_embeddings(len(tokenizer)) 319 | 320 | if model.config.decoder_start_token_id is None: 321 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 322 | 323 | if ( 324 | hasattr(model.config, "max_position_embeddings") 325 | and model.config.max_position_embeddings < data_args.max_source_length 326 | ): 327 | if model_args.resize_position_embeddings is None: 328 | logger.warning( 329 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 330 | f"to {data_args.max_source_length}." 331 | ) 332 | model.resize_position_embeddings(data_args.max_source_length) 333 | elif model_args.resize_position_embeddings: 334 | model.resize_position_embeddings(data_args.max_source_length) 335 | else: 336 | raise ValueError( 337 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 338 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 339 | "resize the model's position encodings by passing `--resize_position_embeddings`." 340 | ) 341 | 342 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 343 | 344 | # Preprocessing the datasets. 345 | # We need to tokenize inputs and targets. 346 | if training_args.do_train: 347 | column_names = raw_datasets["train"].column_names 348 | elif training_args.do_eval: 349 | column_names = raw_datasets["validation"].column_names 350 | elif training_args.do_predict: 351 | column_names = raw_datasets["test"].column_names 352 | else: 353 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 354 | return 355 | 356 | # Get the column names for input/target. 357 | text_column = data_args.text_column 358 | if text_column not in column_names: 359 | raise ValueError( 360 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 361 | ) 362 | summary_column = data_args.summary_column 363 | if summary_column not in column_names: 364 | raise ValueError( 365 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 366 | ) 367 | 368 | # Temporarily set max_target_length for training. 369 | max_target_length = data_args.max_target_length 370 | padding = "max_length" if data_args.pad_to_max_length else False 371 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 372 | logger.warning( 373 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 374 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 375 | ) 376 | 377 | def preprocess_function(examples): 378 | inputs = examples[text_column] 379 | targets = examples[summary_column] 380 | inputs = [prefix + inp for inp in inputs] 381 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 382 | 383 | # Setup the tokenizer for targets 384 | with tokenizer.as_target_tokenizer(): 385 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 386 | 387 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 388 | # padding in the loss. 389 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 390 | labels["input_ids"] = [ 391 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 392 | ] 393 | 394 | model_inputs["labels"] = labels["input_ids"] 395 | return model_inputs 396 | 397 | def _add_special_tokens(example): 398 | example['target_text'] = example['target_text'] + "{sep_token}" 399 | example['target_text'] = example['target_text'].replace("{sep_token}", "") 400 | return example 401 | 402 | if training_args.do_train: 403 | if "train" not in raw_datasets: 404 | raise ValueError("--do_train requires a train dataset") 405 | train_dataset = raw_datasets["train"] 406 | if data_args.max_train_samples is not None: 407 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 408 | with training_args.main_process_first(desc="train dataset map pre-processing"): 409 | train_dataset = train_dataset.map(_add_special_tokens) 410 | train_dataset = train_dataset.map( 411 | preprocess_function, 412 | batched=True, 413 | num_proc=data_args.preprocessing_num_workers, 414 | remove_columns=column_names, 415 | load_from_cache_file=not data_args.overwrite_cache, 416 | desc="Running tokenizer on train dataset", 417 | ) 418 | 419 | if training_args.do_eval: 420 | max_target_length = data_args.val_max_target_length 421 | if "validation" not in raw_datasets: 422 | raise ValueError("--do_eval requires a validation dataset") 423 | eval_dataset = raw_datasets["validation"] 424 | if data_args.max_eval_samples is not None: 425 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 426 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 427 | eval_dataset = eval_dataset.map(_add_special_tokens) 428 | eval_dataset = eval_dataset.map( 429 | preprocess_function, 430 | batched=True, 431 | num_proc=data_args.preprocessing_num_workers, 432 | remove_columns=column_names, 433 | load_from_cache_file=not data_args.overwrite_cache, 434 | desc="Running tokenizer on validation dataset", 435 | ) 436 | 437 | if training_args.do_predict: 438 | max_target_length = data_args.val_max_target_length 439 | if "test" not in raw_datasets: 440 | raise ValueError("--do_predict requires a test dataset") 441 | predict_dataset = raw_datasets["test"] 442 | if data_args.max_predict_samples is not None: 443 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 444 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 445 | predict_dataset = predict_dataset.map(_add_special_tokens) 446 | predict_dataset = predict_dataset.map( 447 | preprocess_function, 448 | batched=True, 449 | num_proc=data_args.preprocessing_num_workers, 450 | remove_columns=column_names, 451 | load_from_cache_file=not data_args.overwrite_cache, 452 | desc="Running tokenizer on prediction dataset", 453 | ) 454 | 455 | # Data collator 456 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 457 | data_collator = DataCollatorForSeq2Seq( 458 | tokenizer, 459 | model=model, 460 | label_pad_token_id=label_pad_token_id, 461 | pad_to_multiple_of=8 if training_args.fp16 else None, 462 | ) 463 | 464 | # Metric 465 | def postprocess_text(preds, labels): 466 | preds = [pred.strip() for pred in preds] 467 | labels = [label.strip() for label in labels] 468 | preds = ["\n".join(pred.split("")) for pred in preds] 469 | labels = ["\n".join(label.split("")) for label in labels] 470 | return preds, labels 471 | 472 | def compute_metrics(eval_preds): 473 | metric = Rouge() 474 | preds, labels = eval_preds 475 | if isinstance(preds, tuple): 476 | preds = preds[0] 477 | decoded_preds = [" ".join(tokenizer.convert_ids_to_tokens(pred)) for pred in preds] 478 | if data_args.ignore_pad_token_for_loss: 479 | # Replace -100 in the labels as we can't decode them. 480 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 481 | decoded_labels = [" ".join(tokenizer.convert_ids_to_tokens(label)) for label in labels] 482 | decode_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 483 | result = metric.get_scores(decoded_preds, decoded_labels, avg=True) 484 | result = {k: round(v["f"], 4) for k, v in result.items()} 485 | return result 486 | 487 | # Initialize our Trainer 488 | trainer = Seq2SeqTrainer( 489 | model=model, 490 | args=training_args, 491 | train_dataset=train_dataset if training_args.do_train else None, 492 | eval_dataset=eval_dataset if training_args.do_eval else None, 493 | tokenizer=tokenizer, 494 | data_collator=data_collator, 495 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 496 | ) 497 | 498 | # Training 499 | if training_args.do_train: 500 | checkpoint = None 501 | if training_args.resume_from_checkpoint is not None: 502 | checkpoint = training_args.resume_from_checkpoint 503 | elif last_checkpoint is not None: 504 | checkpoint = last_checkpoint 505 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 506 | trainer.save_model() # Saves the tokenizer too for easy upload 507 | 508 | metrics = train_result.metrics 509 | max_train_samples = ( 510 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 511 | ) 512 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 513 | 514 | trainer.log_metrics("train", metrics) 515 | trainer.save_metrics("train", metrics) 516 | trainer.save_state() 517 | 518 | # Evaluation 519 | results = {} 520 | max_length = ( 521 | training_args.generation_max_length 522 | if training_args.generation_max_length is not None 523 | else data_args.val_max_target_length 524 | ) 525 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 526 | if training_args.do_eval: 527 | logger.info("*** Evaluate ***") 528 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 529 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 530 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 531 | 532 | trainer.log_metrics("eval", metrics) 533 | trainer.save_metrics("eval", metrics) 534 | 535 | if training_args.do_predict: 536 | logger.info("*** Predict ***") 537 | 538 | predict_results = trainer.predict( 539 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 540 | ) 541 | metrics = predict_results.metrics 542 | max_predict_samples = ( 543 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 544 | ) 545 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 546 | 547 | trainer.log_metrics("predict", metrics) 548 | trainer.save_metrics("predict", metrics) 549 | 550 | if trainer.is_world_process_zero(): 551 | if training_args.predict_with_generate: 552 | predictions = tokenizer.batch_decode( 553 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 554 | ) 555 | predictions = [pred.strip() for pred in predictions] 556 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 557 | with open(output_prediction_file, "w") as writer: 558 | writer.write("\n".join(predictions)) 559 | 560 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 561 | if data_args.dataset_name is not None: 562 | kwargs["dataset_tags"] = data_args.dataset_name 563 | if data_args.dataset_config_name is not None: 564 | kwargs["dataset_args"] = data_args.dataset_config_name 565 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 566 | else: 567 | kwargs["dataset"] = data_args.dataset_name 568 | 569 | if training_args.push_to_hub: 570 | trainer.push_to_hub(**kwargs) 571 | else: 572 | trainer.create_model_card(**kwargs) 573 | 574 | return results 575 | 576 | 577 | if __name__ == "__main__": 578 | main() 579 | -------------------------------------------------------------------------------- /question_generation/train_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | A subclass of `Trainer` specific to Question-Answering tasks 17 | """ 18 | 19 | from transformers import Trainer, is_torch_tpu_available 20 | from transformers.trainer_utils import PredictionOutput 21 | 22 | 23 | if is_torch_tpu_available(): 24 | import torch_xla.core.xla_model as xm 25 | import torch_xla.debug.metrics as met 26 | 27 | 28 | class QuestionAnsweringTrainer(Trainer): 29 | def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.eval_examples = eval_examples 32 | self.post_process_function = post_process_function 33 | 34 | def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): 35 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 36 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 37 | eval_examples = self.eval_examples if eval_examples is None else eval_examples 38 | 39 | # Temporarily disable metric computation, we will do it in the loop here. 40 | compute_metrics = self.compute_metrics 41 | self.compute_metrics = None 42 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 43 | try: 44 | output = eval_loop( 45 | eval_dataloader, 46 | description="Evaluation", 47 | # No point gathering the predictions if there are no metrics, otherwise we defer to 48 | # self.args.prediction_loss_only 49 | prediction_loss_only=True if compute_metrics is None else None, 50 | ignore_keys=ignore_keys, 51 | ) 52 | finally: 53 | self.compute_metrics = compute_metrics 54 | 55 | if self.post_process_function is not None and self.compute_metrics is not None: 56 | eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) 57 | metrics = self.compute_metrics(eval_preds) 58 | 59 | # Prefix all keys with metric_key_prefix + '_' 60 | for key in list(metrics.keys()): 61 | if not key.startswith(f"{metric_key_prefix}_"): 62 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 63 | 64 | self.log(metrics) 65 | else: 66 | metrics = {} 67 | 68 | if self.args.tpu_metrics_debug or self.args.debug: 69 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 70 | xm.master_print(met.metrics_report()) 71 | 72 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) 73 | return metrics 74 | 75 | def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): 76 | predict_dataloader = self.get_test_dataloader(predict_dataset) 77 | 78 | # Temporarily disable metric computation, we will do it in the loop here. 79 | compute_metrics = self.compute_metrics 80 | self.compute_metrics = None 81 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 82 | try: 83 | output = eval_loop( 84 | predict_dataloader, 85 | description="Prediction", 86 | # No point gathering the predictions if there are no metrics, otherwise we defer to 87 | # self.args.prediction_loss_only 88 | prediction_loss_only=True if compute_metrics is None else None, 89 | ignore_keys=ignore_keys, 90 | ) 91 | finally: 92 | self.compute_metrics = compute_metrics 93 | 94 | if self.post_process_function is None or self.compute_metrics is None: 95 | return output 96 | 97 | predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") 98 | metrics = self.compute_metrics(predictions) 99 | 100 | # Prefix all keys with metric_key_prefix + '_' 101 | for key in list(metrics.keys()): 102 | if not key.startswith(f"{metric_key_prefix}_"): 103 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 104 | 105 | return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.12.5 2 | datasets==1.15.1 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | requirements = [ 4 | "transformers>=4.12.5", 5 | "datasets>=1.15.1", 6 | "torch>=1.3" 7 | ] 8 | 9 | setup( 10 | name="question_generation", 11 | version="1.0.5", 12 | author="algolet", 13 | author_email="wei.cai@algolet.com", 14 | description="Question Generation and Question Answering Pipeline", 15 | long_description=open("README.md", "r", encoding="utf-8").read(), 16 | long_description_content_type="text/markdown", 17 | license="Apache", 18 | url="https://github.com/algolet/question_generation", 19 | packages=find_packages(), 20 | install_requires=requirements, 21 | python_requires=">=3.6.0", 22 | classifiers=[ 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Education", 25 | "Intended Audience :: Science/Research", 26 | "License :: OSI Approved :: Apache Software License", 27 | "Programming Language :: Python :: 3", 28 | "Programming Language :: Python :: 3.6", 29 | "Programming Language :: Python :: 3.7", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 33 | ] 34 | ) 35 | 36 | 37 | 38 | --------------------------------------------------------------------------------