├── .gitignore ├── LICENSE ├── README-zh.md ├── README.md ├── data └── CPED │ ├── speakers.txt │ ├── test_split.csv │ ├── train_split.csv │ └── valid_split.csv ├── envs └── py3.8_torch1.9.0_ignite0.4.8_tensorflow2.2.0_cuda10.2_transformers4.18.0_paddlepaddle-gpu_2.3.0.yml ├── erc_baseline └── README.md ├── images ├── dataset_comparison.png └── dataset_staticstics.png ├── pec_baseline ├── README.md ├── models │ ├── README.md │ ├── __init__.py │ ├── base_model.py │ ├── gpt │ │ ├── __init__.py │ │ ├── modeling_openai.py │ │ └── tokenization_openai.py │ ├── gpt2 │ │ ├── __init__.py │ │ ├── modeling_gpt2.py │ │ └── tokenization_gpt2.py │ └── model_parameters.py ├── train_model.py ├── train_model_script.sh └── utils │ ├── README.md │ ├── __init__.py │ ├── base_util.py │ ├── cped_dataset.py │ ├── cped_util.py │ └── dataset_statistics.py └── prc_baseline └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # dataset 132 | CPED_cache_for_CpedDataset 133 | logs/ 134 | runs/ 135 | results/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | # [CPED](https://github.com/scutcyr/CPED) 2 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) [![arxiv](https://img.shields.io/badge/arXiv-2205.14727-b31b1b.svg)](https://arxiv.org/abs/2205.14727) [![GitHub stars](https://img.shields.io/github/stars/scutcyr/CPED)](https://github.com/scutcyr/CPED/stargazers) [![GitHub license](https://img.shields.io/github/license/scutcyr/CPED)](https://github.com/scutcyr/CPED/blob/main/LICENSE) ![GitHub repo size](https://img.shields.io/github/repo-size/scutcyr/CPED) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) ![GitHub last commit](https://img.shields.io/github/last-commit/scutcyr/CPED) 3 | 4 | README: [English](https://github.com/scutcyr/CPED/blob/main/README.md) | [中文](https://github.com/scutcyr/CPED/blob/main/README-zh.md) 5 | 该仓库提供下面的论文的实现细节: 6 | **[CPED: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI](https://arxiv.org/abs/2205.14727)** 7 | 8 | 更多信息请参考我们的[论文](https://arxiv.org/abs/2205.14727)。 9 | 10 | 数据集已经同步发布在千言平台: [https://www.luge.ai/#/luge/dataDetail?id=41](https://www.luge.ai/#/luge/dataDetail?id=41) 11 | 12 | ## 目录 13 | * 简介 14 | * 数据集统计学特性 15 | * 任务定义 16 | * 实验结果 17 | * 使用方法 18 | 19 | ## 简介 20 | 我们构建了一个命名为**CPED**的数据集,该数据集源于40部中文电视剧。 21 | CPED包括与情感、个性特质相关的多源知识,包括:13类情绪、性别、大五人格、19类对话动作以及其他知识。下表给出了CPED与其他常见数据集的比较。 22 | 23 | * 我们构建了一个多轮的中文个性情感对话数据集CPED。据我们所知,CPED是首个中文个性情感对话数据集。它包括超过1.2万个对话,超过13.3万个语句,并且是多模态的。因此,该数据集可以用在复杂的对话理解任务以及拟人化的对话生成任务研究。 24 | * CPED提供了3类属性标注(姓名、性别、年龄),大五人格特质标注,2类情感标注(3分类粗粒度情感、13分类细粒度情感),以及对话动作DA标注。人格特质和情感可以用作开放域对话生成的先验外部知识。提升对话系统的拟人化水平。 25 | * 我们在论文中提出了3个任务:对话中的人格识别(PRC),对话中的情感识别(ERC),以及个性情感对话生成(PEC),一系列实验验证了人格以及情感对于对话生成的重要性。 26 | 27 | ![dataset_comparison](./images/dataset_comparison.png) 28 | 29 | ## 数据集统计学特性 30 | 为了让对话系统学习情感表达和个性表达能力,我们提供了下表中列出的多种类型的注释标签。 31 | 32 | | # of annos. | Labels | Num. | 33 | |:-----------:|:-------|:----:| 34 | | Sentiment | positive, neutral, and negative | 3 | 35 | | Emotion | happy, grateful, relaxed, other-positive, neutral, angry, sad, feared, depressed, disgusted, astonished, worried and other-negative | 13 | 36 | | Gender | male, female, and unknown | 3 | 37 | | Age group | children, teenager, young, middle-aged, elderly and unknown | 6 | 38 | | Big Five | high, low, and unknown | 3 | 39 | | DA | greeting (g), question (q), answer (ans), statement-opinion (sv), statement-non-opinion (sd), apology (fa), command (c), agreement/acceptance (aa), disagreement (dag), acknowledge (a), appreciation (ba), interjection (ij), conventional-closing (fc), thanking (ft), quotation (^q), reject(rj), irony (ir), comfort (cf) and other (oth) | 19 | 40 | | Scene | home, office, school, mall, hospital, restaurant, sports-venue, entertainment-venue, car, outdoor and other-scene | 11 | 41 | 42 | 43 | CPED数据集中性别、年龄、3分类情感、13分类细粒度情绪和DA的统计学分布如下图所示。 44 | ![](./images/dataset_staticstics.png) 45 | 46 | CPED的各项统计信息如下表所示. 47 | | 统计项 | 训练集 | 验证集 | 测试集 | 48 | |-----------------------|---------|---------|---------| 49 | | 模态 | (v,a,t) | (v,a,t) | (v,a,t) | 50 | | 电视剧 | 26 | 5 | 9 | 51 | | 对话 | 8,086 | 934 | 2,815 | 52 | | 语句 | 94,187 | 11,137 | 27,438 | 53 | | 说话人 | 273 | 38 | 81 | 54 | | 每个对话的平均句子数 | 11.6 | 11.9 | 9.7 | 55 | | 对话的最大句子数 | 75 | 31 | 34 | 56 | | 每个对话的平均情感类别数 | 2.8 | 3.4 | 3.2 | 57 | | 每个对话的平均DA类别数 | 3.6 | 3.7 | 3.2 | 58 | | 平均句子长度 | 8.3 | 8.2 | 8.3 | 59 | | 最大句子长度 | 127 | 42 | 45 | 60 | | 语句的平均语音长度 | 2.1s | 2.12s | 2.21s | 61 | 62 | 63 | 64 | ## 任务定义 65 | CPED可以用于对话理解任务和对话生成任务的评估,例如说话人建模、对话中的个性识别、对话中的情感识别、对话中的DA识别、回复的情感预测、情感对话生成、个性会话生成、移情对话生成等,CPED还可以应用于多模态人格或情感识别、多模态对话生成。它将对促进认知智能的发展起到积极的作用。 66 | 我们在本项目当中引入3种任务,如下所示: 67 | * **ERC**: [对话中的情感识别任务](https://paperswithcode.com/task/emotion-recognition-in-conversation) 68 | * **PRC**: [对话中的人格(个性)识别任务](https://paperswithcode.com/task/personality-recognition-in-conversation) 69 | * **PEC**: [个性情感对话生成任务](https://paperswithcode.com/task/personalized-and-emotional-conversation) 70 | 71 | 72 | 73 | ## 使用方法 74 | 如果你使用conda配置虚拟环境,你可以通过以下命令创建运行baseline模型的python虚拟环境: 75 | ```bash 76 | conda create -n py38 python=3.8 77 | conda activate py38 78 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 79 | pip install tensorflow==2.2.0 80 | pip install transformers==4.18.0 81 | python -m pip install paddlepaddle-gpu==2.3.0 -i https://mirror.baidu.com/pypi/simple 82 | pip install pytorch-ignite==0.4.8 83 | pip install notebook 84 | pip install pandas 85 | pip install chardet 86 | pip install matplotlib==3.5.2 87 | python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/ 88 | python -m pip install ppasr -i https://mirrors.aliyun.com/pypi/simple/ -U 89 | pip install nltk 90 | pip install bert-score 91 | ``` 92 | 93 | 部分依赖包的使用版本如下所示: 94 | ```bash 95 | python=3.8 96 | torch==1.9.0+cu102 97 | torchvision==0.10.0+cu102 98 | torchaudio==0.9.0 99 | tensorflow==2.2.0 100 | tensorboard==2.2.2 101 | transformers==4.18.0 102 | paddlepaddle-gpu==2.3.0 103 | paddlenlp==2.3.2 104 | pytorch-ignite==0.4.8 105 | matplotlib==3.5.2 106 | notebook==6.4.11 107 | pandas==1.4.2 108 | chardet==4.0.0 109 | nltk==3.7 110 | bert-score==0.3.11 111 | ``` 112 | 113 | 114 | 115 | 如果你在研究当中使用到CPED数据集或者本项目,请引用以下论文: 116 | ``` 117 | @article{chen2022cped, 118 | title={{CPED}: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI}, 119 | author={Yirong Chen and Weiquan Fan and Xiaofen Xing and Jianxin Pang and Minlie Huang and Wenjing Han and Qianfeng Tie and Xiangmin Xu}, 120 | journal={arXiv preprint arXiv:2205.14727}, 121 | year={2022}, 122 | url={https://arxiv.org/abs/2205.14727} 123 | } 124 | ``` 125 | 126 | >>> 人体数据感知教育部工程研究中心 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CPED](https://github.com/scutcyr/CPED) 2 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) [![arxiv](https://img.shields.io/badge/arXiv-2205.14727-b31b1b.svg)](https://arxiv.org/abs/2205.14727) [![GitHub stars](https://img.shields.io/github/stars/scutcyr/CPED)](https://github.com/scutcyr/CPED/stargazers) [![GitHub license](https://img.shields.io/github/license/scutcyr/CPED)](https://github.com/scutcyr/CPED/blob/main/LICENSE) ![GitHub repo size](https://img.shields.io/github/repo-size/scutcyr/CPED) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) ![GitHub last commit](https://img.shields.io/github/last-commit/scutcyr/CPED) 3 | 4 | 5 | README: [English](https://github.com/scutcyr/CPED/blob/main/README.md) | [中文](https://github.com/scutcyr/CPED/blob/main/README-zh.md) 6 | This repository provides the implementation details for the paper: 7 | **[CPED: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI](https://arxiv.org/abs/2205.14727)** 8 | 9 | For more information, please refer to our [paper](https://arxiv.org/abs/2205.14727). 10 | 11 | The dataset is also available in luge.ai: [https://www.luge.ai/#/luge/dataDetail?id=41](https://www.luge.ai/#/luge/dataDetail?id=41) 12 | 13 | ## Contents 14 | * Introduction 15 | * Dataset Statistics 16 | * Task Definition 17 | * Evaluation Results 18 | * Usage 19 | 20 | ## Introduction 21 | We construct a dataset named **CPED** from 40 Chinese TV shows. CPED consists of multisource knowledge related to empathy and personal characteristic. This knowledge covers 13 emotions, gender, Big Five personality traits, 19 dialogue acts and other knowledge. The table below shows a comparison of CPED with some other common conversation data sets. 22 | 23 | * We build a multiturn Chinese Personalized and Emotional Dialogue dataset called CPED. To the best of our knowledge, CPED is the first Chinese personalized and emotional dialogue dataset. CPED contains 12K dialogues and 133K utterances with multi-modal context. Therefore, it can be used in both complicated dialogue understanding and human-like conversation generation. 24 | * CPED has been annotated with 3 character attributes (name, gender age), Big Five personality traits, 2 types of dynamic emotional information (sentiment and emotion) and DAs. The personality traits and emotions can be used as prior external knowledge for open-domain conversation generation, making the conversation system have a good command of personification capabilities. 25 | * We propose three tasks for CPED: **personality recognition in conversations (PRC)**, **emotion recognition in conversations (ERC)**, and **personalized and emotional conversation (PEC)**. A set of experiments verify the importance of using personalities and emotions as prior external knowledge for conversation generation. 26 | 27 | ![dataset_comparison](./images/dataset_comparison.png) 28 | 29 | ## Dataset Statistics 30 | In order for the dialogue system to learn emotional expression and personalized expression abilities, we provide multiple types of annotation labels listed in the following Table. 31 | 32 | | # of annos. | Labels | Num. | 33 | |:-----------:|:-------|:----:| 34 | | Sentiment | positive, neutral, and negative | 3 | 35 | | Emotion | happy, grateful, relaxed, other-positive, neutral, angry, sad, feared, depressed, disgusted, astonished, worried and other-negative | 13 | 36 | | Gender | male, female, and unknown | 3 | 37 | | Age group | children, teenager, young, middle-aged, elderly and unknown | 6 | 38 | | Big Five | high, low, and unknown | 3 | 39 | | DA | greeting (g), question (q), answer (ans), statement-opinion (sv), statement-non-opinion (sd), apology (fa), command (c), agreement/acceptance (aa), disagreement (dag), acknowledge (a), appreciation (ba), interjection (ij), conventional-closing (fc), thanking (ft), quotation (^q), reject(rj), irony (ir), comfort (cf) and other (oth) | 19 | 40 | | Scene | home, office, school, mall, hospital, restaurant, sports-venue, entertainment-venue, car, outdoor and other-scene | 11 | 41 | 42 | 43 | Distribution of Gender, Age Group, Sentiment, Emotion and DA in CPED Dataset are shown in the following figure. 44 | ![](./images/dataset_staticstics.png) 45 | 46 | The statistics of CPED are listed in the following table. 47 | | Statistics | Train | Dev | Test | 48 | |---------------------------------|---------|---------|---------| 49 | | # of modalities | (v,a,t) | (v,a,t) | (v,a,t) | 50 | | # of TV plays | 26 | 5 | 9 | 51 | | # of dialogues | 8,086 | 934 | 2,815 | 52 | | # of utterances | 94,187 | 11,137 | 27,438 | 53 | | # of speakers | 273 | 38 | 81 | 54 | | Avg. # utt. per dial. | 11.6 | 11.9 | 9.7 | 55 | | Max # utt. per dial. | 75 | 31 | 34 | 56 | | Avg. # of emot. per dial. | 2.8 | 3.4 | 3.2 | 57 | | Avg. # of DAs per dial. | 3.6 | 3.7 | 3.2 | 58 | | Avg. utt. length | 8.3 | 8.2 | 8.3 | 59 | | Max utt. length | 127 | 42 | 45 | 60 | | Avg. duration of an utterance | 2.1s | 2.12s | 2.21s | 61 | 62 | 63 | ## Task Definition 64 | CPED allows evaluation of both conversational cognitive tasks and conversation generation tasks, e.g. speaker modeling, personality recognition in conversations, emotion recognition in conversations, DA recognition in conversations, emotion prediction for response, emotional conversation generation, personalized conversation generation, empathetic conversation etc. By being multimodal, CPED can also be applied in multimodal personality or emotion recognition, multimodal conversation generation. It will play a positive role in promoting the development of cognitive intelligence. 65 | We introduced 3 tasks in the project: 66 | * **ERC**: [Emotion Recognition in Conversation](https://paperswithcode.com/task/emotion-recognition-in-conversation) 67 | * **PRC**: [Personality Recognition in Conversation](https://paperswithcode.com/task/personality-recognition-in-conversation) 68 | * **PEC**: [Personalized and Emotional Conversation](https://paperswithcode.com/task/personalized-and-emotional-conversation) 69 | 70 | 71 | 72 | ## Usage 73 | You can create the python virtual environment through the following bash script: 74 | ```bash 75 | conda create -n py38 python=3.8 76 | conda activate py38 77 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 78 | pip install tensorflow==2.2.0 79 | pip install transformers==4.18.0 80 | python -m pip install paddlepaddle-gpu==2.3.0 -i https://mirror.baidu.com/pypi/simple 81 | pip install pytorch-ignite==0.4.8 82 | pip install notebook 83 | pip install pandas 84 | pip install chardet 85 | pip install matplotlib==3.5.2 86 | python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/ 87 | python -m pip install ppasr -i https://mirrors.aliyun.com/pypi/simple/ -U 88 | pip install nltk 89 | pip install bert-score 90 | ``` 91 | 92 | some version of the used packages are as follows: 93 | ```bash 94 | python=3.8 95 | torch==1.9.0+cu102 96 | torchvision==0.10.0+cu102 97 | torchaudio==0.9.0 98 | tensorflow==2.2.0 99 | tensorboard==2.2.2 100 | transformers==4.18.0 101 | paddlepaddle-gpu==2.3.0 102 | paddlenlp==2.3.2 103 | pytorch-ignite==0.4.8 104 | matplotlib==3.5.2 105 | notebook==6.4.11 106 | pandas==1.4.2 107 | chardet==4.0.0 108 | nltk==3.7 109 | bert-score==0.3.11 110 | ``` 111 | 112 | 113 | 114 | 115 | Please cite our paper if you use CPED or this project: 116 | ``` 117 | @article{chen2022cped, 118 | title={{CPED}: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI}, 119 | author={Yirong Chen and Weiquan Fan and Xiaofen Xing and Jianxin Pang and Minlie Huang and Wenjing Han and Qianfeng Tie and Xiangmin Xu}, 120 | journal={arXiv preprint arXiv:2205.14727}, 121 | year={2022}, 122 | url={https://arxiv.org/abs/2205.14727} 123 | } 124 | ``` 125 | 126 | >>> Engineering Research Ceter of Ministry of Education on Human Body Perception 127 | -------------------------------------------------------------------------------- /data/CPED/speakers.txt: -------------------------------------------------------------------------------- 1 | 于德伟 2 | 王柏川 3 | 袁大头 4 | 外卖小哥 5 | 林磊儿 6 | 贾小朵 7 | 刘梅 8 | 孙强 9 | 果然 10 | 钱三一 11 | 谭宗明 12 | 李月如 13 | 张医生 14 | 方芳 15 | 毕胜男 16 | 邓小可母亲 17 | 王胜男 18 | 彭永辉 19 | 刘静 20 | 关雎尔 21 | 宋倩 22 | 刘珍珍 23 | 崔朕父亲 24 | 秦羽墨 25 | 童娟 26 | 诸葛大力 27 | 林大为 28 | 王珊 29 | 丁香 30 | 陈建 31 | 方一凡 32 | 陆展博 33 | 王岩 34 | 陆母 35 | 刘铮 36 | 赵美兰 37 | 高健 38 | 罗琦 39 | 石天冬 40 | 美薇 41 | 陈爱莲 42 | 陆小山 43 | 夏峰 44 | 黄浩达 45 | 应勤 46 | 陆永福 47 | 林妙妙 48 | 唐悠悠 49 | 赵海棠 50 | 安娜 51 | 毕大千 52 | 方圆 53 | 王东方 54 | 苏明哲 55 | 乔英子 56 | 陈晓伟 57 | 白洁 58 | 方骏 59 | 江达琳 60 | 唐鹏母亲 61 | 陆晓婷 62 | 李理 63 | 小白 64 | 唐鹏 65 | 蔡钟闲 66 | 栗伟正 67 | 张盛 68 | 苑春丽 69 | 黄晓峰 70 | 诺澜 71 | 王长水 72 | 姗姗 73 | 乔卫东 74 | 孙雅娴 75 | 罗槟 76 | 蔡钟闲妻子 77 | 樊斌 78 | 徐笑 79 | 陈可妈妈 80 | 校长 81 | 金志明 82 | 梁晓慧 83 | 大磊 84 | 唐娇 85 | 孙超越 86 | 于扬 87 | 叶珊 88 | 苏明成 89 | 罗丹 90 | 张天皓 91 | 汪涵 92 | 乔伊林 93 | 路易斯 94 | 周芳 95 | 童文洁助理 96 | 江长恩 97 | 唐爽 98 | 周佳成 99 | 咖喱酱 100 | 张一男 101 | 顾大鹏 102 | 程洪斗 103 | 关谷神奇 104 | 陈佳 105 | 吴芳妮 106 | 江焱 107 | 陶语桐爸爸 108 | 莫小康 109 | 卢家凯 110 | 杨桃 111 | 方依依 112 | 蓝红 113 | 顾末易 114 | 陆小贝 115 | 蒙志远 116 | 周小北 117 | 何西 118 | 崔朕妈妈 119 | 邓小可 120 | 徐豆豆 121 | 张伟 122 | 戴曦 123 | 焦阳 124 | 张超 125 | 张一男妈妈 126 | 童心 127 | 林宛瑜 128 | 姚美娟 129 | 杨呈远妈 130 | 斯嘉丽 131 | 封印 132 | 秦羽墨前男友 133 | 顾映真 134 | 舍友 135 | 南海 136 | 崔朕 137 | 蒂娜 138 | 苏大强 139 | 赵启平 140 | 毕然 141 | 蓝馨 142 | 老罗 143 | 周格格 144 | 沈画 145 | 彭凯旋 146 | 包奕凡 147 | 邱莹莹 148 | 小熊 149 | 雷瑛 150 | 那依 151 | 袁帅 152 | 何东 153 | 朱丽 154 | 艾小天 155 | 卫哲 156 | 丽萨 157 | 泰勒 158 | 柳青 159 | 麦飞 160 | 李闻雨 161 | 黄越彬 162 | 郑海潮 163 | 张亮忠 164 | 林晓珑 165 | 陈可依 166 | 刘旭刚 167 | 胡七星 168 | 琴琴 169 | 小艾 170 | 夏雨 171 | 安迪 172 | 其他 173 | 陶语桐妈妈 174 | 司徒末 175 | 杨嘟嘟 176 | 沈英杰 177 | 王媛 178 | 奥特 179 | 刘炎 180 | 肖一飞 181 | 陆曼妮 182 | 田坤 183 | 叶昭君 184 | 季胜利 185 | 于春晓 186 | 陈美嘉 187 | 郑远东 188 | 冯晶晶 189 | 夏冰 190 | 何北 191 | 张小宇 192 | 韩文静 193 | 赵冲 194 | 段西风 195 | 肖威 196 | 邓小琪 197 | 吴非 198 | 曲筱绡 199 | 福子 200 | 冀遇 201 | 樊胜美 202 | 程俊 203 | 莫正源 204 | 朴恩熙 205 | 牛美丽 206 | 于建国 207 | 沈青川 208 | 宁巴拉 209 | 成晓峰 210 | 袁则强 211 | 胡一菲 212 | 宋宁宇 213 | 江天昊 214 | 宋光明 215 | 方朵朵 216 | 程皓 217 | 戴天高 218 | 邓文宣 219 | 傅沛 220 | 小梦 221 | 罗浩 222 | 郝泽宇 223 | 杨呈远 224 | 吴佳妮 225 | 苏青 226 | 齐大胜 227 | 果长山 228 | 苏珊 229 | 廖佳敏 230 | 田坤前女友 231 | 童小麒 232 | 梁晓慧爸爸 233 | 刘星 234 | 李三弟 235 | 顾婕 236 | 童文洁 237 | 福方树 238 | 王佳佳 239 | 赵小亮 240 | 马邦尼 241 | 曾小贤 242 | 戴娜 243 | 魏山山 244 | 童文洁老板 245 | 夏天 246 | 罗海燕 247 | 薛素梅 248 | 宋暖 249 | 王珊珊父亲 250 | 于小强 251 | 胡一统 252 | 陆小贝母亲 253 | 吕子乔 254 | 梁伊 255 | 刘光耀 256 | 徐天 257 | 向飞 258 | 冯兰芝 259 | 丛卉 260 | 于果 261 | 魏渭 262 | 罗玥 263 | 刘慧芸 264 | 蒂娜妈妈 265 | 黄芷陶 266 | 李老师 267 | 季杨杨 268 | 表姐 269 | 郝敏 270 | 夏东海 271 | 刘兰芝 272 | 李三妹 273 | 斯黛拉 274 | 孙总 275 | 陆长山 276 | 何赛 277 | 鲍家明 278 | 权筝 279 | 余峥 280 | 栗娜 281 | 刘栋 282 | 张铭阳 283 | 王涛 284 | 姚梅 285 | 王博 286 | 陶语桐 287 | 舒晴 288 | 李枫 289 | 邹男 290 | 夏雪 291 | 林君 292 | 高红 293 | 罗素 294 | 刀美岚 295 | 赵小川 296 | 姚澜 297 | 宋大楠 298 | 潘芸 299 | 欧阳雨露 300 | 邹北业 301 | 罗茜茜 302 | 苏明玉 303 | 王珊珊 -------------------------------------------------------------------------------- /envs/py3.8_torch1.9.0_ignite0.4.8_tensorflow2.2.0_cuda10.2_transformers4.18.0_paddlepaddle-gpu_2.3.0.yml: -------------------------------------------------------------------------------- 1 | name: py38 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.4.26=h06a4308_0 8 | - certifi=2022.5.18.1=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1o=h7f8727e_0 16 | - pip=21.2.4=py38h06a4308_0 17 | - python=3.8.13=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=61.2.0=py38h06a4308_0 20 | - sqlite=3.38.3=hc218d9a_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.5=h7f8727e_1 24 | - zlib=1.2.12=h7f8727e_2 25 | - pip: 26 | - absl-py==1.1.0 27 | - aiohttp==3.8.1 28 | - aiosignal==1.2.0 29 | - appdirs==1.4.4 30 | - argon2-cffi==21.3.0 31 | - argon2-cffi-bindings==21.2.0 32 | - astor==0.8.1 33 | - asttokens==2.0.5 34 | - astunparse==1.6.3 35 | - async-timeout==4.0.2 36 | - attrs==21.4.0 37 | - audioread==2.1.9 38 | - babel==2.10.1 39 | - backcall==0.2.0 40 | - bce-python-sdk==0.8.64 41 | - beautifulsoup4==4.11.1 42 | - bert-score==0.3.11 43 | - bleach==5.0.0 44 | - cachetools==4.2.4 45 | - cffi==1.15.0 46 | - cfgv==3.3.1 47 | - chardet==4.0.0 48 | - charset-normalizer==2.0.12 49 | - click==8.1.3 50 | - cn2an==0.5.17 51 | - colorama==0.4.4 52 | - colorlog==6.6.0 53 | - cycler==0.11.0 54 | - datasets==2.2.2 55 | - debugpy==1.6.0 56 | - decorator==5.1.1 57 | - defusedxml==0.7.1 58 | - dill==0.3.4 59 | - distlib==0.3.4 60 | - entrypoints==0.4 61 | - executing==0.8.3 62 | - fastjsonschema==2.15.3 63 | - filelock==3.7.1 64 | - flake8==4.0.1 65 | - flask==2.1.2 66 | - flask-babel==2.0.0 67 | - fonttools==4.33.3 68 | - frozenlist==1.3.0 69 | - fsspec==2022.5.0 70 | - future==0.18.2 71 | - gast==0.3.3 72 | - google-auth==1.35.0 73 | - google-auth-oauthlib==0.4.6 74 | - google-pasta==0.2.0 75 | - grpcio==1.46.3 76 | - h5py==2.10.0 77 | - huggingface-hub==0.7.0 78 | - identify==2.5.1 79 | - idna==3.3 80 | - importlib-metadata==4.11.4 81 | - importlib-resources==5.7.1 82 | - ipykernel==6.13.0 83 | - ipython==8.4.0 84 | - ipython-genutils==0.2.0 85 | - itsdangerous==2.1.2 86 | - jedi==0.18.1 87 | - jieba==0.42.1 88 | - jinja2==3.1.2 89 | - joblib==1.1.0 90 | - jsonschema==4.6.0 91 | - jupyter-client==7.3.2 92 | - jupyter-core==4.10.0 93 | - jupyterlab-pygments==0.2.2 94 | - keras-preprocessing==1.1.2 95 | - kiwisolver==1.4.2 96 | - librosa==0.8.0 97 | - llvmlite==0.38.1 98 | - markdown==3.3.7 99 | - markupsafe==2.1.1 100 | - matplotlib==3.5.2 101 | - matplotlib-inline==0.1.3 102 | - mccabe==0.6.1 103 | - mistune==0.8.4 104 | - mock==4.0.3 105 | - multidict==6.0.2 106 | - multiprocess==0.70.12.2 107 | - nbclient==0.6.4 108 | - nbconvert==6.5.0 109 | - nbformat==5.4.0 110 | - nest-asyncio==1.5.5 111 | - nltk==3.7 112 | - nodeenv==1.6.0 113 | - notebook==6.4.11 114 | - numba==0.55.2 115 | - numpy==1.22.4 116 | - oauthlib==3.2.0 117 | - opt-einsum==3.3.0 118 | - packaging==21.3 119 | - paddle-bfloat==0.1.2 120 | - paddle2onnx==0.9.7 121 | - paddlefsl==1.1.0 122 | - paddlenlp==2.3.2 123 | - paddlepaddle-gpu==2.3.0 124 | - paddlespeech-feat==0.1.0 125 | - pandas==1.4.2 126 | - pandocfilters==1.5.0 127 | - parso==0.8.3 128 | - pexpect==4.8.0 129 | - pickleshare==0.7.5 130 | - pillow==9.1.1 131 | - platformdirs==2.5.2 132 | - pooch==1.6.0 133 | - ppasr==0.1.5 134 | - pre-commit==2.19.0 135 | - proces==0.1.2 136 | - prometheus-client==0.14.1 137 | - prompt-toolkit==3.0.29 138 | - protobuf==3.20.0 139 | - psutil==5.9.1 140 | - ptyprocess==0.7.0 141 | - pure-eval==0.2.2 142 | - pyarrow==8.0.0 143 | - pyasn1==0.4.8 144 | - pyasn1-modules==0.2.8 145 | - pycodestyle==2.8.0 146 | - pycparser==2.21 147 | - pycryptodome==3.14.1 148 | - pydub==0.25.1 149 | - pyflakes==2.4.0 150 | - pygments==2.12.0 151 | - pyparsing==3.0.9 152 | - pyrsistent==0.18.1 153 | - python-dateutil==2.8.2 154 | - python-levenshtein==0.12.2 155 | - pytorch-ignite==0.4.8 156 | - pytz==2022.1 157 | - pyyaml==6.0 158 | - pyzmq==23.1.0 159 | - regex==2022.6.2 160 | - requests==2.27.1 161 | - requests-oauthlib==1.3.1 162 | - resampy==0.2.2 163 | - responses==0.18.0 164 | - rsa==4.8 165 | - ruamel-yaml==0.17.21 166 | - ruamel-yaml-clib==0.2.6 167 | - sacremoses==0.0.53 168 | - scikit-learn==1.1.1 169 | - scipy==1.8.1 170 | - send2trash==1.8.0 171 | - sentencepiece==0.1.96 172 | - seqeval==1.2.2 173 | - shellcheck-py==0.8.0.4 174 | - six==1.16.0 175 | - soundfile==0.10.3.post1 176 | - soupsieve==2.3.2.post1 177 | - stack-data==0.2.0 178 | - tensorboard==2.2.2 179 | - tensorboard-plugin-wit==1.8.1 180 | - tensorflow==2.2.0 181 | - tensorflow-estimator==2.2.0 182 | - termcolor==1.1.0 183 | - terminado==0.15.0 184 | - threadpoolctl==3.1.0 185 | - tinycss2==1.1.1 186 | - tokenizers==0.12.1 187 | - toml==0.10.2 188 | - torch==1.9.0+cu102 189 | - torchaudio==0.9.0 190 | - torchvision==0.10.0+cu102 191 | - tornado==6.1 192 | - tqdm==4.59.0 193 | - traitlets==5.2.2.post1 194 | - transformers==4.18.0 195 | - typing-extensions==4.2.0 196 | - urllib3==1.26.9 197 | - virtualenv==20.14.1 198 | - visualdl==2.2.3 199 | - wcwidth==0.2.5 200 | - webencodings==0.5.1 201 | - webrtcvad==2.0.10 202 | - werkzeug==2.1.2 203 | - wrapt==1.14.1 204 | - xxhash==3.0.0 205 | - yarl==1.7.2 206 | - zhconv==1.4.3 207 | - zipp==3.8.0 208 | prefix: /home/phd-chen.yirong/anaconda3/envs/py38 209 | -------------------------------------------------------------------------------- /erc_baseline/README.md: -------------------------------------------------------------------------------- 1 | # ERC: Emotion Recognition in Conversation 2 | # 对话中的情感识别任务 3 | * **Author: Chen Yirong ** 4 | * **Date: 2022.06.01** 5 | 6 | **Note:** Code and model will be released in the future! 7 | 8 | 9 | ## Task Definition 10 | ## 任务定义 11 | See [https://paperswithcode.com/task/emotion-recognition-in-conversation](https://paperswithcode.com/task/emotion-recognition-in-conversation) -------------------------------------------------------------------------------- /images/dataset_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scutcyr/CPED/1e4b81c28a123f22387e06664f37e5dc9322380f/images/dataset_comparison.png -------------------------------------------------------------------------------- /images/dataset_staticstics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scutcyr/CPED/1e4b81c28a123f22387e06664f37e5dc9322380f/images/dataset_staticstics.png -------------------------------------------------------------------------------- /pec_baseline/README.md: -------------------------------------------------------------------------------- 1 | # PEC: Personalized and Emotional Conversation 2 | # 个性情感对话生成任务 3 | * **Author: Chen Yirong ** 4 | * **Date: 2022.06.01** 5 | 6 | **Note:** Code and model will be released in the future! 7 | 8 | ## Task Definition 9 | ## 任务定义 10 | See [https://paperswithcode.com/task/personalized-and-emotional-conversation](https://paperswithcode.com/task/personalized-and-emotional-conversation) 11 | 12 | 13 | ## Folder Description 14 | ## 目录说明 15 | * ```config```: Storing the configuration file of the models. 16 | * ```models```: Storing Python code that defines the models. 17 | * ```results```: Storing the result for testing the trained model. 18 | * ```runs```: Storing model parameters and related configurations generated during training or after training. 19 | * ```utils```: Storing Python code for reading dataset. 20 | -------------------------------------------------------------------------------- /pec_baseline/models/README.md: -------------------------------------------------------------------------------- 1 | # models 2 | # 模型设计模块 3 | * **Author: Chen Yirong ** 4 | * **Date: 2022.03.21** 5 | 6 | ## 架构说明 7 | 每个子文件夹存放一个模型,其中,文件夹命名使用小写字母+下划线+数字的组合,例如:```gpt```、```gpt2```、```gpt_per```。 8 | 每个模型由3个文件组成,假设该模型命名为```xxx```: 9 | * ```__init__.py```: 对外提供可访问的接口 10 | * ```modeling_xxx.py```: 模型类的定义, 11 | * ```tokenization_xxx.py```: 模型相关的tokenization类(不是必须的) 12 | 13 | ### 基础文件 14 | 15 | 16 | -------------------------------------------------------------------------------- /pec_baseline/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Models 18 | # Author: Chen Yirong 19 | # Date: 2022.04.06 20 | 21 | __version__ = "1.0.0" 22 | 23 | # 关键包版本说明: 24 | # pytorch: 1.9.0+ 25 | # transformers: 4.11.3+ 26 | 27 | from .base_model import (is_torch_available) 28 | 29 | # 模型类 30 | if is_torch_available(): 31 | from . import ( 32 | gpt, 33 | gpt2, 34 | cvgpt 35 | ) 36 | 37 | # Model parameters calculating and freezing 38 | from .model_parameters import (count_trainable_parameters, count_total_parameters, show_trainable_parameters, 39 | set_freeze_by_names, freeze_by_model_name, unfreeze_by_model_name) -------------------------------------------------------------------------------- /pec_baseline/models/base_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Basic model configuration file 18 | # File: base_model.py 19 | # Used for model configuration 20 | # 用于数据集读取的基础方法 21 | # Author: Chen Yirong 22 | # Date: 2022.04.06 23 | 24 | import os 25 | import logging 26 | import importlib.util 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | _torch_available = importlib.util.find_spec("torch") is not None 31 | 32 | def is_torch_available(): 33 | return _torch_available -------------------------------------------------------------------------------- /pec_baseline/models/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # 关键包版本说明: 21 | # pytorch: 1.9.0+ 22 | # transformers: 4.11.3+ 23 | 24 | 25 | from typing import TYPE_CHECKING 26 | 27 | from transformers.file_utils import is_torch_available 28 | 29 | if is_torch_available(): 30 | from .modeling_openai import ( 31 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, 32 | OpenAIGPTDoubleHeadsModel, 33 | OpenAIGPTForSequenceClassification, 34 | OpenAIGPTLMHeadModel, 35 | OpenAIGPTModel, 36 | OpenAIGPTPreTrainedModel, 37 | load_tf_weights_in_openai_gpt, 38 | ) 39 | -------------------------------------------------------------------------------- /pec_baseline/models/gpt/modeling_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT model.""" 17 | 18 | 19 | # 关键包版本说明: 20 | # pytorch: 1.9.0+ 21 | # transformers: 4.11.3 22 | 23 | 24 | import json 25 | import math 26 | import os 27 | from dataclasses import dataclass 28 | from typing import Optional, Tuple 29 | 30 | import torch 31 | from torch import nn 32 | from torch.nn import CrossEntropyLoss, MSELoss 33 | from torch.cuda.amp import autocast as autocast # 用于使用自动混合精度,要求torch版本为1.6+ 34 | 35 | from transformers.activations import gelu_new, silu 36 | from transformers.file_utils import ( 37 | ModelOutput, 38 | add_code_sample_docstrings, 39 | add_start_docstrings, 40 | add_start_docstrings_to_model_forward, 41 | replace_return_docstrings, 42 | ) 43 | from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput 44 | from transformers.modeling_utils import ( 45 | Conv1D, 46 | PreTrainedModel, 47 | SequenceSummary, 48 | find_pruneable_heads_and_indices, 49 | prune_conv1d_layer, 50 | ) 51 | from transformers.utils import logging 52 | from transformers import OpenAIGPTConfig 53 | 54 | 55 | logger = logging.get_logger(__name__) 56 | 57 | _CHECKPOINT_FOR_DOC = "openai-gpt" 58 | _CONFIG_FOR_DOC = "OpenAIGPTConfig" 59 | _TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer" 60 | 61 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 62 | "openai-gpt", 63 | # See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt 64 | ] 65 | 66 | 67 | def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): 68 | """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" 69 | import re 70 | 71 | import numpy as np 72 | 73 | if ".ckpt" in openai_checkpoint_folder_path: 74 | openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) 75 | 76 | logger.info(f"Loading weights from {openai_checkpoint_folder_path}") 77 | 78 | with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: 79 | names = json.load(names_handle) 80 | with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: 81 | shapes = json.load(shapes_handle) 82 | offsets = np.cumsum([np.prod(shape) for shape in shapes]) 83 | init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] 84 | init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] 85 | init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] 86 | 87 | # This was used when we had a single embedding matrix for positions and tokens 88 | # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) 89 | # del init_params[1] 90 | init_params = [arr.squeeze() for arr in init_params] 91 | 92 | try: 93 | assert model.tokens_embed.weight.shape == init_params[1].shape 94 | assert model.positions_embed.weight.shape == init_params[0].shape 95 | except AssertionError as e: 96 | e.args += (model.tokens_embed.weight.shape, init_params[1].shape) 97 | e.args += (model.positions_embed.weight.shape, init_params[0].shape) 98 | raise 99 | 100 | model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) 101 | model.positions_embed.weight.data = torch.from_numpy(init_params[0]) 102 | names.pop(0) 103 | # Pop position and token embedding arrays 104 | init_params.pop(0) 105 | init_params.pop(0) 106 | 107 | for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): 108 | name = name[6:] # skip "model/" 109 | assert name[-2:] == ":0" 110 | name = name[:-2] 111 | name = name.split("/") 112 | pointer = model 113 | for m_name in name: 114 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 115 | scope_names = re.split(r"(\d+)", m_name) 116 | else: 117 | scope_names = [m_name] 118 | if scope_names[0] == "g": 119 | pointer = getattr(pointer, "weight") 120 | elif scope_names[0] == "b": 121 | pointer = getattr(pointer, "bias") 122 | elif scope_names[0] == "w": 123 | pointer = getattr(pointer, "weight") 124 | else: 125 | pointer = getattr(pointer, scope_names[0]) 126 | if len(scope_names) >= 2: 127 | num = int(scope_names[1]) 128 | pointer = pointer[num] 129 | try: 130 | assert ( 131 | pointer.shape == array.shape 132 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 133 | except AssertionError as e: 134 | e.args += (pointer.shape, array.shape) 135 | raise 136 | try: 137 | assert ( 138 | pointer.shape == array.shape 139 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 140 | except AssertionError as e: 141 | e.args += (pointer.shape, array.shape) 142 | raise 143 | logger.info(f"Initialize PyTorch weight {name}") 144 | pointer.data = torch.from_numpy(array) 145 | return model 146 | 147 | 148 | ACT_FNS = {"relu": nn.ReLU, "silu": silu, "gelu": gelu_new, "swish": silu} 149 | 150 | 151 | class Attention(nn.Module): 152 | def __init__(self, nx, n_ctx, config, scale=False): 153 | super().__init__() 154 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 155 | # [switch nx => n_state from Block to Attention to keep identical to TF implementation] 156 | assert n_state % config.n_head == 0 157 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 158 | self.n_head = config.n_head 159 | self.split_size = n_state 160 | self.scale = scale 161 | 162 | self.c_attn = Conv1D(n_state * 3, nx) 163 | self.c_proj = Conv1D(n_state, nx) 164 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 165 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 166 | self.pruned_heads = set() 167 | 168 | def prune_heads(self, heads): 169 | if len(heads) == 0: 170 | return 171 | heads, index = find_pruneable_heads_and_indices( 172 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads 173 | ) 174 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 175 | # Prune conv1d layers 176 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 177 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 178 | # Update hyper params 179 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 180 | self.n_head = self.n_head - len(heads) 181 | self.pruned_heads = self.pruned_heads.union(heads) 182 | 183 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): 184 | w = torch.matmul(q, k) 185 | if self.scale: 186 | w = w / math.sqrt(v.size(-1)) 187 | # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights 188 | # XD: self.b may be larger than w, so we need to crop it 189 | b = self.bias[:, :, : w.size(-2), : w.size(-1)] 190 | w = w * b + -1e4 * (1 - b) 191 | 192 | if attention_mask is not None: 193 | # Apply the attention mask 194 | w = w + attention_mask 195 | 196 | w = nn.Softmax(dim=-1)(w) 197 | w = self.attn_dropout(w) 198 | 199 | # Mask heads if we want to 200 | if head_mask is not None: 201 | w = w * head_mask 202 | 203 | outputs = [torch.matmul(w, v)] 204 | if output_attentions: 205 | outputs.append(w) 206 | return outputs 207 | 208 | def merge_heads(self, x): 209 | x = x.permute(0, 2, 1, 3).contiguous() 210 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 211 | return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states 212 | 213 | def split_heads(self, x, k=False): 214 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 215 | x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states 216 | if k: 217 | return x.permute(0, 2, 3, 1) 218 | else: 219 | return x.permute(0, 2, 1, 3) 220 | 221 | def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): 222 | x = self.c_attn(x) 223 | query, key, value = x.split(self.split_size, dim=2) 224 | query = self.split_heads(query) 225 | key = self.split_heads(key, k=True) 226 | value = self.split_heads(value) 227 | 228 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) 229 | a = attn_outputs[0] 230 | 231 | a = self.merge_heads(a) 232 | a = self.c_proj(a) 233 | a = self.resid_dropout(a) 234 | 235 | outputs = [a] + attn_outputs[1:] 236 | return outputs # a, (attentions) 237 | 238 | 239 | class MLP(nn.Module): 240 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 241 | super().__init__() 242 | nx = config.n_embd 243 | self.c_fc = Conv1D(n_state, nx) 244 | self.c_proj = Conv1D(nx, n_state) 245 | self.act = ACT_FNS[config.afn] 246 | self.dropout = nn.Dropout(config.resid_pdrop) 247 | 248 | def forward(self, x): 249 | h = self.act(self.c_fc(x)) 250 | h2 = self.c_proj(h) 251 | return self.dropout(h2) 252 | 253 | 254 | class Block(nn.Module): 255 | def __init__(self, n_ctx, config, scale=False): 256 | super().__init__() 257 | nx = config.n_embd 258 | self.attn = Attention(nx, n_ctx, config, scale) 259 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 260 | self.mlp = MLP(4 * nx, config) 261 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 262 | 263 | def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): 264 | attn_outputs = self.attn( 265 | x, 266 | attention_mask=attention_mask, 267 | head_mask=head_mask, 268 | output_attentions=output_attentions, 269 | ) 270 | a = attn_outputs[0] 271 | 272 | n = self.ln_1(x + a) 273 | m = self.mlp(n) 274 | h = self.ln_2(n + m) 275 | 276 | outputs = [h] + attn_outputs[1:] 277 | return outputs 278 | 279 | 280 | class OpenAIGPTPreTrainedModel(PreTrainedModel): 281 | """ 282 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 283 | models. 284 | """ 285 | 286 | config_class = OpenAIGPTConfig 287 | load_tf_weights = load_tf_weights_in_openai_gpt 288 | base_model_prefix = "transformer" 289 | _keys_to_ignore_on_load_missing = [r"position_ids"] 290 | 291 | def _init_weights(self, module): 292 | """Initialize the weights.""" 293 | if isinstance(module, (nn.Linear, Conv1D)): 294 | # Slightly different from the TF version which uses truncated_normal for initialization 295 | # cf https://github.com/pytorch/pytorch/pull/5617 296 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 297 | if module.bias is not None: 298 | module.bias.data.zero_() 299 | elif isinstance(module, nn.Embedding): 300 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 301 | if module.padding_idx is not None: 302 | module.weight.data[module.padding_idx].zero_() 303 | elif isinstance(module, nn.LayerNorm): 304 | module.bias.data.zero_() 305 | module.weight.data.fill_(1.0) 306 | 307 | 308 | @dataclass 309 | class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): 310 | """ 311 | Base class for outputs of models predicting if two sentences are consecutive or not. 312 | 313 | Args: 314 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): 315 | Language modeling loss. 316 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): 317 | Multiple choice classification loss. 318 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): 319 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 320 | mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): 321 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 322 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 323 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 324 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 325 | 326 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 327 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 328 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 329 | sequence_length, sequence_length)`. 330 | 331 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 332 | heads. 333 | """ 334 | 335 | loss: Optional[torch.FloatTensor] = None 336 | mc_loss: Optional[torch.FloatTensor] = None 337 | logits: torch.FloatTensor = None 338 | mc_logits: torch.FloatTensor = None 339 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 340 | attentions: Optional[Tuple[torch.FloatTensor]] = None 341 | 342 | 343 | OPENAI_GPT_START_DOCSTRING = r""" 344 | 345 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 346 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 347 | pruning heads etc.) 348 | 349 | This model is also a PyTorch `torch.nn.Module `__ 350 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 351 | general usage and behavior. 352 | 353 | Parameters: 354 | config (:class:`~transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model. 355 | Initializing with a config file does not load the weights associated with the model, only the 356 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 357 | weights. 358 | """ 359 | 360 | OPENAI_GPT_INPUTS_DOCSTRING = r""" 361 | Args: 362 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 363 | Indices of input sequence tokens in the vocabulary. 364 | 365 | Indices can be obtained using :class:`~transformers.OpenAIGPTTokenizer`. See 366 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 367 | details. 368 | 369 | `What are input IDs? <../glossary.html#input-ids>`__ 370 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 371 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 372 | 373 | - 1 for tokens that are **not masked**, 374 | - 0 for tokens that are **masked**. 375 | 376 | `What are attention masks? <../glossary.html#attention-mask>`__ 377 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 378 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 379 | 1]``: 380 | 381 | - 0 corresponds to a `sentence A` token, 382 | - 1 corresponds to a `sentence B` token. 383 | 384 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 385 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 386 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 387 | config.max_position_embeddings - 1]``. 388 | 389 | `What are position IDs? <../glossary.html#position-ids>`__ 390 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 391 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 392 | 393 | - 1 indicates the head is **not masked**, 394 | - 0 indicates the head is **masked**. 395 | 396 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 397 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 398 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 399 | vectors than the model's internal embedding lookup matrix. 400 | output_attentions (:obj:`bool`, `optional`): 401 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 402 | tensors for more detail. 403 | output_hidden_states (:obj:`bool`, `optional`): 404 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 405 | more detail. 406 | return_dict (:obj:`bool`, `optional`): 407 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 408 | """ 409 | 410 | 411 | @add_start_docstrings( 412 | "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", 413 | OPENAI_GPT_START_DOCSTRING, 414 | ) 415 | class OpenAIGPTModel(OpenAIGPTPreTrainedModel): 416 | def __init__(self, config): 417 | super().__init__(config) 418 | 419 | self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) 420 | self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) 421 | self.drop = nn.Dropout(config.embd_pdrop) 422 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 423 | 424 | self.register_buffer("position_ids", torch.arange(config.n_positions)) 425 | self.init_weights() 426 | 427 | def get_input_embeddings(self): 428 | return self.tokens_embed 429 | 430 | def set_input_embeddings(self, new_embeddings): 431 | self.tokens_embed = new_embeddings 432 | 433 | def _prune_heads(self, heads_to_prune): 434 | """ 435 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 436 | """ 437 | for layer, heads in heads_to_prune.items(): 438 | self.h[layer].attn.prune_heads(heads) 439 | 440 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) 441 | @add_code_sample_docstrings( 442 | checkpoint=_CHECKPOINT_FOR_DOC, 443 | output_type=BaseModelOutput, 444 | config_class=_CONFIG_FOR_DOC, 445 | ) 446 | @autocast() 447 | def forward( 448 | self, 449 | input_ids=None, 450 | attention_mask=None, 451 | token_type_ids=None, 452 | position_ids=None, 453 | head_mask=None, 454 | inputs_embeds=None, 455 | output_attentions=None, 456 | output_hidden_states=None, 457 | return_dict=None, 458 | ): 459 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 460 | output_hidden_states = ( 461 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 462 | ) 463 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 464 | 465 | if input_ids is not None and inputs_embeds is not None: 466 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 467 | elif input_ids is not None: 468 | input_shape = input_ids.size() 469 | input_ids = input_ids.view(-1, input_shape[-1]) 470 | elif inputs_embeds is not None: 471 | input_shape = inputs_embeds.size()[:-1] 472 | else: 473 | raise ValueError("You have to specify either input_ids or inputs_embeds") 474 | 475 | if position_ids is None: 476 | # Code is different from when we had a single embedding matrix from position and token embeddings 477 | position_ids = self.position_ids[None, : input_shape[-1]] 478 | 479 | # Attention mask. 480 | if attention_mask is not None: 481 | # We create a 3D attention mask from a 2D tensor mask. 482 | # Sizes are [batch_size, 1, 1, to_seq_length] 483 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 484 | # this attention mask is more simple than the triangular masking of causal attention 485 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 486 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 487 | 488 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 489 | # masked positions, this operation will create a tensor which is 0.0 for 490 | # positions we want to attend and -10000.0 for masked positions. 491 | # Since we are adding it to the raw scores before the softmax, this is 492 | # effectively the same as removing these entirely. 493 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 494 | attention_mask = (1.0 - attention_mask) * -10000.0 495 | 496 | # Prepare head mask if needed 497 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 498 | 499 | if inputs_embeds is None: 500 | inputs_embeds = self.tokens_embed(input_ids) 501 | position_embeds = self.positions_embed(position_ids) 502 | if token_type_ids is not None: 503 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 504 | token_type_embeds = self.tokens_embed(token_type_ids) 505 | else: 506 | token_type_embeds = 0 507 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 508 | hidden_states = self.drop(hidden_states) 509 | 510 | output_shape = input_shape + (hidden_states.size(-1),) 511 | 512 | all_attentions = () if output_attentions else None 513 | all_hidden_states = () if output_hidden_states else None 514 | for i, block in enumerate(self.h): 515 | if output_hidden_states: 516 | all_hidden_states = all_hidden_states + (hidden_states,) 517 | 518 | outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions) 519 | hidden_states = outputs[0] 520 | if output_attentions: 521 | all_attentions = all_attentions + (outputs[1],) 522 | 523 | hidden_states = hidden_states.view(*output_shape) 524 | # Add last layer 525 | if output_hidden_states: 526 | all_hidden_states = all_hidden_states + (hidden_states,) 527 | 528 | if not return_dict: 529 | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) 530 | 531 | return BaseModelOutput( 532 | last_hidden_state=hidden_states, 533 | hidden_states=all_hidden_states, 534 | attentions=all_attentions, 535 | ) 536 | 537 | 538 | @add_start_docstrings( 539 | """ 540 | OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input 541 | embeddings). 542 | """, 543 | OPENAI_GPT_START_DOCSTRING, 544 | ) 545 | class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): 546 | def __init__(self, config): 547 | super().__init__(config) 548 | self.transformer = OpenAIGPTModel(config) 549 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 550 | 551 | self.init_weights() 552 | 553 | def get_output_embeddings(self): 554 | return self.lm_head 555 | 556 | def set_output_embeddings(self, new_embeddings): 557 | self.lm_head = new_embeddings 558 | 559 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) 560 | @add_code_sample_docstrings( 561 | checkpoint=_CHECKPOINT_FOR_DOC, 562 | output_type=CausalLMOutput, 563 | config_class=_CONFIG_FOR_DOC, 564 | ) 565 | @autocast() 566 | def forward( 567 | self, 568 | input_ids=None, 569 | attention_mask=None, 570 | token_type_ids=None, 571 | position_ids=None, 572 | head_mask=None, 573 | inputs_embeds=None, 574 | labels=None, 575 | output_attentions=None, 576 | output_hidden_states=None, 577 | return_dict=None, 578 | ): 579 | r""" 580 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 581 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 582 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to 583 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 584 | """ 585 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 586 | 587 | transformer_outputs = self.transformer( 588 | input_ids, 589 | attention_mask=attention_mask, 590 | token_type_ids=token_type_ids, 591 | position_ids=position_ids, 592 | head_mask=head_mask, 593 | inputs_embeds=inputs_embeds, 594 | output_attentions=output_attentions, 595 | output_hidden_states=output_hidden_states, 596 | return_dict=return_dict, 597 | ) 598 | hidden_states = transformer_outputs[0] 599 | lm_logits = self.lm_head(hidden_states) 600 | 601 | loss = None 602 | if labels is not None: 603 | # Shift so that tokens < n predict n 604 | shift_logits = lm_logits[..., :-1, :].contiguous() 605 | shift_labels = labels[..., 1:].contiguous() 606 | #print("shift_logits=", shift_logits) 607 | #print("shift_labels=", shift_labels) 608 | # Flatten the tokens 609 | loss_fct = CrossEntropyLoss(ignore_index=-1) 610 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 611 | 612 | if not return_dict: 613 | output = (lm_logits,) + transformer_outputs[1:] 614 | return ((loss,) + output) if loss is not None else output 615 | 616 | return CausalLMOutput( 617 | loss=loss, 618 | logits=lm_logits, 619 | hidden_states=transformer_outputs.hidden_states, 620 | attentions=transformer_outputs.attentions, 621 | ) 622 | 623 | 624 | @add_start_docstrings( 625 | """ 626 | OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for 627 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the 628 | input embeddings, the classification head takes as input the input of a specified classification token index in the 629 | input sequence). 630 | """, 631 | OPENAI_GPT_START_DOCSTRING, 632 | ) 633 | class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): 634 | def __init__(self, config): 635 | super().__init__(config) 636 | 637 | config.num_labels = 1 638 | self.transformer = OpenAIGPTModel(config) 639 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 640 | self.multiple_choice_head = SequenceSummary(config) 641 | 642 | self.init_weights() 643 | 644 | def get_output_embeddings(self): 645 | return self.lm_head 646 | 647 | def set_output_embeddings(self, new_embeddings): 648 | self.lm_head = new_embeddings 649 | 650 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) 651 | @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 652 | @autocast() 653 | def forward( 654 | self, 655 | input_ids=None, 656 | attention_mask=None, 657 | token_type_ids=None, 658 | position_ids=None, 659 | head_mask=None, 660 | inputs_embeds=None, 661 | mc_token_ids=None, 662 | labels=None, 663 | mc_labels=None, 664 | output_attentions=None, 665 | output_hidden_states=None, 666 | return_dict=None, 667 | ): 668 | r""" 669 | mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input): 670 | Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) - 671 | 1]``. 672 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 673 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 674 | ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to 675 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 676 | mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`): 677 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., 678 | num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see 679 | `input_ids` above) 680 | 681 | Return: 682 | 683 | Examples:: 684 | 685 | >>> from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel 686 | >>> import torch 687 | 688 | >>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 689 | >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt') 690 | >>> tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!) 691 | >>> model.resize_token_embeddings(len(tokenizer)) 692 | 693 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 694 | >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices 695 | >>> mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1 696 | 697 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) 698 | >>> lm_logits = outputs.lm_logits 699 | >>> mc_logits = outputs.mc_logits 700 | """ 701 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 702 | 703 | transformer_outputs = self.transformer( 704 | input_ids, 705 | attention_mask=attention_mask, 706 | token_type_ids=token_type_ids, 707 | position_ids=position_ids, 708 | head_mask=head_mask, 709 | inputs_embeds=inputs_embeds, 710 | output_attentions=output_attentions, 711 | output_hidden_states=output_hidden_states, 712 | return_dict=return_dict, 713 | ) 714 | hidden_states = transformer_outputs[0] 715 | 716 | lm_logits = self.lm_head(hidden_states) 717 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 718 | 719 | lm_loss, mc_loss = None, None 720 | if mc_labels is not None: 721 | loss_fct = CrossEntropyLoss(ignore_index=-1) 722 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 723 | if labels is not None: 724 | shift_logits = lm_logits[..., :-1, :].contiguous() 725 | shift_labels = labels[..., 1:].contiguous() 726 | loss_fct = CrossEntropyLoss(ignore_index=-1) 727 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 728 | 729 | if not return_dict: 730 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 731 | if mc_loss is not None: 732 | output = (mc_loss,) + output 733 | return ((lm_loss,) + output) if lm_loss is not None else output 734 | 735 | return OpenAIGPTDoubleHeadsModelOutput( 736 | loss=lm_loss, 737 | mc_loss=mc_loss, 738 | logits=lm_logits, 739 | mc_logits=mc_logits, 740 | hidden_states=transformer_outputs.hidden_states, 741 | attentions=transformer_outputs.attentions, 742 | ) 743 | 744 | 745 | @add_start_docstrings( 746 | """ 747 | The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer). 748 | :class:`~transformers.OpenAIGPTForSequenceClassification` uses the last token in order to do the classification, as 749 | other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the 750 | position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that 751 | is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each 752 | row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of 753 | :obj:`input_ids`, it does the same (take the last value in each row of the batch). 754 | """, 755 | OPENAI_GPT_START_DOCSTRING, 756 | ) 757 | class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): 758 | def __init__(self, config): 759 | super().__init__(config) 760 | self.num_labels = config.num_labels 761 | self.transformer = OpenAIGPTModel(config) 762 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 763 | 764 | self.init_weights() 765 | 766 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) 767 | @add_code_sample_docstrings( 768 | checkpoint=_CHECKPOINT_FOR_DOC, 769 | output_type=SequenceClassifierOutput, 770 | config_class=_CONFIG_FOR_DOC, 771 | ) 772 | @autocast() 773 | def forward( 774 | self, 775 | input_ids=None, 776 | attention_mask=None, 777 | token_type_ids=None, 778 | position_ids=None, 779 | head_mask=None, 780 | inputs_embeds=None, 781 | labels=None, 782 | output_attentions=None, 783 | output_hidden_states=None, 784 | return_dict=None, 785 | ): 786 | r""" 787 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 788 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 789 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 790 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 791 | """ 792 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 793 | 794 | transformer_outputs = self.transformer( 795 | input_ids, 796 | attention_mask=attention_mask, 797 | token_type_ids=token_type_ids, 798 | position_ids=position_ids, 799 | head_mask=head_mask, 800 | inputs_embeds=inputs_embeds, 801 | output_attentions=output_attentions, 802 | output_hidden_states=output_hidden_states, 803 | return_dict=return_dict, 804 | ) 805 | 806 | hidden_states = transformer_outputs[0] 807 | logits = self.score(hidden_states) 808 | 809 | if input_ids is not None: 810 | batch_size, sequence_length = input_ids.shape[:2] 811 | else: 812 | batch_size, sequence_length = inputs_embeds.shape[:2] 813 | 814 | assert ( 815 | self.config.pad_token_id is not None or batch_size == 1 816 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 817 | if self.config.pad_token_id is None: 818 | sequence_lengths = -1 819 | else: 820 | if input_ids is not None: 821 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 822 | else: 823 | sequence_lengths = -1 824 | logger.warning( 825 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 826 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 827 | ) 828 | 829 | pooled_logits = logits[range(batch_size), sequence_lengths] 830 | 831 | loss = None 832 | if labels is not None: 833 | if self.num_labels == 1: 834 | # We are doing regression 835 | loss_fct = MSELoss() 836 | loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) 837 | else: 838 | loss_fct = CrossEntropyLoss(ignore_index=-1) 839 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 840 | 841 | if not return_dict: 842 | output = (pooled_logits,) + transformer_outputs[1:] 843 | return ((loss,) + output) if loss is not None else output 844 | 845 | return SequenceClassifierOutput( 846 | loss=loss, 847 | logits=pooled_logits, 848 | hidden_states=transformer_outputs.hidden_states, 849 | attentions=transformer_outputs.attentions, 850 | ) 851 | -------------------------------------------------------------------------------- /pec_baseline/models/gpt/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | 17 | 18 | # 关键包版本说明: 19 | # pytorch: 1.9.0+ 20 | # transformers: 4.11.3+ 21 | 22 | import json 23 | import os 24 | import re 25 | from typing import Optional, Tuple 26 | 27 | from transformers import PreTrainedTokenizer 28 | from transformers.utils import logging 29 | from transformers import BasicTokenizer 30 | 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | VOCAB_FILES_NAMES = { 35 | "vocab_file": "vocab.json", 36 | "merges_file": "merges.txt", 37 | } 38 | 39 | PRETRAINED_VOCAB_FILES_MAP = { 40 | "vocab_file": {"openai-gpt": "https://huggingface.co/openai-gpt/resolve/main/vocab.json"}, 41 | "merges_file": {"openai-gpt": "https://huggingface.co/openai-gpt/resolve/main/merges.txt"}, 42 | } 43 | 44 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 45 | "openai-gpt": 512, 46 | } 47 | 48 | 49 | def get_pairs(word): 50 | """ 51 | Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length 52 | strings) 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def text_standardize(text): 63 | """ 64 | fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization 65 | """ 66 | text = text.replace("—", "-") 67 | text = text.replace("–", "-") 68 | text = text.replace("―", "-") 69 | text = text.replace("…", "...") 70 | text = text.replace("´", "'") 71 | text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text) 72 | text = re.sub(r"\s*\n\s*", " \n ", text) 73 | text = re.sub(r"[^\S\n]+", " ", text) 74 | return text.strip() 75 | 76 | 77 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 78 | """ 79 | Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities: 80 | 81 | - lowercases all inputs, 82 | - uses :obj:`SpaCy` tokenizer and :obj:`ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's 83 | :obj:`BasicTokenizer` if not. 84 | 85 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 86 | Users should refer to this superclass for more information regarding those methods. 87 | 88 | Args: 89 | vocab_file (:obj:`str`): 90 | Path to the vocabulary file. 91 | merges_file (:obj:`str`): 92 | Path to the merges file. 93 | unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): 94 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 95 | token instead. 96 | """ 97 | 98 | vocab_files_names = VOCAB_FILES_NAMES 99 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 100 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 101 | model_input_names = ["input_ids", "attention_mask"] 102 | 103 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 104 | super().__init__(unk_token=unk_token, **kwargs) 105 | 106 | try: 107 | import ftfy 108 | from spacy.lang.en import English 109 | 110 | _nlp = English() 111 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 112 | self.fix_text = ftfy.fix_text 113 | except ImportError: 114 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 115 | self.nlp = BasicTokenizer(do_lower_case=True) 116 | self.fix_text = None 117 | 118 | with open(vocab_file, encoding="utf-8") as vocab_handle: 119 | self.encoder = json.load(vocab_handle) 120 | self.decoder = {v: k for k, v in self.encoder.items()} 121 | with open(merges_file, encoding="utf-8") as merges_handle: 122 | merges = merges_handle.read().split("\n")[1:-1] 123 | merges = [tuple(merge.split()) for merge in merges] 124 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 125 | self.cache = {} 126 | 127 | @property 128 | def do_lower_case(self): 129 | return True 130 | 131 | @property 132 | def vocab_size(self): 133 | return len(self.encoder) 134 | 135 | def get_vocab(self): 136 | return dict(self.encoder, **self.added_tokens_encoder) 137 | 138 | def bpe(self, token): 139 | word = tuple(token[:-1]) + (token[-1] + "",) 140 | if token in self.cache: 141 | return self.cache[token] 142 | pairs = get_pairs(word) 143 | 144 | if not pairs: 145 | return token + "" 146 | 147 | while True: 148 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 149 | if bigram not in self.bpe_ranks: 150 | break 151 | first, second = bigram 152 | new_word = [] 153 | i = 0 154 | while i < len(word): 155 | try: 156 | j = word.index(first, i) 157 | except ValueError: 158 | new_word.extend(word[i:]) 159 | break 160 | else: 161 | new_word.extend(word[i:j]) 162 | i = j 163 | 164 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 165 | new_word.append(first + second) 166 | i += 2 167 | else: 168 | new_word.append(word[i]) 169 | i += 1 170 | new_word = tuple(new_word) 171 | word = new_word 172 | if len(word) == 1: 173 | break 174 | else: 175 | pairs = get_pairs(word) 176 | word = " ".join(word) 177 | if word == "\n ": 178 | word = "\n" 179 | self.cache[token] = word 180 | return word 181 | 182 | def _tokenize(self, text): 183 | """Tokenize a string.""" 184 | split_tokens = [] 185 | if self.fix_text is None: 186 | # Using BERT's BasicTokenizer 187 | text = self.nlp.tokenize(text) 188 | for token in text: 189 | split_tokens.extend([t for t in self.bpe(token).split(" ")]) 190 | else: 191 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 192 | text = self.nlp(text_standardize(self.fix_text(text))) 193 | for token in text: 194 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")]) 195 | return split_tokens 196 | 197 | def _convert_token_to_id(self, token): 198 | """Converts a token (str) in an id using the vocab.""" 199 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 200 | 201 | def _convert_id_to_token(self, index): 202 | """Converts an id in a token (BPE) using the vocab.""" 203 | return self.decoder.get(index, self.unk_token) 204 | 205 | def convert_tokens_to_string(self, tokens): 206 | """Converts a sequence of tokens (string) in a single string.""" 207 | out_string = "".join(tokens).replace("", " ").strip() 208 | return out_string 209 | 210 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 211 | if not os.path.isdir(save_directory): 212 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 213 | return 214 | vocab_file = os.path.join( 215 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 216 | ) 217 | merge_file = os.path.join( 218 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] 219 | ) 220 | 221 | with open(vocab_file, "w", encoding="utf-8") as f: 222 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 223 | 224 | index = 0 225 | with open(merge_file, "w", encoding="utf-8") as writer: 226 | writer.write("#version: 0.2\n") 227 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 228 | if index != token_index: 229 | logger.warning( 230 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." 231 | " Please check that the tokenizer is not corrupted!" 232 | ) 233 | index = token_index 234 | writer.write(" ".join(bpe_tokens) + "\n") 235 | index += 1 236 | 237 | return vocab_file, merge_file 238 | -------------------------------------------------------------------------------- /pec_baseline/models/gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # 关键包版本说明: 21 | # pytorch: 1.9.0+ 22 | # transformers: 4.11.3 23 | 24 | 25 | from typing import TYPE_CHECKING 26 | 27 | from transformers.file_utils import is_torch_available 28 | 29 | 30 | if is_torch_available(): 31 | from .modeling_gpt2 import ( 32 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, 33 | GPT2DoubleHeadsModel, 34 | GPT2ForSequenceClassification, 35 | GPT2ForTokenClassification, 36 | GPT2LMHeadModel, 37 | GPT2Model, 38 | GPT2PreTrainedModel, 39 | load_tf_weights_in_gpt2, 40 | ) 41 | -------------------------------------------------------------------------------- /pec_baseline/models/gpt2/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | 17 | 18 | # 关键包版本说明: 19 | # pytorch: 1.9.0+ 20 | # transformers: 4.11.3 21 | 22 | 23 | import json 24 | import os 25 | from functools import lru_cache 26 | from typing import TYPE_CHECKING, List, Optional, Tuple 27 | 28 | import regex as re 29 | 30 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 31 | from transformers.utils import logging 32 | 33 | 34 | if TYPE_CHECKING: 35 | from transformers.pipelines.conversational import Conversation 36 | 37 | logger = logging.get_logger(__name__) 38 | 39 | VOCAB_FILES_NAMES = { 40 | "vocab_file": "vocab.json", 41 | "merges_file": "merges.txt", 42 | } 43 | 44 | PRETRAINED_VOCAB_FILES_MAP = { 45 | "vocab_file": { 46 | "gpt2": "https://huggingface.co/gpt2/resolve/main/vocab.json", 47 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/vocab.json", 48 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/vocab.json", 49 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/vocab.json", 50 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/vocab.json", 51 | }, 52 | "merges_file": { 53 | "gpt2": "https://huggingface.co/gpt2/resolve/main/merges.txt", 54 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/merges.txt", 55 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/merges.txt", 56 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/merges.txt", 57 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/merges.txt", 58 | }, 59 | } 60 | 61 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 62 | "gpt2": 1024, 63 | "gpt2-medium": 1024, 64 | "gpt2-large": 1024, 65 | "gpt2-xl": 1024, 66 | "distilgpt2": 1024, 67 | } 68 | 69 | 70 | @lru_cache() 71 | def bytes_to_unicode(): 72 | """ 73 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 74 | characters the bpe code barfs on. 75 | 76 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 77 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 78 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 79 | tables between utf-8 bytes and unicode strings. 80 | """ 81 | bs = ( 82 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 83 | ) 84 | cs = bs[:] 85 | n = 0 86 | for b in range(2 ** 8): 87 | if b not in bs: 88 | bs.append(b) 89 | cs.append(2 ** 8 + n) 90 | n += 1 91 | cs = [chr(n) for n in cs] 92 | return dict(zip(bs, cs)) 93 | 94 | 95 | def get_pairs(word): 96 | """ 97 | Return set of symbol pairs in a word. 98 | 99 | Word is represented as tuple of symbols (symbols being variable-length strings). 100 | """ 101 | pairs = set() 102 | prev_char = word[0] 103 | for char in word[1:]: 104 | pairs.add((prev_char, char)) 105 | prev_char = char 106 | return pairs 107 | 108 | 109 | class GPT2Tokenizer(PreTrainedTokenizer): 110 | """ 111 | Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding. 112 | 113 | This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will 114 | be encoded differently whether it is at the beginning of the sentence (without space) or not: 115 | 116 | :: 117 | 118 | >>> from transformers import GPT2Tokenizer 119 | >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 120 | >>> tokenizer("Hello world")['input_ids'] 121 | [15496, 995] 122 | >>> tokenizer(" Hello world")['input_ids'] 123 | [18435, 995] 124 | 125 | You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you 126 | call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. 127 | 128 | .. note:: 129 | 130 | When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first 131 | one). 132 | 133 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 134 | Users should refer to this superclass for more information regarding those methods. 135 | 136 | Args: 137 | vocab_file (:obj:`str`): 138 | Path to the vocabulary file. 139 | merges_file (:obj:`str`): 140 | Path to the merges file. 141 | errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`): 142 | Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode 143 | `__ for more information. 144 | unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): 145 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 146 | token instead. 147 | bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): 148 | The beginning of sequence token. 149 | eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`): 150 | The end of sequence token. 151 | add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`): 152 | Whether or not to add an initial space to the input. This allows to treat the leading word just as any 153 | other word. (GPT2 tokenizer detect beginning of words by the preceding space). 154 | """ 155 | 156 | vocab_files_names = VOCAB_FILES_NAMES 157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 158 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 159 | model_input_names = ["input_ids", "attention_mask"] 160 | 161 | def __init__( 162 | self, 163 | vocab_file, 164 | merges_file, 165 | errors="replace", 166 | unk_token="<|endoftext|>", 167 | bos_token="<|endoftext|>", 168 | eos_token="<|endoftext|>", 169 | add_prefix_space=False, 170 | **kwargs 171 | ): 172 | bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token 173 | eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token 174 | unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token 175 | super().__init__( 176 | errors=errors, 177 | unk_token=unk_token, 178 | bos_token=bos_token, 179 | eos_token=eos_token, 180 | add_prefix_space=add_prefix_space, 181 | **kwargs, 182 | ) 183 | 184 | with open(vocab_file, encoding="utf-8") as vocab_handle: 185 | self.encoder = json.load(vocab_handle) 186 | self.decoder = {v: k for k, v in self.encoder.items()} 187 | self.errors = errors # how to handle errors in decoding 188 | self.byte_encoder = bytes_to_unicode() 189 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 190 | with open(merges_file, encoding="utf-8") as merges_handle: 191 | bpe_merges = merges_handle.read().split("\n")[1:-1] 192 | bpe_merges = [tuple(merge.split()) for merge in bpe_merges] 193 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 194 | self.cache = {} 195 | self.add_prefix_space = add_prefix_space 196 | 197 | # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 198 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 199 | 200 | @property 201 | def vocab_size(self): 202 | return len(self.encoder) 203 | 204 | def get_vocab(self): 205 | return dict(self.encoder, **self.added_tokens_encoder) 206 | 207 | def bpe(self, token): 208 | if token in self.cache: 209 | return self.cache[token] 210 | word = tuple(token) 211 | pairs = get_pairs(word) 212 | 213 | if not pairs: 214 | return token 215 | 216 | while True: 217 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 218 | if bigram not in self.bpe_ranks: 219 | break 220 | first, second = bigram 221 | new_word = [] 222 | i = 0 223 | while i < len(word): 224 | try: 225 | j = word.index(first, i) 226 | except ValueError: 227 | new_word.extend(word[i:]) 228 | break 229 | else: 230 | new_word.extend(word[i:j]) 231 | i = j 232 | 233 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 234 | new_word.append(first + second) 235 | i += 2 236 | else: 237 | new_word.append(word[i]) 238 | i += 1 239 | new_word = tuple(new_word) 240 | word = new_word 241 | if len(word) == 1: 242 | break 243 | else: 244 | pairs = get_pairs(word) 245 | word = " ".join(word) 246 | self.cache[token] = word 247 | return word 248 | 249 | def _tokenize(self, text): 250 | """Tokenize a string.""" 251 | bpe_tokens = [] 252 | for token in re.findall(self.pat, text): 253 | token = "".join( 254 | self.byte_encoder[b] for b in token.encode("utf-8") 255 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) 256 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 257 | return bpe_tokens 258 | 259 | def _convert_token_to_id(self, token): 260 | """Converts a token (str) in an id using the vocab.""" 261 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 262 | 263 | def _convert_id_to_token(self, index): 264 | """Converts an index (integer) in a token (str) using the vocab.""" 265 | return self.decoder.get(index) 266 | 267 | def convert_tokens_to_string(self, tokens): 268 | """Converts a sequence of tokens (string) in a single string.""" 269 | text = "".join(tokens) 270 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 271 | return text 272 | 273 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 274 | if not os.path.isdir(save_directory): 275 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 276 | return 277 | vocab_file = os.path.join( 278 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 279 | ) 280 | merge_file = os.path.join( 281 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] 282 | ) 283 | 284 | with open(vocab_file, "w", encoding="utf-8") as f: 285 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 286 | 287 | index = 0 288 | with open(merge_file, "w", encoding="utf-8") as writer: 289 | writer.write("#version: 0.2\n") 290 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 291 | if index != token_index: 292 | logger.warning( 293 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." 294 | " Please check that the tokenizer is not corrupted!" 295 | ) 296 | index = token_index 297 | writer.write(" ".join(bpe_tokens) + "\n") 298 | index += 1 299 | 300 | return vocab_file, merge_file 301 | 302 | def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): 303 | add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) 304 | if is_split_into_words or add_prefix_space: 305 | text = " " + text 306 | return (text, kwargs) 307 | 308 | def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: 309 | input_ids = [] 310 | for is_user, text in conversation.iter_texts(): 311 | input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) 312 | if len(input_ids) > self.model_max_length: 313 | input_ids = input_ids[-self.model_max_length :] 314 | return input_ids 315 | -------------------------------------------------------------------------------- /pec_baseline/models/model_parameters.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Model parameters processing file 18 | # File: model_parameters.py 19 | # Used for model parameters analysis 20 | # Author: Chen Yirong 21 | # Date: 2022.04.06 22 | 23 | def count_trainable_parameters(model): 24 | '''获取需要训练的参数数量 25 | 使用示例:print(f'The model has {count_trainable_parameters(model):,} trainable parameters') 26 | ''' 27 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 28 | 29 | def count_total_parameters(model): 30 | '''获取模型总的参数数量 31 | 使用示例:print(f'The model has {count_total_parameters(model):,} total parameters') 32 | ''' 33 | return sum(p.numel() for p in model.parameters()) 34 | 35 | def show_trainable_parameters(model): 36 | for name, param in model.named_parameters(): 37 | if param.requires_grad: 38 | print(name) 39 | 40 | # 冻结模型参数 41 | # 参考:https://blog.csdn.net/weixin_41712499/article/details/111295683?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_title~default-5.no_search_link&spm=1001.2101.3001.4242 42 | def set_freeze_by_names(model, layer_names, freeze=True): 43 | if not isinstance(layer_names, Iterable): 44 | layer_names = [layer_names] 45 | for name, child in model.named_children(): 46 | if name not in layer_names: 47 | continue 48 | for param in child.parameters(): 49 | param.requires_grad = not freeze 50 | 51 | def freeze_by_names(model, layer_names): 52 | set_freeze_by_names(model, layer_names, True) 53 | 54 | def unfreeze_by_names(model, layer_names): 55 | set_freeze_by_names(model, layer_names, False) 56 | 57 | def set_freeze_by_idxs(model, idxs, freeze=True): 58 | if not isinstance(idxs, Iterable): 59 | idxs = [idxs] 60 | num_child = len(list(model.children())) 61 | idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs)) 62 | for idx, child in enumerate(model.children()): 63 | if idx not in idxs: 64 | continue 65 | for param in child.parameters(): 66 | param.requires_grad = not freeze 67 | 68 | def freeze_by_idxs(model, idxs): 69 | set_freeze_by_idxs(model, idxs, True) 70 | 71 | def unfreeze_by_idxs(model, idxs): 72 | set_freeze_by_idxs(model, idxs, False) 73 | 74 | def freeze_by_model_name(model, model_name): 75 | for name, param in model.named_parameters(): 76 | if name.startswith(model_name): 77 | param.requires_grad = False 78 | 79 | def unfreeze_by_model_name(model, model_name): 80 | for name, param in model.named_parameters(): 81 | if name.startswith(model_name): 82 | param.requires_grad = True -------------------------------------------------------------------------------- /pec_baseline/train_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Model training code 18 | # File: train_model.py 19 | # Used for training model 20 | # Author: Chen Yirong 21 | # Date: 2022.04.06 22 | 23 | # 关键包版本说明: 24 | # pytorch: 1.9.0+ 25 | # pytorch-ignite: 0.4.8 26 | # transformers: 4.18.0 27 | 28 | import os 29 | import json 30 | import time 31 | import math 32 | import torch 33 | import socket 34 | import random 35 | import logging 36 | import numpy as np 37 | from pprint import pformat 38 | from argparse import ArgumentParser # 用于函数文件传递参数 39 | from torch.optim.lr_scheduler import LambdaLR, CyclicLR, OneCycleLR 40 | from torch.nn.parallel import DistributedDataParallel # 用于分布式模型训练 41 | from torch.cuda.amp import autocast as autocast # 用于使用自动混合精度,要求torch版本为1.6+ 42 | from ignite.engine import Engine, Events 43 | from ignite.handlers import ModelCheckpoint, EarlyStopping, Checkpoint 44 | from ignite.metrics import Loss, MetricsLambda, RunningAverage 45 | from ignite.contrib.handlers import ProgressBar, PiecewiseLinear, LRScheduler 46 | from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine 47 | from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler 48 | 49 | from transformers import (WEIGHTS_NAME, CONFIG_NAME, BertTokenizer, OpenAIGPTTokenizer, OpenAIGPTConfig, GPT2Config) # , AdamW 50 | 51 | from torch.optim import AdamW 52 | 53 | # 本项目自主撰写的python包 54 | from models import (count_trainable_parameters, count_total_parameters, show_trainable_parameters, freeze_by_model_name, unfreeze_by_model_name) 55 | # 从transformers的代码修改,适配混合精度训练 56 | from models.gpt import OpenAIGPTLMHeadModel 57 | from models.gpt2 import GPT2LMHeadModel 58 | # 读取数据集的dataloader构建函数 59 | from utils import build_cped_dataloaders 60 | 61 | logger = logging.getLogger(__file__) 62 | 63 | def setup_seed(seed): 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | np.random.seed(seed) 67 | random.seed(seed) 68 | torch.backends.cudnn.deterministic = True 69 | 70 | setup_seed(2022) 71 | 72 | 73 | def average_distributed_scalar(scalar, args): 74 | """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ 75 | if args.local_rank == -1: 76 | return scalar 77 | scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size() 78 | torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) 79 | return scalar_t.item() 80 | 81 | def score_function(engine): 82 | '''最小化ppl,也就是最大化负的ppl 83 | 84 | ''' 85 | return -engine.state.metrics["average_ppl"] 86 | 87 | 88 | 89 | def train(): 90 | '''train() 91 | 封装好的训练模型过程 92 | 93 | ''' 94 | # 参数定义 95 | parser = ArgumentParser() 96 | # 模型类型以及路径 97 | parser.add_argument("--model_type", type=str, default="GPT", choices=['GPT', 'GPT-2'], help="Type of Model(模型类型名称)") 98 | parser.add_argument("--model_checkpoint", type=str, default="/home/phd-chen.yirong/PretrainedModel/models.huggingface.co/bert-base-chinese", help="Path or URL of the model") 99 | parser.add_argument("--gpt_model_checkpoint", type=str, default="/home/phd-chen.yirong/PretrainedModel/CDial-GPT/LCCD_GPT_FOR_GPTSPEAKERROBOT", help="Path or URL of the GPT model used to initialized the GPT part of the UiBot.") 100 | parser.add_argument("--bert_model_checkpoint", type=str, default="./runs/SPEAKERBERT", help="Path or URL of the BERT model used to initialized the BERT part of the UiBot.") 101 | parser.add_argument('--log_file', '-log_file', type=str, default="./logs", help="Output logs to a file under this path") 102 | 103 | # 数据集名称、路径等配置 104 | parser.add_argument("--dataset_name", type=str, default="CPED", choices=['CPED', 'MELD', 'CPED-shuffle'], help="Name of Dataset(数据集名称)") 105 | parser.add_argument("--data_path", type=str, default="./data/CPED/", help="dir of the dataset(数据集保存路径,可以是目录或者文件名)") 106 | parser.add_argument("--cache_path", type=str, default="./data/CPED_cache_for_CpedDataset", help="path of the dataset cache(数据集缓存文件的保存路径,必须为文件名)") 107 | parser.add_argument("--use_speaker_name_as_speaker_list", action='store_true', 108 | help="If true using speaker name as speaker_list") 109 | parser.add_argument("--emotion_type", type=str, default="Emotion", choices=['Sentiment', 'BaseEmotion', 'Emotion'], help="Type of Emotion") 110 | parser.add_argument("--da_type", type=str, default="DA", choices=['DA', 'BaseDA'], help="Type of DA") 111 | parser.add_argument('--with_emotion', action='store_true', help="use emotion as token_type") 112 | parser.add_argument('--with_da', action='store_true', help="use da as token_type") 113 | parser.add_argument('--with_current_speaker', action='store_true', help="use current speaker as control signal") 114 | parser.add_argument('--with_current_persona', action='store_true', help="use current persona as control signal") 115 | parser.add_argument('--with_current_emotion', action='store_true', help="use current emotion as control signal") 116 | parser.add_argument('--with_current_da', action='store_true', help="use current da as control signal") 117 | parser.add_argument('--set_eda_in_speaker', action='store_true', help="set eda in speaker") 118 | parser.add_argument('--set_current_speaker_mask', action='store_true', help="set current_speaker_mask") 119 | 120 | # 训练模型配置 121 | parser.add_argument('--find_unused_parameters', action='store_true', help="If True find_unused_parameters") 122 | parser.add_argument('--show_parameters', action='store_true', help="If True show model parameters") 123 | parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step") 124 | parser.add_argument('--pretrained', action='store_true', help="If False train from scratch") 125 | ## Adamw优化器参数配置:正则化、beta1、beta2、eps数值设置 126 | parser.add_argument('--L2_regularization', action='store_true', help="If False train without L2 Regularization") 127 | parser.add_argument("--L2_weight_decay", type=float, default=1e-2, help="L2 weight decay") 128 | parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Adam's beta1 parameter") 129 | parser.add_argument("--adamw_beta2", type=float, default=0.999, help="Adam's beta2 parameter") 130 | parser.add_argument("--adamw_eps", type=float, default=1e-6, help="Adam's epsilon for numerical stability") 131 | 132 | # 损失函数的各部分占比 133 | parser.add_argument("--alpha_nll", type=float, default=1.0, help="alpha_nll") 134 | parser.add_argument("--alpha_emotion", type=float, default=1.0, help="alpha_emotion") 135 | parser.add_argument("--alpha_da", type=float, default=1.0, help="alpha_da") 136 | parser.add_argument("--alpha_per_gen", type=float, default=1.0, help="alpha_per_gen") 137 | parser.add_argument("--alpha_per_neu", type=float, default=1.0, help="alpha_per_neu") 138 | parser.add_argument("--alpha_per_ext", type=float, default=1.0, help="alpha_per_ext") 139 | parser.add_argument("--alpha_per_ope", type=float, default=1.0, help="alpha_per_ope") 140 | parser.add_argument("--alpha_per_agr", type=float, default=1.0, help="alpha_per_agr") 141 | parser.add_argument("--alpha_per_con", type=float, default=1.0, help="alpha_per_con") 142 | 143 | # 冻结模型的部分层 144 | parser.add_argument('--freeze_model', action='store_true', help="If True freeze some layers of the model") 145 | parser.add_argument('--freeze_start_layer', type=int, default=0, help="冻结指定的层范围,格式为start-end,其中start取值范围为0~11,end取值范围为0~11,start<=end") 146 | parser.add_argument('--freeze_end_layer', type=int, default=11, help="冻结指定的层范围,格式为start-end,其中start取值范围为0~11,end取值范围为0~11,start<=end") 147 | 148 | 149 | parser.add_argument("--not_return_dict", action='store_true', help="If False the model return dict as result(适配最新版本transformers的模型输出)") 150 | parser.add_argument("--do_unscale", action='store_true', help="Calling scaler.unscale_(optimizer) before clipping enables you to clip unscaled gradients as usual(梯度裁剪)") 151 | parser.add_argument("--retain_graph", action='store_true', help="If true using retain_graph=True in loss.backward") 152 | parser.add_argument("--num_workers", type=int, default=4, help="Number of subprocesses for data loading") 153 | parser.add_argument("--n_epochs", type=int, default=32, help="Number of training epochs") 154 | parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training") 155 | parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation") 156 | parser.add_argument("--test_batch_size", type=int, default=1, help="Batch size for testing") 157 | parser.add_argument("--max_history", type=int, default=25, help="Number of previous exchanges to keep in history") 158 | 159 | parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)") 160 | # 学习率设置 161 | parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear', 'cyclic', '1cycle','fixedlr'], help="method of optim") 162 | parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate") 163 | parser.add_argument("--base_lr", type=float, default=1e-3, help="Initial learning rate which is the lower boundary in the cycle for each parameter group.") 164 | parser.add_argument("--max_lr", type=float, default=5e-3, help="Upper learning rate boundaries in the cycle for each parameter group.") 165 | ## CyclicLR 166 | parser.add_argument("--cycliclr_mode", type=str, default="triangular2", choices=['triangular', 'triangular2', 'exp_range'], help="mode of CyclicLR, see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html#torch.optim.lr_scheduler.CyclicLR") 167 | 168 | parser.add_argument("--eval_before_start", action='store_true', 169 | help="If true start with a first evaluation before training") 170 | parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps") 171 | parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps") 172 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradients on several steps") 173 | parser.add_argument("--max_norm", type=float, default=2.0, help="Clipping gradient norm") 174 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", 175 | help="Device (cuda or cpu)") 176 | parser.add_argument("--autocast", action='store_true', 177 | help="If true using autocast to automatically mix accuracy to accelerate training(开启自动混合精度加速训练)") 178 | parser.add_argument("--local_rank", type=int, default=-1, 179 | help="Local rank for distributed training (-1: not distributed)") 180 | 181 | args = parser.parse_args() 182 | 183 | # 日志文件配置 184 | # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. 185 | # logger.info => log main process only, logger.warning => log all processes 186 | # the name of log file looks like './logs/Jan11_22-55-46_gpu144_GPT_CPED.log' 187 | # The log information is output to a specified file, which is convenient 188 | # for viewing the log when the program is running time specified by the ```at``` command. 189 | if not os.path.exists(args.log_file): 190 | # 不存在log目录则创建 191 | os.makedirs(args.log_file) 192 | log_file_name_or_tensorboard_dir_name = str(time.strftime('%b%d_%H-%M-%S',time.localtime(time.time())))+'_'+str(socket.gethostname())+'_'+args.model_type+'_'+args.dataset_name 193 | logging_file_name = os.path.join(args.log_file,log_file_name_or_tensorboard_dir_name+'.log') 194 | logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 195 | filename=logging_file_name) # output the log information to the log file 196 | logger.warning("Running process %d", args.local_rank) 197 | logger.info("Arguments: %s", pformat(args)) 198 | 199 | # 数据集的训练集、验证集、测试集的文件名称 200 | cped_filenames = {"train":"train_split.csv", 201 | "valid":"valid_split.csv", 202 | "test":"test_split.csv"} 203 | meld_filenames = {"train":"train_sent_emo.csv", 204 | "valid":"dev_sent_emo.csv", 205 | "test":"test_sent_emo.csv"} 206 | cped_shuffle_filenames = {"train":"train_shuffle_split.csv", 207 | "valid":"valid_shuffle_split.csv", 208 | "test":"test_shuffle_split.csv"} 209 | if args.dataset_name == "MELD": 210 | filenames = meld_filenames 211 | elif args.dataset_name == "CPED": 212 | filenames = cped_filenames 213 | elif args.dataset_name == "CPED-shuffle": 214 | filenames = cped_shuffle_filenames 215 | 216 | # 多机分布式训练配置 217 | # Initialize distributed training if needed 218 | args.distributed = (args.local_rank != -1) 219 | if args.distributed: 220 | torch.cuda.set_device(args.local_rank) 221 | args.device = torch.device("cuda", args.local_rank) 222 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 223 | 224 | logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") 225 | 226 | # 根据模型类型确定模型类以及tokenizer 227 | if args.model_type == 'GPT': 228 | model_class = OpenAIGPTLMHeadModel 229 | config_class = OpenAIGPTConfig 230 | tokenizer_class = BertTokenizer 231 | 232 | elif args.model_type == 'GPT-2': 233 | model_class = GPT2LMHeadModel 234 | config_class = GPT2Config 235 | tokenizer_class = BertTokenizer 236 | 237 | 238 | '''此处增加模型 239 | elif args.model_type == 'MODELNAME': 240 | model_class = ModelClass 241 | config_class = ModelConfigClass 242 | tokenizer_class = TokenizerClass 243 | ''' 244 | 245 | 246 | # 初始化模型参数 247 | if args.pretrained: # 加载预训练模型参数 248 | tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, do_lower_case=True) 249 | model = model_class.from_pretrained(args.model_checkpoint) 250 | else: # 不是从预训练模型中初始化 251 | print("不是从预训练模型中初始化") 252 | 253 | # 输出模型结构与参数信息 254 | if args.show_parameters: 255 | print(model) # 输出网络结构 256 | # 输出模型总的参数量 257 | # https://huggingface.co/docs/transformers/main_classes/model#transformers.modeling_utils.ModuleUtilsMixin.num_parameters 258 | total_params = model.num_parameters() #直接调用模型自身函数计算模型的参数量 259 | print(f'{total_params:,} total parameters.') 260 | 261 | # 冻结模型的部分层 262 | if args.freeze_model: 263 | freeze_by_model_name(model, "transformer.tokens_embed") 264 | freeze_by_model_name(model, "transformer.positions_embed") 265 | for i in range(args.freeze_start_layer,args.freeze_end_layer+1): 266 | # 冻结args.freeze_start_layer~args.freeze_end_layer层 267 | freeze_by_model_name(model, "transformer.h."+str(i)+".") # 加点在后面,防止冻结1~6时,把第10、11层也冻结了 268 | 269 | # 查看模型的可训练的层 270 | print("可训练的层如下所示:") 271 | for name, param in model.named_parameters(): 272 | if param.requires_grad: 273 | print(name,':',param.size()) 274 | # 计算模型可训练的参数量 275 | # total_trainable_params = count_trainable_parameters(model) 276 | total_trainable_params = model.num_parameters(only_trainable=True) 277 | print(f'{total_trainable_params:,} total trainable parameters.') 278 | 279 | model.to(args.device) 280 | 281 | 282 | 283 | if args.L2_regularization: # L2 Regularization 284 | # 旧版本的L2正则化 285 | # reference: https://blog.csdn.net/mch2869253130/article/details/105994044 286 | # 不参与L2正则化的列表 287 | # optimizer = AdamW([{'params': model.parameters(), 'weight_decay': args.L2_weight_decay, 'initial_lr': args.lr}], lr=args.lr, betas=(0.9, 0.999), eps=1e-6, weight_decay=args.L2_weight_decay, correct_bias=True) 288 | 289 | # 新版本的 290 | # 参考:https://arxiv.org/pdf/1711.05101.pdf 291 | # https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html?highlight=adamw#torch.optim.AdamW 292 | # transformers 版本:from transformers import AdamW 293 | # optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay, correct_bias=True) 294 | # pytorch版本:from torch.optim import AdamW 295 | optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay) 296 | ''' 297 | no_decay = ['bias', 'bias_ih_l0', 'bias_hh_l0', 'LayerNorm.weight','layernorm_1.weight'] 298 | parameters_list = [] 299 | optimizer_grouped_parameters = [ 300 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.L2_weight_decay, 'initial_lr': args.lr}, 301 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'initial_lr': args.lr} 302 | ] 303 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, correct_bias=True) 304 | ''' 305 | else: # not L2 Regularization 306 | # transformers 版本:from transformers import AdamW 307 | # optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay, correct_bias=True) 308 | # pytorch版本:from torch.optim import AdamW 309 | optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps) 310 | 311 | if args.autocast: 312 | # 混合精度训练 313 | # 参考: https://pytorch.org/docs/1.9.0/amp.html?highlight=torch%20cuda%20amp%20gradscaler 314 | # https://pytorch.org/docs/1.9.0/notes/amp_examples.html#amp-examples 315 | scaler = torch.cuda.amp.GradScaler() # pytorch版本要求:1.6+ 316 | 317 | if args.distributed: 318 | # Add "find_unused_parameters=True" to avoid the following error 319 | # ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception: 320 | # Expected to have finished reduction in the prior iteration before starting a new one. 321 | # This error indicates that your module has parameters that were not used in producing loss. 322 | # You can enable unused parameter detection by (1) passing the keyword argument 323 | # `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; 324 | if args.find_unused_parameters: 325 | model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 326 | else: 327 | model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) # ,find_unused_parameters=True 328 | 329 | 330 | logger.info("Prepare datasets...") 331 | if args.dataset_name == 'CPED': 332 | loader_class = build_cped_dataloaders 333 | 334 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 335 | train_loader, valid_loader, train_sampler, valid_sampler = build_cped_dataloaders(args, tokenizer, logger, load_test=False, filenames=filenames) 336 | test_loader, test_sampler = build_cped_dataloaders(args, tokenizer, logger, load_test=True, filenames=filenames) 337 | 338 | 339 | 340 | # Training function and trainer 341 | def update(engine, batch): 342 | model.train() 343 | # 从batch中读取数据 344 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 345 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch) 346 | 347 | 348 | # 模型进行前向计算 349 | '''参考代码 350 | if args.model_type == 'SPEAKERBERT': 351 | if args.autocast: 352 | with autocast(): 353 | (lm_loss), *_ = model(input_ids=input_ids, masked_lm_labels=lm_labels, speaker_type_ids=speaker_type_ids) 354 | else: 355 | (lm_loss), *_ = model(input_ids=input_ids, masked_lm_labels=lm_labels, speaker_type_ids=speaker_type_ids) 356 | ''' 357 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 358 | if args.autocast: 359 | with autocast(): 360 | CausalLMOutput = model(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids) 361 | else: 362 | CausalLMOutput = model(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids) 363 | lm_loss = CausalLMOutput.loss 364 | #print("results=", lm_loss) 365 | 366 | 367 | # 反向传递 368 | if args.autocast: 369 | with autocast(): 370 | loss = lm_loss / args.gradient_accumulation_steps 371 | else: 372 | loss = lm_loss / args.gradient_accumulation_steps 373 | if args.autocast: # 混合精度训练,要求:torch1.6+ 374 | scaler.scale(loss).backward(retain_graph=args.retain_graph) # retain_graph here is unrelated to amp, it's present because in this both backward() calls share some sections of graph. 375 | 376 | if engine.state.iteration % args.gradient_accumulation_steps == 0: 377 | # 参考:https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation 378 | if args.do_unscale: 379 | # 参考:https://pytorch.org/docs/1.9.0/notes/amp_examples.html#amp-examples 380 | # Unscales the gradients of optimizer's assigned params in-place 381 | scaler.unscale_(optimizer) 382 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual 383 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 384 | scaler.step(optimizer) 385 | scaler.update() 386 | optimizer.zero_grad() 387 | 388 | else: 389 | loss.backward(retain_graph=args.retain_graph) # 增加:retain_graph=True,以解决:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. 390 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) # 梯度剪切函数,为防止梯度爆炸 391 | if engine.state.iteration % args.gradient_accumulation_steps == 0: 392 | optimizer.step() 393 | optimizer.zero_grad() 394 | 395 | return loss.item(), optimizer.param_groups[0]['lr'] 396 | 397 | trainer = Engine(update) 398 | 399 | # Evaluation function and evaluator (evaluator output is the input of the metrics) 400 | def inference(engine, batch): 401 | model.eval() 402 | # 从batch中读取数据 403 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 404 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch) 405 | 406 | with torch.no_grad(): 407 | # 模型进行前向计算 408 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 409 | if args.autocast: 410 | with autocast(): 411 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids) 412 | else: 413 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids) 414 | 415 | 416 | lm_logits = CausalLMOutput.logits 417 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) 418 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) 419 | return lm_logits_flat_shifted, lm_labels_flat_shifted 420 | 421 | evaluator = Engine(inference) 422 | 423 | # Evaluation function and evaluator (evaluator output is the input of the metrics) 424 | def test(engine, batch): 425 | model.eval() 426 | # 从batch中读取数据 427 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 428 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch) 429 | 430 | with torch.no_grad(): 431 | # 模型进行前向计算 432 | if args.model_type == 'GPT' or args.model_type == 'GPT-2': 433 | if args.autocast: 434 | with autocast(): 435 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids) 436 | else: 437 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids) 438 | lm_logits = CausalLMOutput.logits 439 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) 440 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) 441 | return lm_logits_flat_shifted, lm_labels_flat_shifted 442 | 443 | testor = Engine(test) 444 | 445 | 446 | # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch 447 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(valid_loader)) 448 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: testor.run(test_loader)) 449 | 450 | if args.n_epochs < 1: 451 | trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(valid_loader)) 452 | trainer.add_event_handler(Events.COMPLETED, lambda _: testor.run(test_loader)) 453 | 454 | if args.eval_before_start: 455 | trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(valid_loader)) 456 | trainer.add_event_handler(Events.COMPLETED, lambda _: testor.run(test_loader)) 457 | 458 | 459 | # Evaluation during training 460 | @trainer.on(Events.ITERATION_STARTED) 461 | def log_iterations(engine): 462 | # if engine.state.iteration % max(int(0.1 * len(train_loader)), 1) == 0: 463 | if engine.state.iteration % args.valid_steps == 0: 464 | evaluator.run(valid_loader) 465 | testor.run(test_loader) 466 | 467 | # Make sure distributed data samplers split the dataset nicely between the distributed processes 468 | if args.distributed: 469 | trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) 470 | evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) 471 | testor.add_event_handler(Events.EPOCH_STARTED, lambda engine: test_sampler.set_epoch(engine.state.epoch)) 472 | 473 | # noam decrease the learning rate 474 | # d_model = model.config.n_embd 475 | # 参考论文《transformer is all you need》第5.3节 476 | d_model = args.n_emd 477 | noam_lambda = lambda step: ( 478 | d_model ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5))) 479 | noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step) 480 | if args.scheduler == "noam": 481 | scheduler = LRScheduler(noam_scheduler) 482 | if args.scheduler == "linear": 483 | scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) 484 | if args.scheduler == "cyclic": 485 | # 文章https://arxiv.org/pdf/1506.01186.pdf 486 | scheduler = LRScheduler(CyclicLR(optimizer, base_lr=args.base_lr, max_lr=args.max_lr, step_size_up=2000, step_size_down=2000, mode=args.cycliclr_mode, gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)) 487 | if args.scheduler == "1cycle": 488 | scheduler = LRScheduler(OneCycleLR(optimizer, args.lr, total_steps=args.n_epochs, epochs=None, steps_per_epoch=None, pct_start=0.3, anneal_strategy='cos', cycle_momentum=False, base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=10000.0, last_epoch=-1)) 489 | if args.scheduler == "fixedlr": 490 | scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), args.lr)]) 491 | 492 | trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) 493 | 494 | 495 | 496 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") 497 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "lr") 498 | metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0], x[1]))} 499 | metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)}) 500 | metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) 501 | for name, metric in metrics.items(): 502 | metric.attach(evaluator, name) 503 | metric.attach(testor, name) 504 | 505 | # On the main process: add progress bar, tensorboard, checkpoints 506 | # And save model, configuration and tokenizer before we start to train 507 | if args.local_rank in [-1, 0]: 508 | pbar = ProgressBar(persist=True, mininterval=2) 509 | pbar.attach(trainer, metric_names=["loss", "lr"]) 510 | evaluator.add_event_handler(Events.COMPLETED, 511 | lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) 512 | testor.add_event_handler(Events.COMPLETED, 513 | lambda _: pbar.log_message("Test: %s" % pformat(testor.state.metrics))) 514 | 515 | # tb_logger = TensorboardLogger(log_dir=None, comment='_'+args.model_type+'_'+args.dataset_name) 516 | # 统一日志名称和tensorboard输出文件夹名称, 2022.04.11 517 | tb_logger = TensorboardLogger(log_dir=os.path.join("./runs/",log_file_name_or_tensorboard_dir_name)) 518 | 519 | tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), 520 | event_name=Events.ITERATION_COMPLETED) 521 | 522 | tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) 523 | 524 | tb_logger.attach_output_handler(evaluator, 525 | event_name=Events.EPOCH_COMPLETED, 526 | tag="validation", 527 | metric_names=list(metrics.keys()), 528 | global_step_transform=global_step_from_engine(trainer)) 529 | tb_logger.attach_output_handler(testor, 530 | event_name=Events.EPOCH_COMPLETED, 531 | tag="test", 532 | metric_names=list(metrics.keys()), 533 | global_step_transform=global_step_from_engine(trainer)) 534 | ''' 535 | tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys())), 536 | event_name=Events.EPOCH_COMPLETED, global_step_transform=global_step_from_engine(trainer)) 537 | tb_logger.attach(testor, log_handler=OutputHandler(tag="test", metric_names=list(metrics.keys())), 538 | event_name=Events.EPOCH_COMPLETED, global_step_transform=global_step_from_engine(trainer)) 539 | ''' 540 | 541 | # 连续3个epoch,测试集的ppl没有下降就停止训练 542 | early_stop_handler = EarlyStopping(patience=3, score_function=score_function, trainer=trainer) 543 | testor.add_event_handler(Events.COMPLETED, early_stop_handler) 544 | 545 | #checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) 546 | # ignite v0.4.8版本 547 | 548 | best_model_handler = Checkpoint( 549 | {"model": model}, 550 | tb_logger.writer.log_dir, 551 | filename_prefix="best", 552 | n_saved=2, 553 | global_step_transform=global_step_from_engine(trainer), 554 | score_name="test_ppl", 555 | score_function=score_function, 556 | ) 557 | testor.add_event_handler(Events.COMPLETED, best_model_handler) 558 | 559 | ''' 560 | # save model after evaluation 561 | testor.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 562 | 'mymodel': getattr(model, 'module', model)}) 563 | evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 564 | 'mymodel': getattr(model, 'module', model)}) 565 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 566 | 'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation 567 | ''' 568 | 569 | torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') 570 | getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) 571 | 572 | #tokenizer.save_vocabulary(tb_logger.writer.log_dir) 573 | # save the new tokens vacab 574 | tokenizer.save_pretrained(tb_logger.writer.log_dir) 575 | with open(tb_logger.writer.log_dir + "/training_args.json",'w',encoding='utf-8') as json_file: 576 | json.dump(pformat(args),json_file,ensure_ascii=False) 577 | 578 | # Run the training 579 | trainer.run(train_loader, max_epochs=args.n_epochs) 580 | 581 | # On the main process: close tensorboard logger and rename the last checkpoint 582 | # (for easy re-loading with OpenAIGPTModel.from_pretrained method) 583 | if args.local_rank in [-1, 0] and args.n_epochs > 0: 584 | # 重命名模型名称为WEIGHTS_NAME指定的字符串 585 | # ignite0.4.8有所调整 586 | # os.rename(os.path.join(tb_logger.writer.log_dir, checkpoint_handler._saved[-1][1]), 587 | # os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) 588 | os.rename(os.path.join(tb_logger.writer.log_dir, best_model_handler.last_checkpoint), 589 | os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) 590 | 591 | tb_logger.close() 592 | 593 | 594 | if __name__ == "__main__": 595 | train() -------------------------------------------------------------------------------- /pec_baseline/train_model_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # coding=utf-8 3 | # Copyright 2021 South China University of Technology and 4 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Author: Chen Yirong 19 | # Date: 2022.04.06 20 | 21 | # Train model with torch.distributed.launch by using torch.distributed.launch 22 | # Trained on Ubuntu18.04 with 2 GeForce RTX 2080ti GPUs and 125G Memory 23 | # --n_epochs=120 --warmup_steps=10000 --lr=5.5e-3 --pretrained --train_batch_size=4 --valid_batch_size=2 --test_batch_size=1 --num_workers=4 24 | 25 | # CPU个数:2,单个CPU核心数:10 26 | # cat /proc/cpuinfo| grep "physical id"| sort| uniq| wc -l 27 | # cat /proc/cpuinfo| grep "cpu cores"| uniq 28 | 29 | 30 | 31 | ############################################################################# 32 | # 中文数据集CPED 33 | dataset_name=CPED 34 | data_path=../data/CPED 35 | model_checkpoint=~/PretrainedModel/CDial-GPT/CDial-GPT_LCCC-base 36 | cache_path=../data/CPED_cache_for_CpedDataset 37 | 38 | # GPT 39 | CUDA_VISIBLE_DEVICES=5 python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 2301 train_model.py --model_type=GPT --model_checkpoint=~/PretrainedModel/CDial-GPT/CDial-GPT_LCCC-base --pretrained --dataset_name=CPED --data_path=../data/CPED/ --cache_path=../data/CPED_cache_for_CpedDataset --lr=6.25e-5 --scheduler=noam --autocast --train_batch_size=1 --valid_batch_size=1 --test_batch_size=1 --n_epochs=2 40 | -------------------------------------------------------------------------------- /pec_baseline/utils/README.md: -------------------------------------------------------------------------------- 1 | # utils 2 | # 数据读取模块 3 | * **Author: Chen Yirong ** 4 | * **Date: 2022.03.21** 5 | 6 | ## 架构说明 7 | ### 基础文件 8 | base_util.py: 读取数据集的一些共用函数 9 | dataset_statistics.py: 用于数据集的统计等 10 | 11 | ### 数据集处理文件 12 | 每一个数据集(假设名字为:xxx)都有两个.py文件,格式如下: 13 | * xxx_util.py: 封装读取数据集的函数以及一些常量 14 | * xxx_dataset.py: 构建torch.utils.data.Dataset的子类、build_xxx_dataloaders函数 15 | 16 | ### CPED数据集的处理文件: 17 | * cped_util.py: 封装读取CPED数据集的函数以及一些常量 18 | * cped_dataset.py: 构建torch.utils.data.Dataset的子类、build_xxx_dataloaders函数 19 | -------------------------------------------------------------------------------- /pec_baseline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Dataset data loading module 18 | # Author: Chen Yirong 19 | # Date: 2022.03.21 20 | 21 | __version__ = "1.0.0" 22 | 23 | # 关键包版本说明: 24 | # pytorch: 1.9.0+ 25 | 26 | 27 | try: 28 | import absl.logging 29 | absl.logging.set_verbosity('info') 30 | absl.logging.set_stderrthreshold('info') 31 | absl.logging._warn_preinit_stderr = False 32 | except: 33 | pass 34 | 35 | import logging 36 | 37 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | # Files and general utilities 41 | from .base_util import (load_csv_data_from_dir,shuffle_total_data,combine_csv_files,save_speaker,load_speaker, 42 | convert_speaker_to_id,convert_id_to_speaker,convert_cache_to_csv,is_torch_available) 43 | 44 | from .dataset_statistics import (get_data_for_analysis, get_totaldata_for_analysis, get_row_statistics, cout_dialogue_words) 45 | 46 | 47 | # Dataset utilities 48 | if is_torch_available(): 49 | # 读取CPED数据集的若干种函数 50 | from .cped_util import (cped_get_single_file,cped_get_single_cache_file,cped_get_data_from_dir, 51 | cped_get_single_file_for_bert_gpt,cped_get_data_from_dir_for_bert_gpt, 52 | CPED_SPECIAL_TOKENS,CPED_IGNORE_ID,CPED_DA_TOKENS,CPED_SENTIMENT_TOKENS, 53 | CPED_EMOTION_TOKENS,CPED_DA_TO_TOKENS,CPED_SENTIMENT_TO_TOKENS,CPED_EMOTION_TO_TOKENS, 54 | CPED_DA_TO_ID,CPED_EMOTION_TO_ID,CPED_GENDER_TO_ID,CPED_BIGFIVE_TO_ID,CPED_SPEAKER_TYPE_TO_ID) 55 | 56 | from .cped_dataset import (CpedDataset, build_cped_dataloaders, find_split_id_of_response, create_speaker_type, 57 | convert_emotion_to_tokens, convert_da_to_tokens, set_da_in_speaker, set_emotion_in_speaker) 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /pec_baseline/utils/base_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Dataset data loading file 18 | # File: base_util.py 19 | # Used for dataset loading 20 | # 用于数据集读取的基础方法 21 | # Author: Chen Yirong 22 | # Date: 2022.03.21 23 | 24 | import os 25 | import re 26 | import math 27 | import json 28 | import shutil 29 | import random 30 | import collections 31 | import importlib.util 32 | import pandas as pd 33 | import logging 34 | from io import open 35 | from os.path import join 36 | from .dataset_statistics import get_row_statistics 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | _torch_available = importlib.util.find_spec("torch") is not None 42 | 43 | 44 | def is_torch_available(): 45 | return _torch_available 46 | 47 | def load_csv_data_from_dir(data_dir="../data/CPED", 48 | file_dict={"train":"train_split.csv", 49 | "valid":"valid_split.csv", 50 | "test":"test_split.csv"}): 51 | '''get_data from dir, which have train_split.csv, valid_split.csv, test_split.csv file 52 | Inputs: 53 | **data_dir**: str, 54 | 55 | ''' 56 | print("Read dataset from ", data_dir) 57 | train_data = pd.read_csv(join(data_dir,file_dict["train"]), encoding="UTF-8-SIG") 58 | valid_data = pd.read_csv(join(data_dir,file_dict["valid"]), encoding="UTF-8-SIG") 59 | test_data = pd.read_csv(join(data_dir,file_dict["test"]), encoding="UTF-8-SIG") 60 | return train_data, valid_data, test_data 61 | 62 | 63 | def shuffle_total_data(data_path, 64 | save_path, 65 | validation_split_percentage=0.1, 66 | test_split_percentage=0.1, 67 | file_names = ["train_shuffle_split.csv", "valid_shuffle_split.csv", "test_shuffle_split.csv"], 68 | regen=False): 69 | '''shuffle_total_data 70 | 功能:将一个.csv文件随机打乱,拆分为训练集、验证集、测试集,分别保存 71 | 输入: 72 | data_path: .csv文件的路径 73 | save_path: 拆分后的文件保存的目录 74 | validation_split_percentage: 验证集比例 75 | 76 | ''' 77 | if regen==False: 78 | print("不进行重复生成!") 79 | return False 80 | else: 81 | # 以下操作删除原先的文件,为危险操作 82 | if os.path.exists(join(save_path,file_names[0])): 83 | os.remove(join(save_path,file_names[0])) 84 | 85 | if os.path.exists(join(save_path,file_names[1])): 86 | os.remove(join(save_path,file_names[1])) 87 | 88 | if os.path.exists(join(save_path,file_names[2])): 89 | os.remove(join(save_path,file_names[2])) 90 | 91 | print("Read dataset from ", data_path) 92 | data = pd.read_csv(data_path, 93 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 94 | encoding="UTF-8-SIG") 95 | # 划分为训练集、测试集 96 | keys = list(set(data['Dialogue_ID'])) 97 | random.shuffle(keys) # 随机打乱 98 | validation_split_id = int(len(keys)*(1-validation_split_percentage-test_split_percentage)) 99 | test_split_id = int(len(keys)*(1-test_split_percentage)) 100 | train_keys = keys[:validation_split_id] # 训练集索引 101 | valid_keys = keys[validation_split_id:test_split_id] # 验证集索引 102 | test_keys = keys[test_split_id:] # 测试集索引 103 | train_data = data[data['Dialogue_ID'].isin(train_keys)] 104 | valid_data = data[data['Dialogue_ID'].isin(valid_keys)] 105 | test_data = data[data['Dialogue_ID'].isin(test_keys)] 106 | 107 | train_data.to_csv(join(save_path,file_names[0]), encoding="UTF-8-SIG", index=False) 108 | valid_data.to_csv(join(save_path,file_names[1]), encoding="UTF-8-SIG", index=False) 109 | test_data.to_csv(join(save_path,file_names[2]), encoding="UTF-8-SIG", index=False) 110 | print("已经完成数据集生成!") 111 | 112 | return True 113 | 114 | 115 | # 将指定路径下的所有csv文件合并为一个csv文件 116 | # 实现该函数主要方便汇总统计 117 | def combine_csv_files(data_path="./MELD/", 118 | save_path="./MELD/MELD_total_text.csv", 119 | regen=False): 120 | '''combine_csv_files 121 | 将指定路径下的所有csv文件合并为一个csv文件 122 | 123 | 使用示例: 124 | 125 | combine_csv_files(data_path="./MELD/", 126 | save_name="MELD_total_text", 127 | files=file_names, 128 | save_in="./MELD/") 129 | 130 | ''' 131 | if regen==False: 132 | print("不进行重复生成!") 133 | return False 134 | 135 | 136 | 137 | files = os.listdir("./MELD/") 138 | if os.path.isfile(join(save_in, "%s.csv") % save_name): 139 | return 0 140 | else: 141 | try: 142 | main_list = [] 143 | for i in range(len(files)): 144 | content = pd.read_csv(join(data_path, files[i]), encoding="UTF-8-SIG") 145 | if i == 0: 146 | main_list.extend([list(content.keys())]) 147 | main_list.extend(content.values.tolist()) 148 | 149 | main_dict = {} 150 | for i in list(zip(*main_list)): 151 | main_dict[i[0]] = list(i[1:]) 152 | data_df = pd.DataFrame(main_dict) 153 | data_df.to_csv(join(save_in, "%s.csv") % save_name, encoding="UTF-8-SIG", index=False) 154 | except: 155 | print("合并[%s]时发生错误" % save_name) 156 | 157 | 158 | 159 | def save_speaker(data_path, save_path, row_name="Speaker", regen=False): 160 | '''读取数据集的所有姓名,制作姓名表 161 | 使用示例: 162 | save_speaker(data_path="/148Dataset/Dataset/MELD/MELD/train_sent_emo.csv", 163 | save_path="/148Dataset/Dataset/MELD/MELD/speakers.txt", 164 | row_name="Speaker", 165 | regen=True) 166 | ''' 167 | if os.path.exists(save_path): 168 | if regen == False: 169 | return None 170 | elif regen == True: 171 | os.remove(save_path) 172 | 173 | data = pd.read_csv(data_path, encoding="UTF-8-SIG") 174 | results = get_row_statistics(data,row_name) 175 | print(results["keys"]) 176 | print(results["element_stastics"]) 177 | with open(save_path, 'w') as f: 178 | for i in range(len(results["keys"])): 179 | f.write(results["keys"][i]+'\n') 180 | return True 181 | 182 | 183 | def load_speaker(speakers_file): #speakers.txt 184 | """Loads a speaker file into a dictionary. 185 | speakers_file: 姓名汇总表格 186 | speakers: 返回的列表形式的表格 187 | speakers_to_ids: 返回的字典,通过该字典可以根据姓名获得对应的id 188 | ids_to_speakers:通过该字典可以根据id获得对应的姓名 189 | """ 190 | speakers_to_ids = collections.OrderedDict() 191 | with open(speakers_file, "r", encoding="utf-8") as reader: 192 | speakers = reader.readlines() 193 | for index, token in enumerate(speakers): 194 | token = token.rstrip('\n') 195 | speakers[index] = token 196 | speakers_to_ids[token] = index 197 | ids_to_speakers = collections.OrderedDict([(ids, tok) for tok, ids in speakers_to_ids.items()]) 198 | return speakers, speakers_to_ids, ids_to_speakers 199 | 200 | 201 | def convert_speaker_to_id(speakers_to_ids, speaker, unk_token="其他"): 202 | """ Converts a speaker (str/unicode) in an id using the speakers. """ 203 | return speakers_to_ids.get(speaker, speakers_to_ids.get(unk_token)) 204 | 205 | def convert_id_to_speaker(ids_to_speakers, index, unk_token="其他"): 206 | """Converts an index (integer) in a speaker (string/unicode) using the speakers.""" 207 | return ids_to_speakers.get(index, unk_token) 208 | 209 | def convert_cache_to_csv(dataset_cache,output_dir): 210 | 211 | data = torch.load(dataset_cache) 212 | train_data = data["train"] 213 | valid_data = data["valid"] 214 | test_data = data["test"] 215 | train_data.to_csv(join(output_dir,dataset_cache+"train.csv"),columns=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"]) 216 | 217 | -------------------------------------------------------------------------------- /pec_baseline/utils/cped_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # CPED Dataset data loading file 18 | # File: cped_dataset.py 19 | # Used for CPED dataset loading 20 | # Author: Chen Yirong 21 | # Date: 2022.03.29 22 | 23 | # CPED数据集说明 24 | # 数据集存储形式如下: 25 | # CPED/ 顶层文件夹, 26 | # CPED_total_text.csv 全部对话数据汇总的csv格式文件, 27 | # speakers.txt 所有说话人姓名汇总集合的文本文件, 28 | # train_split.csv 训练集,与valid_split.csv、test_split.csv无重叠说话人 29 | # valid_split.csv 验证集,与train_split.csv、test_split.csv无重叠说话人 30 | # test_split.csv 测试集,与train_split.csv、valid_split.csv无重叠说话人 31 | # train_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 32 | # valid_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 33 | # test_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 34 | # 上述两种数据集划分方式可以用于不同的研究场景! 35 | 36 | # 本文件约定函数命名风格:cped_function_name,例如: 37 | # cped_get_total_data、cped_get_single_file 38 | 39 | # 本文件约定类的命名风格:CpedClassName,例如: 40 | # CpedDataset、CpedBertDataset、CpedBertGptDataset 41 | 42 | # 关键包版本说明: 43 | # pytorch: 1.9.0+ 44 | import torch 45 | import logging 46 | from itertools import chain # 将二维列表转换为一维列表,[[1127, 1127, 6432, 6814],[118, 117, 116]]---> [1127, 1127, 6432, 6814, 118, 117, 116] 47 | from torch.utils.data import Dataset # 参考:https://pytorch.org/docs/1.9.0/data.html#torch.utils.data.Dataset 48 | from torch.utils.data import DataLoader 49 | from torch.nn.utils.rnn import pad_sequence 50 | from .cped_util import (tokenize, cped_get_single_file, cped_get_single_cache_file, cped_get_data_from_dir, 51 | cped_get_single_file_for_bert_gpt, cped_get_data_from_dir_for_bert_gpt, 52 | CPED_SPECIAL_TOKENS, CPED_IGNORE_ID, CPED_DA_TOKENS, CPED_SENTIMENT_TOKENS, 53 | CPED_EMOTION_TOKENS, CPED_DA_TO_TOKENS, CPED_SENTIMENT_TO_TOKENS, CPED_EMOTION_TO_TOKENS, 54 | CPED_DA_TO_ID, CPED_EMOTION_TO_ID, CPED_GENDER_TO_ID, CPED_BIGFIVE_TO_ID, CPED_SPEAKER_TYPE_TO_ID) 55 | 56 | 57 | logger = logging.getLogger(__name__) 58 | 59 | 60 | def find_split_id_of_response(speaker_list,responder): 61 | '''find_split_id_of_response 62 | Inputs: 63 | speaker_list: 姓名组成的列表,例如:['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲', '诺澜', '诺澜'] 64 | responder: 字符串,表示回复者的姓名,例如:'诺澜' 65 | Outputs: 66 | split_id: 负整数,范围为-1至-len(speaker_list)+1 67 | Examples: 68 | speaker_list = ['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲', '诺澜', '诺澜'] 69 | responder = '诺澜' 70 | split_id= find_split_id_of_response(speaker_list,responder) 71 | # 返回结果:-2 72 | # utterance_history = data_index["Token"].tolist()[-max_history_utterances:split_id] 73 | # reponse = data_index["Token"].tolist()[-split_id:] 74 | ''' 75 | split_id = -1 76 | for i in range(-2,-len(speaker_list),-1): 77 | if speaker_list[i] != responder: 78 | return split_id 79 | else: 80 | split_id = split_id-1 81 | return -1 # 极端情形,只有一个人说话,则只认为最后一句为response内容 82 | 83 | def create_speaker_type(speaker_list,responder=None): 84 | '''create_speaker_type: 将姓名列表转换为"[speaker1]"、"[speaker2]"组成的列表 85 | Inputs: 86 | speaker_list: 姓名组成的列表,例如:['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲'] 87 | responder: 字符串,表示回复者的姓名,例如:'诺澜' 88 | Outputs: 89 | speaker_type_list: "[speaker1]"、"[speaker2]"组成的列表 90 | ''' 91 | if responder==None: # 不指定responder 92 | speaker2 = speaker_list[-1] # 最后一个句子的说话人被定义为回复者 93 | else: 94 | speaker2 = responder 95 | speaker_type_list = [] 96 | for speaker in speaker_list: 97 | if speaker==speaker2: 98 | speaker_type_list.append("[speaker2]") # "[speaker2]"代表回复者 99 | else: 100 | speaker_type_list.append("[speaker1]") 101 | return speaker_type_list 102 | 103 | 104 | def convert_emotion_to_tokens(emotion_list, 105 | emotion_type="Emotion", 106 | SELECTED_EMOTION_TO_TOKENS={"Emotion":CPED_EMOTION_TO_TOKENS, 107 | "Sentiment":CPED_SENTIMENT_TO_TOKENS}): 108 | '''convert_emotion_to_tokens: 将情感列表转换为词表当中的情感字符 109 | Inputs: 110 | emotion_list: 对话的情感标签列,每一个元素表示对应的句子的原始情感标签,例如:["happy","happy"] 111 | emotion_type: 字符串,表示情感类型,"Emotion"或者"Sentiment",指定了使用SELECTED_EMOTION_TO_TOKENS 112 | 当中的某一种字典用于将原始标签转换为TOKENS标签 113 | SELECTED_EMOTION_TO_TOKENS: 字典,其键为情感类型,值为相应的情感标签转换为Tokens的字典,由.cped_util定义 114 | Outputs: 115 | emotion_tokens_list: 经过转换后的情感tokens列表 116 | ''' 117 | # emotion_tokens_list = [SELECTED_EMOTION_TO_TOKENS[emotion_type][emo] for emo in emotion_list] 118 | emotion_tokens_list = [] 119 | for emo in emotion_list: 120 | if emo not in SELECTED_EMOTION_TO_TOKENS[emotion_type]: 121 | emotion_tokens_list.append("[neutral]") 122 | else: 123 | emotion_tokens_list.append(SELECTED_EMOTION_TO_TOKENS[emotion_type][emo]) 124 | return emotion_tokens_list 125 | 126 | def convert_da_to_tokens(da_list, 127 | da_type="DA", 128 | SELECTED_DA_TO_TOKENS={"DA":CPED_DA_TO_TOKENS}): 129 | '''convert_da_to_tokens: 将DA列表转换为词表当中的DA字符 130 | Inputs: 131 | da_list: 对话的DA标签列,每一个元素表示对应的句子的原始DA标签 132 | da_type: 字符串,表示DA类型,"DA"或者自定义的DA列名称,指定了使用SELECTED_da_TO_TOKENS 133 | 当中的某一种字典用于将原始标签转换为TOKENS标签 134 | SELECTED_DA_TO_TOKENS: 字典,其键为DA类型,值为相应的DA标签转换为Tokens的字典,由.cped_util定义 135 | Outputs: 136 | da_tokens_list: 经过转换后的DA的tokens列表 137 | ''' 138 | da_tokens_list = [SELECTED_DA_TO_TOKENS[da_type][da] for da in da_list] 139 | return da_tokens_list 140 | 141 | 142 | def set_da_in_speaker(da_ids,input_ids,bos, eos, speaker1, speaker2, pad): 143 | '''set_da_in_speaker: 仅在说话人标志位叠加DA Embedding 144 | 145 | ''' 146 | special_token_ids_list = [bos, eos, speaker1, speaker2] 147 | new_da_ids = [] 148 | for i,da in enumerate(da_ids): 149 | if input_ids[i] in special_token_ids_list: 150 | new_da_ids.append(da_ids[i]) 151 | else: 152 | new_da_ids.append(pad) 153 | return new_da_ids 154 | 155 | def set_emotion_in_speaker(emotion_ids,input_ids,bos, eos, speaker1, speaker2, pad): 156 | '''set_emotion_in_speaker: 仅在说话人标志位叠加情感标签 Embedding 157 | 158 | ''' 159 | special_token_ids_list = [bos, eos, speaker1, speaker2] 160 | new_emotion_ids = [] 161 | for i,emotion in enumerate(emotion_ids): 162 | if input_ids[i] in special_token_ids_list: 163 | new_emotion_ids.append(emotion_ids[i]) 164 | else: 165 | new_emotion_ids.append(pad) 166 | return new_emotion_ids 167 | 168 | 169 | class CpedDataset(Dataset): 170 | '''CpedDataset:用于常规对话生成模型,例如CDialGPT,增加情感或者个性或者DA控制,或者增加Embedding进行叠加 171 | 不适用于SPEAKERBERT、BERTGPT等模型 172 | 173 | Inputs: 174 | data: DataFrame,经过预处理后的对话数据,每一行为一个句子,data至少包含'Dialogue_ID'、"Token"两列 175 | 其中,data['Dialogue_ID']表示样本编号 176 | 其中data["Token"]为对应的句子经过tokenizer映射的ids,看起来像下面这样, 177 | [2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814] 178 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 179 | emotion_type: 字符串,指定情感列名,有"Sentiment"、"Emotion两种选择",可自行组建新的情感列 180 | da_type: 字符串,指定DA列名,目前只有"DA"一种,可自行组建新的DA列 181 | persona_type: 列表,指定个性列名 182 | max_history: 最大对话轮数,句子数=2*max_history 183 | batch_first: 布尔类型,指定batch是否在第一维,也就是(batch,...) 184 | lm_labels: 布尔类型,指定是否返回response 185 | with_current_speaker: 布尔类型,指定回复时的说话人 186 | with_current_persona: 布尔类型,指定回复时的个性 187 | with_current_emotion: 布尔类型,指定回复时的情感 188 | with_current_da: 布尔类型,指定回复时的对话动作DA 189 | with_emotion: 布尔类型,指定情感嵌入 190 | with_da=False: 布尔类型,指定DA嵌入 191 | use_speaker_name_as_speaker_list: 布尔类型,指定说话人姓名作为speaker_list 192 | set_eda_in_speaker: 布尔类型,指定在说话人位置嵌入情感或DA 193 | 194 | Outputs: 195 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``. 196 | Indices of input sequence tokens in the vocabulary. 197 | **token_type_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``. 198 | A parallel sequence of tokens (can be used to indicate various portions of the inputs). 199 | The embeddings from these tokens will be summed with the respective token embeddings. 200 | Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices) 201 | **emotion_ids**: (`optional`, returned when ``with_emotion=True``) 202 | ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``. 203 | **da_ids**: (`optional`, returned when ``with_da=True``) 204 | ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``. 205 | **current_speaker_id**: (`optional`, returned when ``with_current_speaker=True``) 206 | ``torch.LongTensor`` of shape ``(batch_size, 1)``. 207 | **current_persona_ids**: (`optional`, returned when ``with_current_persona=True``) 208 | ``torch.LongTensor`` of shape ``(batch_size, persona_size)``. 209 | **current_emotion_id**: (`optional`, returned when ``with_current_emotion=True``) 210 | ``torch.LongTensor`` of shape ``(batch_size, 1)``. 211 | **current_da_id**: (`optional`, returned when ``with_current_da=True``) 212 | ``torch.LongTensor`` of shape ``(batch_size, 1)``. 213 | **lm_labels**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` 214 | 215 | Examples: 216 | 217 | 218 | ''' 219 | def __init__(self, 220 | data, 221 | tokenizer, 222 | emotion_type="Emotion", 223 | da_type="DA", 224 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness","Age"], 225 | max_history=25, # 句子数则为50 226 | batch_first=True, 227 | lm_labels=True, 228 | with_current_speaker=False, 229 | with_current_persona=False, 230 | with_current_emotion=False, 231 | with_current_da=False, 232 | with_emotion=False, 233 | with_da=False, 234 | use_speaker_name_as_speaker_list=False, 235 | set_eda_in_speaker=False, 236 | set_current_speaker_mask=False, 237 | max_word_length=512): # 增加限制总的字符数 238 | self.data = data 239 | self.tokenizer = tokenizer 240 | self.emotion_type = emotion_type # 'Emotion' 情感标签列名 241 | self.da_type = da_type # 'DA' DA标签列名 242 | self.persona_type = persona_type 243 | self.with_current_speaker = with_current_speaker 244 | self.with_current_persona = with_current_persona 245 | self.with_current_emotion = with_current_emotion 246 | self.with_current_da = with_current_da 247 | self.with_emotion=with_emotion # Whether use emotion to help generate dialogue 248 | self.with_da=with_da # Whether use DA to help generate dialogue 249 | self.max_history = max_history # Maximum number of dialogue turns 250 | self.max_history_utterances = 2*max_history # Maximum number of dialogue sentences 251 | self.max_word_length = max_word_length 252 | self.use_speaker_name_as_speaker_list = use_speaker_name_as_speaker_list 253 | self.set_eda_in_speaker = set_eda_in_speaker 254 | self.set_current_speaker_mask = set_current_speaker_mask 255 | self.pad = tokenizer.pad_token_id 256 | self.batch_first = batch_first 257 | self.lm_labels = lm_labels 258 | self.keys = list(set(self.data['Dialogue_ID'])) 259 | self.len = len(self.keys) 260 | 261 | def __len__(self): 262 | return self.len 263 | 264 | def __getitem__(self, index): 265 | dialogue_id = self.keys[index] # 当前对话样本编号 266 | data_index = self.data[self.data['Dialogue_ID']==dialogue_id] 267 | 268 | if len(data_index["Speaker"].tolist()) > self.max_history_utterances: # 实际句子数大于self.max_history_utterances 269 | max_history_utterances = self.max_history_utterances 270 | else: # 实际句子数小于self.max_history_utterances 271 | max_history_utterances = len(data_index["Speaker"].tolist()) 272 | # 判断data_index的“句子数+data_index["Token"]的Token数+2”是否大于self.max_word_length 273 | while len(data_index["Speaker"].tolist()[-max_history_utterances:])+len(list(chain(*data_index["Token"].tolist()[-max_history_utterances:])))+2>self.max_word_length: 274 | max_history_utterances = max_history_utterances - 1 275 | 276 | speaker_name_list = data_index["Speaker"].tolist()[-max_history_utterances:] # 说话人姓名列表 277 | responder = speaker_name_list[-1] # 回复者姓名 278 | responder_token = self.tokenizer.convert_tokens_to_ids(responder) # 整数 279 | 280 | # 找出回复内容与历史对话的分割id 281 | response_split_id = find_split_id_of_response(speaker_name_list,responder) 282 | # 历史对话内容,长这样:[[2108, 3342, 3342, 8024], [1962, 1008, 2769, 1420]] 283 | history_utterance_tokens = data_index["Token"].tolist()[-max_history_utterances:response_split_id] 284 | # 回复内容,长这样:[[1127, 1127, 6432, 6814],[118, 117, 116]]---> [1127, 1127, 6432, 6814, 118, 117, 116] 285 | if self.lm_labels: 286 | response_utterance_tokens = data_index["Token"].tolist()[response_split_id:] 287 | response_utterance_tokens = list(chain(*response_utterance_tokens)) # 二维列表转一维列表 288 | else: 289 | response_utterance_tokens = [] 290 | 291 | # 创建历史对话对应的history_speaker_types 292 | if self.use_speaker_name_as_speaker_list: 293 | # 使用说话人姓名嵌入,需要把说话人姓名加进词表! 294 | history_speaker_types = speaker_name_list[-max_history_utterances:response_split_id] 295 | else: 296 | # 使用函数create_speaker_type创建姓名嵌入 297 | # "[speaker2]"表示回复者,"[speaker1]"表示另一个说话人 298 | history_speaker_types = create_speaker_type(speaker_list=speaker_name_list[-max_history_utterances:response_split_id], responder=responder) 299 | # 将字符串表示转换为id表示 300 | history_speaker_tokens = self.tokenizer.convert_tokens_to_ids(history_speaker_types) 301 | 302 | # 创建历史对话对应的history_emotion_tokens 303 | if self.with_emotion: # 需要把情感标签加进词表! 304 | history_emotion_tokens = convert_emotion_to_tokens(emotion_list=data_index[self.emotion_type].tolist()[-max_history_utterances:response_split_id], 305 | emotion_type=self.emotion_type) 306 | history_emotion_tokens = self.tokenizer.convert_tokens_to_ids(history_emotion_tokens) 307 | else: 308 | history_emotion_tokens = [] 309 | 310 | # 创建历史对话对应的history_da_tokens 311 | if self.with_da: # 需要把DA标签加进词表! 312 | history_da_tokens = convert_da_to_tokens(da_list=data_index[self.da_type].tolist()[-max_history_utterances:response_split_id], 313 | da_type=self.da_type) 314 | history_da_tokens = self.tokenizer.convert_tokens_to_ids(history_da_tokens) 315 | else: 316 | history_da_tokens = [] 317 | 318 | 319 | # 创建用于指定回复的情感、DA、个性 320 | # 以下用于情感、DA与词嵌入共同使用一个Embedding 321 | current_emotion_token = self.tokenizer.convert_tokens_to_ids(data_index[self.emotion_type].tolist()[-1]) 322 | current_da_token = self.tokenizer.convert_tokens_to_ids(data_index[self.da_type].tolist()[-1]) 323 | # 以下用于情感、DA不与词嵌入共同使用一个Embedding 324 | current_emotion_id = CPED_EMOTION_TO_ID[data_index[self.emotion_type].tolist()[-1]] 325 | current_da_id = CPED_DA_TO_ID[data_index[self.da_type].tolist()[-1]] 326 | if self.with_current_persona: 327 | current_gender_id = CPED_GENDER_TO_ID[data_index[self.persona_type[0]].tolist()[-1]] 328 | current_Neuroticism_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[1]].tolist()[-1]] 329 | current_Extraversion_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[2]].tolist()[-1]] 330 | current_Openness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[3]].tolist()[-1]] 331 | current_Agreeableness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[4]].tolist()[-1]] 332 | current_Conscientiousness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[5]].tolist()[-1]] 333 | current_persona_ids = [current_gender_id,current_Neuroticism_id,current_Extraversion_id,current_Openness_id, 334 | current_Agreeableness_id,current_Conscientiousness_id] 335 | else: 336 | current_persona_ids = [] 337 | 338 | return self.process(history_speaker_tokens, 339 | history_utterance_tokens, 340 | history_emotion_tokens, 341 | history_da_tokens, 342 | responder_token, 343 | current_emotion_token, 344 | current_da_token, 345 | current_emotion_id, 346 | current_da_id, 347 | current_persona_ids, 348 | response_utterance_tokens) 349 | 350 | def process(self, 351 | history_speaker_tokens, 352 | history_utterance_tokens, 353 | history_emotion_tokens, 354 | history_da_tokens, 355 | responder_token, 356 | current_emotion_token, 357 | current_da_token, 358 | current_emotion_id, 359 | current_da_id, 360 | current_persona_ids, 361 | response_utterance_tokens, 362 | with_eos=True): 363 | instance = {} 364 | bos, eos, speaker1, speaker2 = self.tokenizer.convert_tokens_to_ids(CPED_SPECIAL_TOKENS) 365 | speaker_tokens = history_speaker_tokens + [responder_token] 366 | emotion_tokens = history_emotion_tokens + [current_emotion_token] 367 | da_tokens = history_da_tokens + [current_da_token] 368 | sequence = [[bos]] + history_utterance_tokens + [response_utterance_tokens + ([eos] if with_eos else [])] 369 | sequence = [sequence[0]] + [[speaker_tokens[i]] + s 370 | for i, s in enumerate(sequence[1:])] 371 | instance["input_ids"] = list(chain(*sequence)) 372 | instance["token_type_ids"] = [bos] + [speaker_tokens[i] for i, s in 373 | enumerate(sequence[1:]) 374 | for _ in s] 375 | 376 | if self.with_da: 377 | instance["da_ids"] = [bos] + [da_tokens[i] for i, s in 378 | enumerate(sequence[1:]) 379 | for _ in s] 380 | # only set the DA in [speaker1] or [speaker2] 381 | if self.set_eda_in_speaker: 382 | instance["da_ids"] = set_da_in_speaker(instance["da_ids"],instance["input_ids"],bos, eos, speaker1, speaker2, self.pad) 383 | if self.with_emotion: 384 | instance["emotion_ids"] = [bos] + [emotion_tokens[i] for i, s in 385 | enumerate(sequence[1:]) 386 | for _ in s] 387 | # only set the emotion in [speaker1] or [speaker2] 388 | if self.set_eda_in_speaker: 389 | instance["emotion_ids"] = self.set_emotion_in_speaker(instance["emotion_ids"],instance["input_ids"],bos, eos, speaker1, speaker2, self.pad) 390 | if self.with_current_speaker: 391 | instance["current_speaker_id"] = responder_token 392 | 393 | if self.with_current_emotion: 394 | instance["current_emotion_id"] = current_emotion_id 395 | 396 | if self.with_current_da: 397 | instance["current_da_id"] = current_da_id 398 | 399 | if self.with_current_persona: 400 | instance["current_persona_ids"] = current_persona_ids 401 | 402 | instance["lm_labels"] = [-1] * len(instance["input_ids"]) 403 | if self.lm_labels: 404 | instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:] 405 | if self.set_current_speaker_mask: 406 | instance["current_speaker_mask"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [1] + ([-1] * len(sequence[-1][1:]) ) 407 | 408 | return instance 409 | 410 | def collate(self, batch): 411 | input_ids = pad_sequence( 412 | [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in batch], 413 | batch_first=self.batch_first, padding_value=self.pad) 414 | token_type_ids = pad_sequence( 415 | [torch.tensor(instance["token_type_ids"], dtype=torch.long) for instance in batch], 416 | batch_first=self.batch_first, padding_value=self.pad) 417 | 418 | if self.with_emotion: 419 | emotion_ids = pad_sequence( 420 | [torch.tensor(instance["emotion_ids"], dtype=torch.long) for instance in batch], 421 | batch_first=self.batch_first, padding_value=self.pad) 422 | else: 423 | emotion_ids = None 424 | 425 | if self.with_da: 426 | da_ids = pad_sequence( 427 | [torch.tensor(instance["da_ids"], dtype=torch.long) for instance in batch], 428 | batch_first=self.batch_first, padding_value=self.pad) 429 | else: 430 | da_ids = None 431 | 432 | if self.with_current_speaker: 433 | current_speaker_id = torch.tensor( 434 | [torch.tensor(instance["current_speaker_id"], dtype=torch.long) for instance in batch], 435 | dtype=torch.long) 436 | else: 437 | current_speaker_id = None 438 | 439 | if self.with_current_persona: 440 | current_persona_ids = pad_sequence( 441 | [torch.tensor(instance["current_persona_ids"], dtype=torch.long) for instance in batch], 442 | batch_first=self.batch_first, padding_value=1) # padding_value=1 means unknown here 443 | else: 444 | current_persona_ids = None 445 | 446 | if self.with_current_emotion: 447 | current_emotion_id = torch.tensor( 448 | [torch.tensor(instance["current_emotion_id"], dtype=torch.long) for instance in batch], 449 | dtype=torch.long) 450 | else: 451 | current_emotion_id = None 452 | 453 | if self.with_current_da: 454 | current_da_id = torch.tensor( 455 | [torch.tensor(instance["current_da_id"], dtype=torch.long) for instance in batch], 456 | dtype=torch.long) 457 | else: 458 | current_da_id = None 459 | lm_labels = pad_sequence( 460 | [torch.tensor(instance["lm_labels"], dtype=torch.long) for instance in batch], 461 | batch_first=self.batch_first, padding_value=-1) 462 | 463 | if self.set_current_speaker_mask: # for CVGPT 464 | current_speaker_mask = pad_sequence( 465 | [torch.tensor(instance["current_speaker_mask"], dtype=torch.long) for instance in batch], 466 | batch_first=self.batch_first, padding_value=-1) 467 | return input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels, current_speaker_mask 468 | else: 469 | return input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels 470 | 471 | 472 | def build_cped_dataloaders(args, 473 | tokenizer, 474 | logger, 475 | load_test=False, 476 | filenames={"train":"train_shuffle_split.csv", 477 | "valid":"valid_shuffle_split.csv", 478 | "test":"test_shuffle_split.csv"}): 479 | data,sample = cped_get_data_from_dir(dir_path=args.data_path, 480 | cache_path=args.cache_path, 481 | tokenizer=tokenizer, 482 | logger=logger, 483 | filenames=filenames) 484 | 485 | if load_test==False: 486 | logger.info("Build train and validation dataloaders") 487 | train_data = data["train"] 488 | valid_data = data["valid"] 489 | train_dataset = CpedDataset(data=train_data, 490 | tokenizer=tokenizer, 491 | emotion_type=args.emotion_type, 492 | da_type=args.da_type, 493 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 494 | max_history=args.max_history, 495 | batch_first=True, 496 | lm_labels=True, 497 | with_current_speaker=args.with_current_speaker, 498 | with_current_persona=args.with_current_persona, 499 | with_current_emotion=args.with_current_emotion, 500 | with_current_da=args.with_current_da, 501 | with_emotion=args.with_emotion, 502 | with_da=args.with_da, 503 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list, 504 | set_eda_in_speaker=args.set_eda_in_speaker, 505 | set_current_speaker_mask=args.set_current_speaker_mask) 506 | 507 | 508 | valid_dataset = CpedDataset(data=valid_data, 509 | tokenizer=tokenizer, 510 | emotion_type=args.emotion_type, 511 | da_type=args.da_type, 512 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 513 | max_history=args.max_history, 514 | batch_first=True, 515 | lm_labels=True, 516 | with_current_speaker=args.with_current_speaker, 517 | with_current_persona=args.with_current_persona, 518 | with_current_emotion=args.with_current_emotion, 519 | with_current_da=args.with_current_da, 520 | with_emotion=args.with_emotion, 521 | with_da=args.with_da, 522 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list, 523 | set_eda_in_speaker=args.set_eda_in_speaker, 524 | set_current_speaker_mask=args.set_current_speaker_mask) 525 | 526 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None 527 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None 528 | train_loader = DataLoader(train_dataset, 529 | sampler=train_sampler, 530 | collate_fn=train_dataset.collate, 531 | num_workers=args.num_workers, 532 | batch_size=args.train_batch_size, 533 | shuffle=(not args.distributed)) 534 | valid_loader = DataLoader(valid_dataset, 535 | sampler=valid_sampler, 536 | collate_fn=valid_dataset.collate, 537 | num_workers=args.num_workers, 538 | batch_size=args.valid_batch_size, 539 | shuffle=False) 540 | 541 | return train_loader, valid_loader, train_sampler, valid_sampler 542 | 543 | else: 544 | logger.info("Build test dataloaders") 545 | test_data = data["test"] 546 | test_dataset = CpedDataset(data=test_data, 547 | tokenizer=tokenizer, 548 | emotion_type=args.emotion_type, 549 | da_type=args.da_type, 550 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 551 | max_history=args.max_history, 552 | batch_first=True, 553 | lm_labels=True, 554 | with_current_speaker=args.with_current_speaker, 555 | with_current_persona=args.with_current_persona, 556 | with_current_emotion=args.with_current_emotion, 557 | with_current_da=args.with_current_da, 558 | with_emotion=args.with_emotion, 559 | with_da=args.with_da, 560 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list, 561 | set_eda_in_speaker=args.set_eda_in_speaker, 562 | set_current_speaker_mask=args.set_current_speaker_mask) 563 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if args.distributed else None 564 | test_loader = DataLoader(test_dataset, 565 | sampler=test_sampler, 566 | collate_fn=test_dataset.collate, 567 | num_workers=args.num_workers, 568 | batch_size=args.test_batch_size, 569 | shuffle=False) 570 | return test_loader, test_sampler 571 | -------------------------------------------------------------------------------- /pec_baseline/utils/cped_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # CPED Dataset data loading file 18 | # File: cped_util.py 19 | # Used for CPED dataset loading 20 | # Author: Chen Yirong 21 | # Date: 2022.03.22 22 | 23 | # CPED数据集说明 24 | # 数据集存储形式如下: 25 | # CPED/ 顶层文件夹, 26 | # CPED_total_text.csv 全部对话数据汇总的csv格式文件, 27 | # speakers.txt 所有说话人姓名汇总集合的文本文件, 28 | # train_split.csv 训练集,与valid_split.csv、test_split.csv无重叠说话人 29 | # valid_split.csv 验证集,与train_split.csv、test_split.csv无重叠说话人 30 | # test_split.csv 测试集,与train_split.csv、valid_split.csv无重叠说话人 31 | # train_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 32 | # valid_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 33 | # test_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到 34 | # 上述两种数据集划分方式可以用于不同的研究场景! 35 | 36 | # 本文件约定函数命名风格:cped_function_name,例如: 37 | # cped_get_total_data、cped_get_single_file 38 | 39 | # 关键包版本说明: 40 | # pytorch: 1.9.0+ 41 | 42 | import os 43 | import re 44 | import torch 45 | import logging 46 | import pandas as pd 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | 51 | # CPED数据采集用到的一些常量,例如:情感标签、对话动作、 52 | CPED_SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[speaker1]", "[speaker2]"] 53 | CPED_IGNORE_ID = -1 # Tokens with indices set to ``-1`` are ignored,用于训练SpeakerBert 54 | 55 | # 以下列表用于添加到词表当中 56 | CPED_DA_TOKENS = ["[greeting]","[question]","[answer]","[statement-opinion]","[statement-non-opinion]","[apology]", 57 | "[command]","[agreement]","[disagreement]","[acknowledge]","[appreciation]","[interjection]", 58 | "[conventional-closing]","[quotation]","[reject]","[irony]","[comfort]","[thanking]","[da-other]"] # 19 DA labels 59 | 60 | CPED_SENTIMENT_TOKENS = ["[neutral]","[positive]","[negative]"] 61 | 62 | CPED_EMOTION_TOKENS = ["[happy]","[grateful]","[relaxed]","[positive-other]","[anger]","[sadness]","[fear]", 63 | "[depress]","[disgust]","[astonished]","[worried]","[negative-other]","[neutral]"] # 13 emotion labels 64 | 65 | CPED_DA_TO_TOKENS = {'greeting': '[greeting]', 'question': '[question]', 'answer': '[answer]', 66 | 'statement-opinion': '[statement-opinion]', 'statement-non-opinion': '[statement-non-opinion]', 67 | 'apology': '[apology]', 'command': '[command]', 'agreement': '[agreement]', 68 | 'disagreement': '[disagreement]', 'acknowledge': '[acknowledge]', 'appreciation': '[appreciation]', 69 | 'interjection': '[interjection]', 'conventional-closing': '[conventional-closing]', 70 | 'quotation': '[quotation]', 'reject': '[reject]', 'irony': '[irony]', 71 | 'comfort': '[comfort]','thanking':'[thanking]', 'other': '[da-other]'} 72 | 73 | CPED_SENTIMENT_TO_TOKENS = {'neutral': '[neutral]', 'positive': '[positive]', 'negative': '[negative]'} 74 | 75 | CPED_EMOTION_TO_TOKENS = {'happy': '[happy]', 'grateful': '[grateful]', 'relaxed': '[relaxed]', 76 | 'positive-other': '[positive-other]', 'anger': '[anger]', 'sadness': '[sadness]', 77 | 'fear': '[fear]', 'depress': '[depress]', 'disgust': '[disgust]', 78 | 'astonished': '[astonished]', 'worried': '[worried]', 'negative-other': '[negative-other]', 79 | 'neutral': '[neutral]'} 80 | 81 | CPED_DA_TO_ID = {'greeting': 0, 'question': 1, 'answer': 2, 'statement-opinion': 3, 'statement-non-opinion': 4, 82 | 'apology': 5, 'command': 6, 'agreement': 7, 'disagreement': 8, 'acknowledge': 9, 'appreciation': 10, 83 | 'interjection': 11, 'conventional-closing': 12, 'quotation': 13, 'reject': 14, 'irony': 15, 84 | 'comfort': 16,'thanking':17, 'other': 18} 85 | 86 | CPED_EMOTION_TO_ID = {'happy': 0, 'grateful': 1, 'relaxed': 2, 'positive-other': 3, 'anger': 4, 'sadness': 5, 87 | 'fear': 6, 'depress': 7, 'disgust': 8, 'astonished': 9, 'worried': 10, 88 | 'negative-other': 11, 'neutral': 12} 89 | 90 | CPED_GENDER_TO_ID = {'female': 0, 'unknown': 1, 'male': 2} 91 | CPED_BIGFIVE_TO_ID = {'low': 0, 'unknown': 1, 'high': 2} 92 | 93 | CPED_SPEAKER_TYPE_TO_ID ={"[speaker1]": 0, "[speaker2]": 1, "[MASK]": 2} 94 | 95 | # 给语音识别文本加上标点符号 96 | # https://blog.csdn.net/qq_33200967/article/details/122474859 97 | 98 | 99 | 100 | 101 | 102 | def tokenize(utterance, tokenizer): 103 | '''tokenize:使用tokenizer对utterance进行tokenize 104 | Inputs: 105 | utterance: 字符串,一个句子 106 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 107 | Outputs: 108 | ids: 列表,列表中的每个元素对应token的id,例如:[2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814] 109 | Example: 110 | from transformers import BertTokenizer 111 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 112 | ids = tokenize(utterance="季杨杨,好像我听凡凡说过", tokenizer=tokenizer) 113 | print(ids) 114 | # 返回:[2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814] 115 | ''' 116 | utterance = str(utterance) # 保证为str类型 117 | # 对于问句添加问号 118 | utterance = utterance.replace("吗", "吗?") 119 | utterance = utterance.replace("??", "?") 120 | 121 | # 对于感叹句添加感叹号 122 | utterance = utterance.replace("啊", "啊!") 123 | utterance = utterance.replace("吧", "吧!") 124 | utterance = utterance.replace("啦", "啦!") 125 | utterance = utterance.replace("呀", "呀!") 126 | utterance = utterance.replace("!!", "!") 127 | 128 | # 对于句子中间非问句,非感叹句添加逗号 129 | utterance = utterance.replace(" ", ",") 130 | # 去除重复标点符号 131 | utterance = utterance.split() # 去除全部空格 132 | 133 | utt_list = list(utterance) # "季杨杨,好像我听凡凡说过" --> ['季', '杨', '杨', ',', '好', '像', '我', '听', '凡', '凡', '说', '过'] 134 | 135 | utterance = ' '.join(utt_list) # ['季', '杨', '杨', ',', '好', '像', '我', '听', '凡', '凡', '说', '过']--> “季 杨 杨 , 好 像 我 听 凡 凡 说 过” # 136 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(utterance)) 137 | 138 | 139 | def cped_get_single_file(file_path, 140 | tokenizer, 141 | logger, 142 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 143 | args=None): 144 | '''cped_get_single_file: 读取指定路径的csv文件,例如:CPED_total_text.csv、train_split.csv、... 145 | Inputs: 146 | file_path: 字符串,指定文件路径 147 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 148 | logger: logging日志对象 149 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中 150 | "Dialogue_ID","Utterance_ID","Speaker","Utterance" 151 | 是必需项 152 | args: parser.parse_args()返回的参数字典 153 | Outputs: 154 | data: DataFrame对象 155 | samples: DataFrame对象 156 | Example: 157 | import logging 158 | from transformers import BertTokenizer 159 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 160 | logger = logging.getLogger(__name__) 161 | file_path = "../data/CPED/test_split.csv" 162 | data, samples = cped_get_single_file(file_path, tokenizer, logger) 163 | 164 | ''' 165 | logger.info("Read file from %s", file_path) 166 | data = pd.read_csv(file_path, 167 | usecols=usecols, 168 | encoding="UTF-8-SIG") 169 | samples = data.iloc[0:30] 170 | 171 | logger.info("Start tokenizing and encoding the file") 172 | data["Token"] = [tokenize(s, tokenizer) for s in data["Utterance"]] 173 | logger.info("Finished tokenizing and encoding the dataset") 174 | return data, samples 175 | 176 | 177 | def cped_get_single_cache_file(file_path, 178 | cache_path, 179 | tokenizer, 180 | logger, 181 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 182 | args=None): 183 | '''cped_get_single_cache_file: 读取指定路径的csv文件,如果存在cache,则直接读取cache文件, 184 | 例如:CPED_total_text.csv、train_split.csv、... 185 | 这个函数与cped_get_single_file的最大不同就是,第一次读取会保存一个cache文件,之后再读取,就不需要 186 | 调用tokenizer进行预处理了,节省大量实验时间 187 | Inputs: 188 | file_path: 字符串,指定文件路径 189 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path) 190 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 191 | logger: logging日志对象 192 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中 193 | "Dialogue_ID","Utterance_ID","Speaker","Utterance" 194 | 是必需项 195 | args: parser.parse_args()返回的参数字典 196 | Outputs: 197 | data: DataFrame对象 198 | samples: DataFrame对象 199 | Example: 200 | import logging 201 | from transformers import BertTokenizer 202 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 203 | logger = logging.getLogger(__name__) 204 | file_path = "../data/CPED/test_split.csv" 205 | cache_path = "../data/CPED/test_split_cache" 206 | data, samples = cped_get_single_cache_file(file_path, cache_path, tokenizer, logger) 207 | 208 | ''' 209 | if cache_path and os.path.isfile(cache_path): 210 | logger.info("Load tokenized dataset from cache at %s", cache_path) 211 | data = torch.load(cache_path) 212 | samples = None 213 | else: # 从原始文件中读取数据 214 | logger.info("Read dataset from %s", file_path) 215 | data, samples = cped_get_single_file(file_path=file_path, 216 | tokenizer=tokenizer, 217 | logger=logger, 218 | usecols=usecols, 219 | args=args) 220 | logger.info("Finished tokenizing and encoding the dataset") 221 | logger.info("Save tokenized dataset to cache at %s", cache_path) 222 | torch.save(data, cache_path) 223 | return data, samples 224 | 225 | 226 | def cped_get_data_from_dir(dir_path, 227 | cache_path, 228 | tokenizer, 229 | logger, 230 | filenames={"train":"train_shuffle_split.csv", 231 | "valid":"valid_shuffle_split.csv", 232 | "test":"test_shuffle_split.csv"}, 233 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 234 | args=None): 235 | '''cped_get_data_from_dir: 读取dir_path指定目录下,字典filenames指定的数据集 236 | 如果存在cache,则直接读取cache文件, 237 | Inputs: 238 | dir_path: 字符串,指定数据集存放的目录 239 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path) 240 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 241 | logger: logging日志对象 242 | filenames: 字典,包括"train"、"valid"、"test"三个键,其值指定对应的文件名 243 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中 244 | "Dialogue_ID","Utterance_ID","Speaker","Utterance" 245 | 是必需项 246 | args: parser.parse_args()返回的参数字典 247 | Outputs: 248 | data: 字典,格式为{"train":train_data,"valid":valid_data, "test":test_data},每一个值为DataFrame对象 249 | samples: DataFrame对象 250 | Example: 251 | import logging 252 | from transformers import BertTokenizer 253 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 254 | logger = logging.getLogger(__name__) 255 | dir_path = "../data/CPED" 256 | cache_path = "../data/CPED/cped_cache" 257 | filenames = {"train":"train_shuffle_split.csv", 258 | "valid":"valid_shuffle_split.csv", 259 | "test":"test_shuffle_split.csv"} 260 | data, samples = cped_get_data_from_dir(dir_path, cache_path, tokenizer, logger, filenames) 261 | 262 | ''' 263 | if cache_path and os.path.isfile(cache_path): 264 | logger.info("Load tokenized dataset from cache at %s", cache_path) 265 | data = torch.load(cache_path) 266 | samples = None 267 | else: # 从原始文件中读取数据 268 | logger.info("Read dataset from %s", dir_path) 269 | train_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["train"]), tokenizer, logger, usecols, args) 270 | valid_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["valid"]), tokenizer, logger, usecols, args) 271 | test_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["test"]), tokenizer, logger, usecols, args) 272 | data = {"train":train_data,"valid":valid_data, "test":test_data} 273 | logger.info("Finished tokenizing and encoding the dataset") 274 | logger.info("Save tokenized dataset to cache at %s", cache_path) 275 | torch.save(data, cache_path) 276 | return data, samples 277 | 278 | 279 | def cped_get_single_file_for_bert_gpt(file_path, 280 | bert_tokenizer, 281 | gpt_tokenizer, 282 | logger, 283 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 284 | args=None): 285 | '''cped_get_single_file_for_bert_gpt: 读取指定路径的csv文件,例如:CPED_total_text.csv、train_split.csv、... 286 | 并使用两种tokenizer进行tokenize 287 | Inputs: 288 | file_path: 字符串,指定文件路径 289 | bert_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 290 | gpt_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 291 | logger: logging日志对象 292 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中 293 | "Dialogue_ID","Utterance_ID","Speaker","Utterance" 294 | 是必需项 295 | args: parser.parse_args()返回的参数字典 296 | Outputs: 297 | data: DataFrame对象 298 | samples: DataFrame对象 299 | Example: 300 | import logging 301 | from transformers import BertTokenizer 302 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 303 | gpt_tokenizer = BertTokenizer.from_pretrained("openai-gpt") 304 | logger = logging.getLogger(__name__) 305 | file_path = "../data/CPED/test_split.csv" 306 | data, samples = cped_get_single_file_for_bert_gpt(file_path, bert_tokenizer, gpt_tokenizer, logger) 307 | 308 | ''' 309 | logger.info("Read file from %s", file_path) 310 | data = pd.read_csv(file_path, 311 | usecols=usecols, 312 | encoding="UTF-8-SIG") 313 | samples = data.iloc[0:30] 314 | 315 | logger.info("Start tokenizing and encoding the file") 316 | data["Token_bert"] = [tokenize(s, bert_tokenizer) for s in data["Utterance"]] 317 | data["Token_gpt"] = [tokenize(s, gpt_tokenizer) for s in data["Utterance"]] 318 | logger.info("Finished tokenizing and encoding the dataset") 319 | return data, samples 320 | 321 | 322 | def cped_get_data_from_dir_for_bert_gpt(dir_path, 323 | cache_path, 324 | bert_tokenizer, 325 | gpt_tokenizer, 326 | logger, 327 | filenames={"train":"train_shuffle_split.csv", 328 | "valid":"valid_shuffle_split.csv", 329 | "test":"test_shuffle_split.csv"}, 330 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"], 331 | args=None): 332 | '''cped_get_data_from_dir_for_bert_gpt: 读取dir_path指定目录下,字典filenames指定的数据集 333 | 如果存在cache,则直接读取cache文件, 334 | Inputs: 335 | dir_path: 字符串,指定数据集存放的目录 336 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path) 337 | bert_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 338 | gpt_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer 339 | logger: logging日志对象 340 | filenames: 字典,包括"train"、"valid"、"test"三个键,其值指定对应的文件名 341 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中 342 | "Dialogue_ID","Utterance_ID","Speaker","Utterance" 343 | 是必需项 344 | args: parser.parse_args()返回的参数字典 345 | Outputs: 346 | data: 字典,格式为{"train":train_data,"valid":valid_data, "test":test_data},每一个值为DataFrame对象 347 | samples: DataFrame对象 348 | Example: 349 | import logging 350 | from transformers import BertTokenizer 351 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") 352 | gpt_tokenizer = BertTokenizer.from_pretrained("openai-gpt") 353 | logger = logging.getLogger(__name__) 354 | dir_path = "../data/CPED" 355 | cache_path = "../data/CPED/cped_cache" 356 | filenames = {"train":"train_shuffle_split.csv", 357 | "valid":"valid_shuffle_split.csv", 358 | "test":"test_shuffle_split.csv"} 359 | data, samples = cped_get_data_from_dir_for_bert_gpt(dir_path, cache_path, bert_tokenizer, gpt_tokenizer, logger, filenames) 360 | 361 | ''' 362 | if cache_path and os.path.isfile(cache_path): 363 | logger.info("Load tokenized dataset from cache at %s", cache_path) 364 | data = torch.load(cache_path) 365 | samples = None 366 | else: # 从原始文件中读取数据 367 | logger.info("Read dataset from %s", dir_path) 368 | train_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["train"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args) 369 | valid_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["valid"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args) 370 | test_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["test"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args) 371 | data = {"train":train_data,"valid":valid_data, "test":test_data} 372 | logger.info("Finished tokenizing and encoding the dataset") 373 | logger.info("Save tokenized dataset to cache at %s", cache_path) 374 | torch.save(data, cache_path) 375 | return data, samples 376 | 377 | -------------------------------------------------------------------------------- /pec_baseline/utils/dataset_statistics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 South China University of Technology and 3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Dataset statistic file 18 | # File: dataset_statistics.py 19 | # Used for dataset analysis 20 | # Author: Chen Yirong 21 | # Date: 2022.03.21 22 | 23 | import os 24 | import numpy as np 25 | import pandas as pd 26 | from os.path import join 27 | from collections import Counter 28 | 29 | 30 | def get_data_for_analysis(data_dir="../data/CPED", 31 | file_dict={"train":"train_split.csv","valid":"valid_split.csv","test":"test_split.csv"}): 32 | '''get_data from dir, which have train_split.csv, valid_split.csv, test_split.csv file 33 | Get .csv format dataset from data_dir. 34 | ''' 35 | print("Read dataset from ", data_dir) 36 | train_data = pd.read_csv(join(data_dir,file_dict["train"]), encoding="UTF-8-SIG") 37 | valid_data = pd.read_csv(join(data_dir,file_dict["valid"]), encoding="UTF-8-SIG") 38 | test_data = pd.read_csv(join(data_dir,file_dict["test"]), encoding="UTF-8-SIG") 39 | return train_data, valid_data, test_data 40 | 41 | 42 | def get_totaldata_for_analysis(data_path="/home/MMMTD/data/processed_cleaned_data/total_checked_processed_cleaned_data/checked_processed_cleaned_data.csv"): 43 | '''get total data from data_path 44 | Get .csv format dataset from data_path. 45 | ''' 46 | print("Read dataset from ", data_path) 47 | total_data = pd.read_csv(data_path, encoding="UTF-8-SIG") 48 | return total_data 49 | 50 | 51 | def get_row_statistics(data,row_name): 52 | '''get_row_statistics 53 | Get dataset row statistics with row_name 54 | E.g. 55 | train_data_TV_ID=get_row_statistics(train_data,"TV_ID") 56 | print("train TV stastics:\n", train_data_TV_ID["element_stastics"]) 57 | ''' 58 | name = row_name 59 | keys = list(set(data[name])) 60 | values = data[name].tolist() 61 | element_stastics=pd.value_counts(values) 62 | 63 | row_size = len(values) 64 | row_class = len(keys) 65 | results={"name":name,"keys":keys,"values":values,"element_stastics":element_stastics,"size":row_size,"class":row_class} 66 | return results 67 | 68 | 69 | def cout_dialogue_words(data,dialogue_id): 70 | dialogue_data = data[data['Dialogue_ID']==dialogue_id] 71 | count = 0 72 | for utt in dialogue_data["Utterance"]: 73 | count = count + len(str(utt)) 74 | return count 75 | 76 | 77 | 78 | def remove_element(utt_list,word = " "): 79 | temp_list = utt_list 80 | while word in temp_list: 81 | temp_list.remove(word) 82 | return temp_list 83 | 84 | 85 | def statistics_utterance(data,row_name="Utterance"): 86 | '''statistics_utterance 87 | Count the average and maximum word count of sentences 88 | 89 | ''' 90 | utt_list = data[row_name].tolist() 91 | utt_word_list = [remove_element(list(utterance)) for utterance in utt_list] 92 | count_utt_word = [] 93 | for utt in utt_word_list: 94 | count_utt_word.append(len(utt)) 95 | return {"max":max(count_utt_word),"avg":sum(count_utt_word)/len(count_utt_word)} 96 | 97 | 98 | def statistics_emotda(data,row_name="Emotion",dialogue_id="Dialogue_ID"): 99 | '''statistics_emotda 100 | Count the average emotion/DA per Dialogue 101 | 102 | ''' 103 | keys = list(set(data[dialogue_id])) 104 | count_dial_eda = [] 105 | for key in keys: 106 | dial_eda = list(set(data[data[dialogue_id]==key][row_name].tolist())) 107 | count_dial_eda.append(len(dial_eda)) 108 | return {"max":max(count_dial_eda),"avg":sum(count_dial_eda)/len(count_dial_eda)} 109 | 110 | 111 | def statistics_avg_duration(all_data,data,dialogue_id="Dialogue_ID",StartTime="StartTime",EndTime="EndTime"): 112 | '''statistics_emotda 113 | Count the average emotion/DA per Dialogue 114 | 115 | ''' 116 | keys = list(set(data[dialogue_id])) 117 | count_dial_time = [] 118 | for key in keys: 119 | start_time = all_data[all_data[dialogue_id]==key][StartTime] 120 | end_time = all_data[all_data[dialogue_id]==key][EndTime] 121 | time_list = np.array([int(s) for s in end_time.tolist()])-np.array([int(s) for s in start_time.tolist()]) 122 | time_list = time_list.tolist() 123 | 124 | count_dial_time.append(sum(time_list)/len(time_list)) 125 | return {"avg":sum(count_dial_time)/len(count_dial_time)} 126 | 127 | 128 | def print_speaker(data_path): 129 | print("Output all the name of speakers from "+data_path) 130 | data = get_totaldata_for_analysis(data_path) 131 | data_speaker_result = get_row_statistics(data,"说话者姓名") 132 | print(data_speaker_result["keys"]) 133 | print(data_speaker_result["class"]) 134 | 135 | 136 | def print_speaker_from_dir(input_dir="/home/MMMTD/data/processed_cleaned_data/checked_processed_cleaned_data"): 137 | if not os.path.exists(input_dir): 138 | print("unexisted data dir:"+input_dir) 139 | return False 140 | print("From"+input_dir+"load file......") 141 | file_names = os.listdir(input_dir) 142 | for file_name in file_names: 143 | print_speaker(data_path= join(input_dir,file_name)) 144 | 145 | return True 146 | 147 | 148 | def statistic_speaker(data_path = "/home/MMMTD/data/MMMTD_cleaned_speaker_annotation.xlsx"): 149 | if not os.path.isfile(data_path): 150 | print("unexisted data dir:"+input_dir) 151 | return 0 152 | else: 153 | data = pd.read_excel(data_path) # xlrd==1.2.0, do not use xlrd==2.0.1 154 | gender_result = get_row_statistics(data,"性别") 155 | age_result = get_row_statistics(data,"年龄段") 156 | return gender_result, age_result 157 | 158 | 159 | def print_sentiment(data_path, sentiment='中性情绪'): 160 | print("From "+data_path+" return "+sentiment) 161 | data = get_totaldata_for_analysis(data_path) 162 | data_result = get_row_statistics(data,"情绪(粗粒度)") 163 | print(data_result['element_stastics'][sentiment]) 164 | return data_result['element_stastics'][sentiment]/data_result['size'] 165 | 166 | 167 | def print_sentiment_from_dir(input_dir="/home/MMMTD/data/processed_cleaned_data/checked_processed_cleaned_data"): 168 | if not os.path.exists(input_dir): 169 | print("unexisted data dir:"+input_dir) 170 | return False 171 | print("From"+input_dir+"load file......") 172 | file_names = os.listdir(input_dir) 173 | result_dir = {} 174 | for file_name in file_names: 175 | result_dir[file_name] = print_sentiment(data_path= join(input_dir,file_name), sentiment='中性情绪') 176 | 177 | return result_dir 178 | 179 | 180 | def count_eda_array_from_data(data,da_name = "DA", emotion_name = "Emotion", utt_id_name = "Utterance_ID"): 181 | # Get DA and Emotion labels 182 | da_label = list(set(data[da_name].tolist())) 183 | emotion_label = list(set(data[emotion_name].tolist())) 184 | print("da_label=",da_label) 185 | print("da_label数目", len(da_label)) 186 | print("emotion_label=",emotion_label) 187 | print("emotion_label数目", len(emotion_label)) 188 | # add index to label 189 | da_id = {} 190 | id = 0 191 | for da in da_label: 192 | da_id[da]=id 193 | id = id+1 194 | print(da_id) 195 | 196 | emotion_id={} 197 | id = 0 198 | for emotion in emotion_label: 199 | emotion_id[emotion]=id 200 | id = id+1 201 | print(emotion_id) 202 | 203 | #count_eda_array = np.zeros((len(da_label),len(emotion_label))) 204 | count_eda_array = np.zeros((len(emotion_label),len(da_label))) 205 | # output the initial array 206 | print(count_eda_array) 207 | 208 | # begin statistics 209 | 210 | for utt_id in data[utt_id_name].tolist(): 211 | # for the emotion and DA of each row 212 | # print(data[data[utt_id_name]== utt_id ][da_name]) 213 | current_da_id = da_id[ str(data[data[utt_id_name]== utt_id ][da_name].values.astype("str")[0]) ] 214 | current_emotion_id = emotion_id[ str(data[data[utt_id_name]== utt_id ][emotion_name].values.astype("str")[0]) ] 215 | count_eda_array[current_emotion_id, current_da_id] = count_eda_array[current_emotion_id, current_da_id] + 1 216 | print("完成统计后的结果",count_eda_array) 217 | 218 | pro_eda_array=np.zeros((len(emotion_label),len(da_label))) 219 | 220 | da_results = get_row_statistics(data=data,row_name= da_name) 221 | emotion_results = get_row_statistics(data=data,row_name= emotion_name) 222 | 223 | 224 | for da in da_id: 225 | for emotion in emotion_id: 226 | current_da_id = da_id[da] 227 | current_emotion_id = emotion_id[emotion] 228 | current_da_number = da_results["element_stastics"][da] 229 | #current_emotion_number = emotion_results["element_stastics"][emotion] 230 | pro_eda_array[current_emotion_id, current_da_id] = count_eda_array[current_emotion_id, current_da_id]/(current_da_number) 231 | return da_id, emotion_id, count_eda_array, pro_eda_array 232 | 233 | -------------------------------------------------------------------------------- /prc_baseline/README.md: -------------------------------------------------------------------------------- 1 | # PRC: Personality Recognition in Conversation 2 | # 对话中的个性识别任务 3 | * **Author: Chen Yirong ** 4 | * **Date: 2022.06.01** 5 | 6 | **Note:** Code and model will be released in the future! 7 | 8 | ## Task Definition 9 | ## 任务定义 10 | See [https://paperswithcode.com/task/personality-recognition-in-conversation](https://paperswithcode.com/task/personality-recognition-in-conversation) --------------------------------------------------------------------------------