├── .gitignore
├── LICENSE
├── README-zh.md
├── README.md
├── data
└── CPED
│ ├── speakers.txt
│ ├── test_split.csv
│ ├── train_split.csv
│ └── valid_split.csv
├── envs
└── py3.8_torch1.9.0_ignite0.4.8_tensorflow2.2.0_cuda10.2_transformers4.18.0_paddlepaddle-gpu_2.3.0.yml
├── erc_baseline
└── README.md
├── images
├── dataset_comparison.png
└── dataset_staticstics.png
├── pec_baseline
├── README.md
├── models
│ ├── README.md
│ ├── __init__.py
│ ├── base_model.py
│ ├── gpt
│ │ ├── __init__.py
│ │ ├── modeling_openai.py
│ │ └── tokenization_openai.py
│ ├── gpt2
│ │ ├── __init__.py
│ │ ├── modeling_gpt2.py
│ │ └── tokenization_gpt2.py
│ └── model_parameters.py
├── train_model.py
├── train_model_script.sh
└── utils
│ ├── README.md
│ ├── __init__.py
│ ├── base_util.py
│ ├── cped_dataset.py
│ ├── cped_util.py
│ └── dataset_statistics.py
└── prc_baseline
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # dataset
132 | CPED_cache_for_CpedDataset
133 | logs/
134 | runs/
135 | results/
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README-zh.md:
--------------------------------------------------------------------------------
1 | # [CPED](https://github.com/scutcyr/CPED)
2 | [](#python) [](https://arxiv.org/abs/2205.14727) [](https://github.com/scutcyr/CPED/stargazers) [](https://github.com/scutcyr/CPED/blob/main/LICENSE)  [](https://github.com/psf/black) 
3 |
4 | README: [English](https://github.com/scutcyr/CPED/blob/main/README.md) | [中文](https://github.com/scutcyr/CPED/blob/main/README-zh.md)
5 | 该仓库提供下面的论文的实现细节:
6 | **[CPED: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI](https://arxiv.org/abs/2205.14727)**
7 |
8 | 更多信息请参考我们的[论文](https://arxiv.org/abs/2205.14727)。
9 |
10 | 数据集已经同步发布在千言平台: [https://www.luge.ai/#/luge/dataDetail?id=41](https://www.luge.ai/#/luge/dataDetail?id=41)
11 |
12 | ## 目录
13 | * 简介
14 | * 数据集统计学特性
15 | * 任务定义
16 | * 实验结果
17 | * 使用方法
18 |
19 | ## 简介
20 | 我们构建了一个命名为**CPED**的数据集,该数据集源于40部中文电视剧。
21 | CPED包括与情感、个性特质相关的多源知识,包括:13类情绪、性别、大五人格、19类对话动作以及其他知识。下表给出了CPED与其他常见数据集的比较。
22 |
23 | * 我们构建了一个多轮的中文个性情感对话数据集CPED。据我们所知,CPED是首个中文个性情感对话数据集。它包括超过1.2万个对话,超过13.3万个语句,并且是多模态的。因此,该数据集可以用在复杂的对话理解任务以及拟人化的对话生成任务研究。
24 | * CPED提供了3类属性标注(姓名、性别、年龄),大五人格特质标注,2类情感标注(3分类粗粒度情感、13分类细粒度情感),以及对话动作DA标注。人格特质和情感可以用作开放域对话生成的先验外部知识。提升对话系统的拟人化水平。
25 | * 我们在论文中提出了3个任务:对话中的人格识别(PRC),对话中的情感识别(ERC),以及个性情感对话生成(PEC),一系列实验验证了人格以及情感对于对话生成的重要性。
26 |
27 | 
28 |
29 | ## 数据集统计学特性
30 | 为了让对话系统学习情感表达和个性表达能力,我们提供了下表中列出的多种类型的注释标签。
31 |
32 | | # of annos. | Labels | Num. |
33 | |:-----------:|:-------|:----:|
34 | | Sentiment | positive, neutral, and negative | 3 |
35 | | Emotion | happy, grateful, relaxed, other-positive, neutral, angry, sad, feared, depressed, disgusted, astonished, worried and other-negative | 13 |
36 | | Gender | male, female, and unknown | 3 |
37 | | Age group | children, teenager, young, middle-aged, elderly and unknown | 6 |
38 | | Big Five | high, low, and unknown | 3 |
39 | | DA | greeting (g), question (q), answer (ans), statement-opinion (sv), statement-non-opinion (sd), apology (fa), command (c), agreement/acceptance (aa), disagreement (dag), acknowledge (a), appreciation (ba), interjection (ij), conventional-closing (fc), thanking (ft), quotation (^q), reject(rj), irony (ir), comfort (cf) and other (oth) | 19 |
40 | | Scene | home, office, school, mall, hospital, restaurant, sports-venue, entertainment-venue, car, outdoor and other-scene | 11 |
41 |
42 |
43 | CPED数据集中性别、年龄、3分类情感、13分类细粒度情绪和DA的统计学分布如下图所示。
44 | 
45 |
46 | CPED的各项统计信息如下表所示.
47 | | 统计项 | 训练集 | 验证集 | 测试集 |
48 | |-----------------------|---------|---------|---------|
49 | | 模态 | (v,a,t) | (v,a,t) | (v,a,t) |
50 | | 电视剧 | 26 | 5 | 9 |
51 | | 对话 | 8,086 | 934 | 2,815 |
52 | | 语句 | 94,187 | 11,137 | 27,438 |
53 | | 说话人 | 273 | 38 | 81 |
54 | | 每个对话的平均句子数 | 11.6 | 11.9 | 9.7 |
55 | | 对话的最大句子数 | 75 | 31 | 34 |
56 | | 每个对话的平均情感类别数 | 2.8 | 3.4 | 3.2 |
57 | | 每个对话的平均DA类别数 | 3.6 | 3.7 | 3.2 |
58 | | 平均句子长度 | 8.3 | 8.2 | 8.3 |
59 | | 最大句子长度 | 127 | 42 | 45 |
60 | | 语句的平均语音长度 | 2.1s | 2.12s | 2.21s |
61 |
62 |
63 |
64 | ## 任务定义
65 | CPED可以用于对话理解任务和对话生成任务的评估,例如说话人建模、对话中的个性识别、对话中的情感识别、对话中的DA识别、回复的情感预测、情感对话生成、个性会话生成、移情对话生成等,CPED还可以应用于多模态人格或情感识别、多模态对话生成。它将对促进认知智能的发展起到积极的作用。
66 | 我们在本项目当中引入3种任务,如下所示:
67 | * **ERC**: [对话中的情感识别任务](https://paperswithcode.com/task/emotion-recognition-in-conversation)
68 | * **PRC**: [对话中的人格(个性)识别任务](https://paperswithcode.com/task/personality-recognition-in-conversation)
69 | * **PEC**: [个性情感对话生成任务](https://paperswithcode.com/task/personalized-and-emotional-conversation)
70 |
71 |
72 |
73 | ## 使用方法
74 | 如果你使用conda配置虚拟环境,你可以通过以下命令创建运行baseline模型的python虚拟环境:
75 | ```bash
76 | conda create -n py38 python=3.8
77 | conda activate py38
78 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
79 | pip install tensorflow==2.2.0
80 | pip install transformers==4.18.0
81 | python -m pip install paddlepaddle-gpu==2.3.0 -i https://mirror.baidu.com/pypi/simple
82 | pip install pytorch-ignite==0.4.8
83 | pip install notebook
84 | pip install pandas
85 | pip install chardet
86 | pip install matplotlib==3.5.2
87 | python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/
88 | python -m pip install ppasr -i https://mirrors.aliyun.com/pypi/simple/ -U
89 | pip install nltk
90 | pip install bert-score
91 | ```
92 |
93 | 部分依赖包的使用版本如下所示:
94 | ```bash
95 | python=3.8
96 | torch==1.9.0+cu102
97 | torchvision==0.10.0+cu102
98 | torchaudio==0.9.0
99 | tensorflow==2.2.0
100 | tensorboard==2.2.2
101 | transformers==4.18.0
102 | paddlepaddle-gpu==2.3.0
103 | paddlenlp==2.3.2
104 | pytorch-ignite==0.4.8
105 | matplotlib==3.5.2
106 | notebook==6.4.11
107 | pandas==1.4.2
108 | chardet==4.0.0
109 | nltk==3.7
110 | bert-score==0.3.11
111 | ```
112 |
113 |
114 |
115 | 如果你在研究当中使用到CPED数据集或者本项目,请引用以下论文:
116 | ```
117 | @article{chen2022cped,
118 | title={{CPED}: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI},
119 | author={Yirong Chen and Weiquan Fan and Xiaofen Xing and Jianxin Pang and Minlie Huang and Wenjing Han and Qianfeng Tie and Xiangmin Xu},
120 | journal={arXiv preprint arXiv:2205.14727},
121 | year={2022},
122 | url={https://arxiv.org/abs/2205.14727}
123 | }
124 | ```
125 |
126 | >>> 人体数据感知教育部工程研究中心
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CPED](https://github.com/scutcyr/CPED)
2 | [](#python) [](https://arxiv.org/abs/2205.14727) [](https://github.com/scutcyr/CPED/stargazers) [](https://github.com/scutcyr/CPED/blob/main/LICENSE)  [](https://github.com/psf/black) 
3 |
4 |
5 | README: [English](https://github.com/scutcyr/CPED/blob/main/README.md) | [中文](https://github.com/scutcyr/CPED/blob/main/README-zh.md)
6 | This repository provides the implementation details for the paper:
7 | **[CPED: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI](https://arxiv.org/abs/2205.14727)**
8 |
9 | For more information, please refer to our [paper](https://arxiv.org/abs/2205.14727).
10 |
11 | The dataset is also available in luge.ai: [https://www.luge.ai/#/luge/dataDetail?id=41](https://www.luge.ai/#/luge/dataDetail?id=41)
12 |
13 | ## Contents
14 | * Introduction
15 | * Dataset Statistics
16 | * Task Definition
17 | * Evaluation Results
18 | * Usage
19 |
20 | ## Introduction
21 | We construct a dataset named **CPED** from 40 Chinese TV shows. CPED consists of multisource knowledge related to empathy and personal characteristic. This knowledge covers 13 emotions, gender, Big Five personality traits, 19 dialogue acts and other knowledge. The table below shows a comparison of CPED with some other common conversation data sets.
22 |
23 | * We build a multiturn Chinese Personalized and Emotional Dialogue dataset called CPED. To the best of our knowledge, CPED is the first Chinese personalized and emotional dialogue dataset. CPED contains 12K dialogues and 133K utterances with multi-modal context. Therefore, it can be used in both complicated dialogue understanding and human-like conversation generation.
24 | * CPED has been annotated with 3 character attributes (name, gender age), Big Five personality traits, 2 types of dynamic emotional information (sentiment and emotion) and DAs. The personality traits and emotions can be used as prior external knowledge for open-domain conversation generation, making the conversation system have a good command of personification capabilities.
25 | * We propose three tasks for CPED: **personality recognition in conversations (PRC)**, **emotion recognition in conversations (ERC)**, and **personalized and emotional conversation (PEC)**. A set of experiments verify the importance of using personalities and emotions as prior external knowledge for conversation generation.
26 |
27 | 
28 |
29 | ## Dataset Statistics
30 | In order for the dialogue system to learn emotional expression and personalized expression abilities, we provide multiple types of annotation labels listed in the following Table.
31 |
32 | | # of annos. | Labels | Num. |
33 | |:-----------:|:-------|:----:|
34 | | Sentiment | positive, neutral, and negative | 3 |
35 | | Emotion | happy, grateful, relaxed, other-positive, neutral, angry, sad, feared, depressed, disgusted, astonished, worried and other-negative | 13 |
36 | | Gender | male, female, and unknown | 3 |
37 | | Age group | children, teenager, young, middle-aged, elderly and unknown | 6 |
38 | | Big Five | high, low, and unknown | 3 |
39 | | DA | greeting (g), question (q), answer (ans), statement-opinion (sv), statement-non-opinion (sd), apology (fa), command (c), agreement/acceptance (aa), disagreement (dag), acknowledge (a), appreciation (ba), interjection (ij), conventional-closing (fc), thanking (ft), quotation (^q), reject(rj), irony (ir), comfort (cf) and other (oth) | 19 |
40 | | Scene | home, office, school, mall, hospital, restaurant, sports-venue, entertainment-venue, car, outdoor and other-scene | 11 |
41 |
42 |
43 | Distribution of Gender, Age Group, Sentiment, Emotion and DA in CPED Dataset are shown in the following figure.
44 | 
45 |
46 | The statistics of CPED are listed in the following table.
47 | | Statistics | Train | Dev | Test |
48 | |---------------------------------|---------|---------|---------|
49 | | # of modalities | (v,a,t) | (v,a,t) | (v,a,t) |
50 | | # of TV plays | 26 | 5 | 9 |
51 | | # of dialogues | 8,086 | 934 | 2,815 |
52 | | # of utterances | 94,187 | 11,137 | 27,438 |
53 | | # of speakers | 273 | 38 | 81 |
54 | | Avg. # utt. per dial. | 11.6 | 11.9 | 9.7 |
55 | | Max # utt. per dial. | 75 | 31 | 34 |
56 | | Avg. # of emot. per dial. | 2.8 | 3.4 | 3.2 |
57 | | Avg. # of DAs per dial. | 3.6 | 3.7 | 3.2 |
58 | | Avg. utt. length | 8.3 | 8.2 | 8.3 |
59 | | Max utt. length | 127 | 42 | 45 |
60 | | Avg. duration of an utterance | 2.1s | 2.12s | 2.21s |
61 |
62 |
63 | ## Task Definition
64 | CPED allows evaluation of both conversational cognitive tasks and conversation generation tasks, e.g. speaker modeling, personality recognition in conversations, emotion recognition in conversations, DA recognition in conversations, emotion prediction for response, emotional conversation generation, personalized conversation generation, empathetic conversation etc. By being multimodal, CPED can also be applied in multimodal personality or emotion recognition, multimodal conversation generation. It will play a positive role in promoting the development of cognitive intelligence.
65 | We introduced 3 tasks in the project:
66 | * **ERC**: [Emotion Recognition in Conversation](https://paperswithcode.com/task/emotion-recognition-in-conversation)
67 | * **PRC**: [Personality Recognition in Conversation](https://paperswithcode.com/task/personality-recognition-in-conversation)
68 | * **PEC**: [Personalized and Emotional Conversation](https://paperswithcode.com/task/personalized-and-emotional-conversation)
69 |
70 |
71 |
72 | ## Usage
73 | You can create the python virtual environment through the following bash script:
74 | ```bash
75 | conda create -n py38 python=3.8
76 | conda activate py38
77 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
78 | pip install tensorflow==2.2.0
79 | pip install transformers==4.18.0
80 | python -m pip install paddlepaddle-gpu==2.3.0 -i https://mirror.baidu.com/pypi/simple
81 | pip install pytorch-ignite==0.4.8
82 | pip install notebook
83 | pip install pandas
84 | pip install chardet
85 | pip install matplotlib==3.5.2
86 | python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/
87 | python -m pip install ppasr -i https://mirrors.aliyun.com/pypi/simple/ -U
88 | pip install nltk
89 | pip install bert-score
90 | ```
91 |
92 | some version of the used packages are as follows:
93 | ```bash
94 | python=3.8
95 | torch==1.9.0+cu102
96 | torchvision==0.10.0+cu102
97 | torchaudio==0.9.0
98 | tensorflow==2.2.0
99 | tensorboard==2.2.2
100 | transformers==4.18.0
101 | paddlepaddle-gpu==2.3.0
102 | paddlenlp==2.3.2
103 | pytorch-ignite==0.4.8
104 | matplotlib==3.5.2
105 | notebook==6.4.11
106 | pandas==1.4.2
107 | chardet==4.0.0
108 | nltk==3.7
109 | bert-score==0.3.11
110 | ```
111 |
112 |
113 |
114 |
115 | Please cite our paper if you use CPED or this project:
116 | ```
117 | @article{chen2022cped,
118 | title={{CPED}: A Large-Scale Chinese Personalized and Emotional Dialogue Dataset for Conversational AI},
119 | author={Yirong Chen and Weiquan Fan and Xiaofen Xing and Jianxin Pang and Minlie Huang and Wenjing Han and Qianfeng Tie and Xiangmin Xu},
120 | journal={arXiv preprint arXiv:2205.14727},
121 | year={2022},
122 | url={https://arxiv.org/abs/2205.14727}
123 | }
124 | ```
125 |
126 | >>> Engineering Research Ceter of Ministry of Education on Human Body Perception
127 |
--------------------------------------------------------------------------------
/data/CPED/speakers.txt:
--------------------------------------------------------------------------------
1 | 于德伟
2 | 王柏川
3 | 袁大头
4 | 外卖小哥
5 | 林磊儿
6 | 贾小朵
7 | 刘梅
8 | 孙强
9 | 果然
10 | 钱三一
11 | 谭宗明
12 | 李月如
13 | 张医生
14 | 方芳
15 | 毕胜男
16 | 邓小可母亲
17 | 王胜男
18 | 彭永辉
19 | 刘静
20 | 关雎尔
21 | 宋倩
22 | 刘珍珍
23 | 崔朕父亲
24 | 秦羽墨
25 | 童娟
26 | 诸葛大力
27 | 林大为
28 | 王珊
29 | 丁香
30 | 陈建
31 | 方一凡
32 | 陆展博
33 | 王岩
34 | 陆母
35 | 刘铮
36 | 赵美兰
37 | 高健
38 | 罗琦
39 | 石天冬
40 | 美薇
41 | 陈爱莲
42 | 陆小山
43 | 夏峰
44 | 黄浩达
45 | 应勤
46 | 陆永福
47 | 林妙妙
48 | 唐悠悠
49 | 赵海棠
50 | 安娜
51 | 毕大千
52 | 方圆
53 | 王东方
54 | 苏明哲
55 | 乔英子
56 | 陈晓伟
57 | 白洁
58 | 方骏
59 | 江达琳
60 | 唐鹏母亲
61 | 陆晓婷
62 | 李理
63 | 小白
64 | 唐鹏
65 | 蔡钟闲
66 | 栗伟正
67 | 张盛
68 | 苑春丽
69 | 黄晓峰
70 | 诺澜
71 | 王长水
72 | 姗姗
73 | 乔卫东
74 | 孙雅娴
75 | 罗槟
76 | 蔡钟闲妻子
77 | 樊斌
78 | 徐笑
79 | 陈可妈妈
80 | 校长
81 | 金志明
82 | 梁晓慧
83 | 大磊
84 | 唐娇
85 | 孙超越
86 | 于扬
87 | 叶珊
88 | 苏明成
89 | 罗丹
90 | 张天皓
91 | 汪涵
92 | 乔伊林
93 | 路易斯
94 | 周芳
95 | 童文洁助理
96 | 江长恩
97 | 唐爽
98 | 周佳成
99 | 咖喱酱
100 | 张一男
101 | 顾大鹏
102 | 程洪斗
103 | 关谷神奇
104 | 陈佳
105 | 吴芳妮
106 | 江焱
107 | 陶语桐爸爸
108 | 莫小康
109 | 卢家凯
110 | 杨桃
111 | 方依依
112 | 蓝红
113 | 顾末易
114 | 陆小贝
115 | 蒙志远
116 | 周小北
117 | 何西
118 | 崔朕妈妈
119 | 邓小可
120 | 徐豆豆
121 | 张伟
122 | 戴曦
123 | 焦阳
124 | 张超
125 | 张一男妈妈
126 | 童心
127 | 林宛瑜
128 | 姚美娟
129 | 杨呈远妈
130 | 斯嘉丽
131 | 封印
132 | 秦羽墨前男友
133 | 顾映真
134 | 舍友
135 | 南海
136 | 崔朕
137 | 蒂娜
138 | 苏大强
139 | 赵启平
140 | 毕然
141 | 蓝馨
142 | 老罗
143 | 周格格
144 | 沈画
145 | 彭凯旋
146 | 包奕凡
147 | 邱莹莹
148 | 小熊
149 | 雷瑛
150 | 那依
151 | 袁帅
152 | 何东
153 | 朱丽
154 | 艾小天
155 | 卫哲
156 | 丽萨
157 | 泰勒
158 | 柳青
159 | 麦飞
160 | 李闻雨
161 | 黄越彬
162 | 郑海潮
163 | 张亮忠
164 | 林晓珑
165 | 陈可依
166 | 刘旭刚
167 | 胡七星
168 | 琴琴
169 | 小艾
170 | 夏雨
171 | 安迪
172 | 其他
173 | 陶语桐妈妈
174 | 司徒末
175 | 杨嘟嘟
176 | 沈英杰
177 | 王媛
178 | 奥特
179 | 刘炎
180 | 肖一飞
181 | 陆曼妮
182 | 田坤
183 | 叶昭君
184 | 季胜利
185 | 于春晓
186 | 陈美嘉
187 | 郑远东
188 | 冯晶晶
189 | 夏冰
190 | 何北
191 | 张小宇
192 | 韩文静
193 | 赵冲
194 | 段西风
195 | 肖威
196 | 邓小琪
197 | 吴非
198 | 曲筱绡
199 | 福子
200 | 冀遇
201 | 樊胜美
202 | 程俊
203 | 莫正源
204 | 朴恩熙
205 | 牛美丽
206 | 于建国
207 | 沈青川
208 | 宁巴拉
209 | 成晓峰
210 | 袁则强
211 | 胡一菲
212 | 宋宁宇
213 | 江天昊
214 | 宋光明
215 | 方朵朵
216 | 程皓
217 | 戴天高
218 | 邓文宣
219 | 傅沛
220 | 小梦
221 | 罗浩
222 | 郝泽宇
223 | 杨呈远
224 | 吴佳妮
225 | 苏青
226 | 齐大胜
227 | 果长山
228 | 苏珊
229 | 廖佳敏
230 | 田坤前女友
231 | 童小麒
232 | 梁晓慧爸爸
233 | 刘星
234 | 李三弟
235 | 顾婕
236 | 童文洁
237 | 福方树
238 | 王佳佳
239 | 赵小亮
240 | 马邦尼
241 | 曾小贤
242 | 戴娜
243 | 魏山山
244 | 童文洁老板
245 | 夏天
246 | 罗海燕
247 | 薛素梅
248 | 宋暖
249 | 王珊珊父亲
250 | 于小强
251 | 胡一统
252 | 陆小贝母亲
253 | 吕子乔
254 | 梁伊
255 | 刘光耀
256 | 徐天
257 | 向飞
258 | 冯兰芝
259 | 丛卉
260 | 于果
261 | 魏渭
262 | 罗玥
263 | 刘慧芸
264 | 蒂娜妈妈
265 | 黄芷陶
266 | 李老师
267 | 季杨杨
268 | 表姐
269 | 郝敏
270 | 夏东海
271 | 刘兰芝
272 | 李三妹
273 | 斯黛拉
274 | 孙总
275 | 陆长山
276 | 何赛
277 | 鲍家明
278 | 权筝
279 | 余峥
280 | 栗娜
281 | 刘栋
282 | 张铭阳
283 | 王涛
284 | 姚梅
285 | 王博
286 | 陶语桐
287 | 舒晴
288 | 李枫
289 | 邹男
290 | 夏雪
291 | 林君
292 | 高红
293 | 罗素
294 | 刀美岚
295 | 赵小川
296 | 姚澜
297 | 宋大楠
298 | 潘芸
299 | 欧阳雨露
300 | 邹北业
301 | 罗茜茜
302 | 苏明玉
303 | 王珊珊
--------------------------------------------------------------------------------
/envs/py3.8_torch1.9.0_ignite0.4.8_tensorflow2.2.0_cuda10.2_transformers4.18.0_paddlepaddle-gpu_2.3.0.yml:
--------------------------------------------------------------------------------
1 | name: py38
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2022.4.26=h06a4308_0
8 | - certifi=2022.5.18.1=py38h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.3=he6710b0_2
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - ncurses=6.3=h7f8727e_2
15 | - openssl=1.1.1o=h7f8727e_0
16 | - pip=21.2.4=py38h06a4308_0
17 | - python=3.8.13=h12debd9_0
18 | - readline=8.1.2=h7f8727e_1
19 | - setuptools=61.2.0=py38h06a4308_0
20 | - sqlite=3.38.3=hc218d9a_0
21 | - tk=8.6.12=h1ccaba5_0
22 | - wheel=0.37.1=pyhd3eb1b0_0
23 | - xz=5.2.5=h7f8727e_1
24 | - zlib=1.2.12=h7f8727e_2
25 | - pip:
26 | - absl-py==1.1.0
27 | - aiohttp==3.8.1
28 | - aiosignal==1.2.0
29 | - appdirs==1.4.4
30 | - argon2-cffi==21.3.0
31 | - argon2-cffi-bindings==21.2.0
32 | - astor==0.8.1
33 | - asttokens==2.0.5
34 | - astunparse==1.6.3
35 | - async-timeout==4.0.2
36 | - attrs==21.4.0
37 | - audioread==2.1.9
38 | - babel==2.10.1
39 | - backcall==0.2.0
40 | - bce-python-sdk==0.8.64
41 | - beautifulsoup4==4.11.1
42 | - bert-score==0.3.11
43 | - bleach==5.0.0
44 | - cachetools==4.2.4
45 | - cffi==1.15.0
46 | - cfgv==3.3.1
47 | - chardet==4.0.0
48 | - charset-normalizer==2.0.12
49 | - click==8.1.3
50 | - cn2an==0.5.17
51 | - colorama==0.4.4
52 | - colorlog==6.6.0
53 | - cycler==0.11.0
54 | - datasets==2.2.2
55 | - debugpy==1.6.0
56 | - decorator==5.1.1
57 | - defusedxml==0.7.1
58 | - dill==0.3.4
59 | - distlib==0.3.4
60 | - entrypoints==0.4
61 | - executing==0.8.3
62 | - fastjsonschema==2.15.3
63 | - filelock==3.7.1
64 | - flake8==4.0.1
65 | - flask==2.1.2
66 | - flask-babel==2.0.0
67 | - fonttools==4.33.3
68 | - frozenlist==1.3.0
69 | - fsspec==2022.5.0
70 | - future==0.18.2
71 | - gast==0.3.3
72 | - google-auth==1.35.0
73 | - google-auth-oauthlib==0.4.6
74 | - google-pasta==0.2.0
75 | - grpcio==1.46.3
76 | - h5py==2.10.0
77 | - huggingface-hub==0.7.0
78 | - identify==2.5.1
79 | - idna==3.3
80 | - importlib-metadata==4.11.4
81 | - importlib-resources==5.7.1
82 | - ipykernel==6.13.0
83 | - ipython==8.4.0
84 | - ipython-genutils==0.2.0
85 | - itsdangerous==2.1.2
86 | - jedi==0.18.1
87 | - jieba==0.42.1
88 | - jinja2==3.1.2
89 | - joblib==1.1.0
90 | - jsonschema==4.6.0
91 | - jupyter-client==7.3.2
92 | - jupyter-core==4.10.0
93 | - jupyterlab-pygments==0.2.2
94 | - keras-preprocessing==1.1.2
95 | - kiwisolver==1.4.2
96 | - librosa==0.8.0
97 | - llvmlite==0.38.1
98 | - markdown==3.3.7
99 | - markupsafe==2.1.1
100 | - matplotlib==3.5.2
101 | - matplotlib-inline==0.1.3
102 | - mccabe==0.6.1
103 | - mistune==0.8.4
104 | - mock==4.0.3
105 | - multidict==6.0.2
106 | - multiprocess==0.70.12.2
107 | - nbclient==0.6.4
108 | - nbconvert==6.5.0
109 | - nbformat==5.4.0
110 | - nest-asyncio==1.5.5
111 | - nltk==3.7
112 | - nodeenv==1.6.0
113 | - notebook==6.4.11
114 | - numba==0.55.2
115 | - numpy==1.22.4
116 | - oauthlib==3.2.0
117 | - opt-einsum==3.3.0
118 | - packaging==21.3
119 | - paddle-bfloat==0.1.2
120 | - paddle2onnx==0.9.7
121 | - paddlefsl==1.1.0
122 | - paddlenlp==2.3.2
123 | - paddlepaddle-gpu==2.3.0
124 | - paddlespeech-feat==0.1.0
125 | - pandas==1.4.2
126 | - pandocfilters==1.5.0
127 | - parso==0.8.3
128 | - pexpect==4.8.0
129 | - pickleshare==0.7.5
130 | - pillow==9.1.1
131 | - platformdirs==2.5.2
132 | - pooch==1.6.0
133 | - ppasr==0.1.5
134 | - pre-commit==2.19.0
135 | - proces==0.1.2
136 | - prometheus-client==0.14.1
137 | - prompt-toolkit==3.0.29
138 | - protobuf==3.20.0
139 | - psutil==5.9.1
140 | - ptyprocess==0.7.0
141 | - pure-eval==0.2.2
142 | - pyarrow==8.0.0
143 | - pyasn1==0.4.8
144 | - pyasn1-modules==0.2.8
145 | - pycodestyle==2.8.0
146 | - pycparser==2.21
147 | - pycryptodome==3.14.1
148 | - pydub==0.25.1
149 | - pyflakes==2.4.0
150 | - pygments==2.12.0
151 | - pyparsing==3.0.9
152 | - pyrsistent==0.18.1
153 | - python-dateutil==2.8.2
154 | - python-levenshtein==0.12.2
155 | - pytorch-ignite==0.4.8
156 | - pytz==2022.1
157 | - pyyaml==6.0
158 | - pyzmq==23.1.0
159 | - regex==2022.6.2
160 | - requests==2.27.1
161 | - requests-oauthlib==1.3.1
162 | - resampy==0.2.2
163 | - responses==0.18.0
164 | - rsa==4.8
165 | - ruamel-yaml==0.17.21
166 | - ruamel-yaml-clib==0.2.6
167 | - sacremoses==0.0.53
168 | - scikit-learn==1.1.1
169 | - scipy==1.8.1
170 | - send2trash==1.8.0
171 | - sentencepiece==0.1.96
172 | - seqeval==1.2.2
173 | - shellcheck-py==0.8.0.4
174 | - six==1.16.0
175 | - soundfile==0.10.3.post1
176 | - soupsieve==2.3.2.post1
177 | - stack-data==0.2.0
178 | - tensorboard==2.2.2
179 | - tensorboard-plugin-wit==1.8.1
180 | - tensorflow==2.2.0
181 | - tensorflow-estimator==2.2.0
182 | - termcolor==1.1.0
183 | - terminado==0.15.0
184 | - threadpoolctl==3.1.0
185 | - tinycss2==1.1.1
186 | - tokenizers==0.12.1
187 | - toml==0.10.2
188 | - torch==1.9.0+cu102
189 | - torchaudio==0.9.0
190 | - torchvision==0.10.0+cu102
191 | - tornado==6.1
192 | - tqdm==4.59.0
193 | - traitlets==5.2.2.post1
194 | - transformers==4.18.0
195 | - typing-extensions==4.2.0
196 | - urllib3==1.26.9
197 | - virtualenv==20.14.1
198 | - visualdl==2.2.3
199 | - wcwidth==0.2.5
200 | - webencodings==0.5.1
201 | - webrtcvad==2.0.10
202 | - werkzeug==2.1.2
203 | - wrapt==1.14.1
204 | - xxhash==3.0.0
205 | - yarl==1.7.2
206 | - zhconv==1.4.3
207 | - zipp==3.8.0
208 | prefix: /home/phd-chen.yirong/anaconda3/envs/py38
209 |
--------------------------------------------------------------------------------
/erc_baseline/README.md:
--------------------------------------------------------------------------------
1 | # ERC: Emotion Recognition in Conversation
2 | # 对话中的情感识别任务
3 | * **Author: Chen Yirong **
4 | * **Date: 2022.06.01**
5 |
6 | **Note:** Code and model will be released in the future!
7 |
8 |
9 | ## Task Definition
10 | ## 任务定义
11 | See [https://paperswithcode.com/task/emotion-recognition-in-conversation](https://paperswithcode.com/task/emotion-recognition-in-conversation)
--------------------------------------------------------------------------------
/images/dataset_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scutcyr/CPED/1e4b81c28a123f22387e06664f37e5dc9322380f/images/dataset_comparison.png
--------------------------------------------------------------------------------
/images/dataset_staticstics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scutcyr/CPED/1e4b81c28a123f22387e06664f37e5dc9322380f/images/dataset_staticstics.png
--------------------------------------------------------------------------------
/pec_baseline/README.md:
--------------------------------------------------------------------------------
1 | # PEC: Personalized and Emotional Conversation
2 | # 个性情感对话生成任务
3 | * **Author: Chen Yirong **
4 | * **Date: 2022.06.01**
5 |
6 | **Note:** Code and model will be released in the future!
7 |
8 | ## Task Definition
9 | ## 任务定义
10 | See [https://paperswithcode.com/task/personalized-and-emotional-conversation](https://paperswithcode.com/task/personalized-and-emotional-conversation)
11 |
12 |
13 | ## Folder Description
14 | ## 目录说明
15 | * ```config```: Storing the configuration file of the models.
16 | * ```models```: Storing Python code that defines the models.
17 | * ```results```: Storing the result for testing the trained model.
18 | * ```runs```: Storing model parameters and related configurations generated during training or after training.
19 | * ```utils```: Storing Python code for reading dataset.
20 |
--------------------------------------------------------------------------------
/pec_baseline/models/README.md:
--------------------------------------------------------------------------------
1 | # models
2 | # 模型设计模块
3 | * **Author: Chen Yirong **
4 | * **Date: 2022.03.21**
5 |
6 | ## 架构说明
7 | 每个子文件夹存放一个模型,其中,文件夹命名使用小写字母+下划线+数字的组合,例如:```gpt```、```gpt2```、```gpt_per```。
8 | 每个模型由3个文件组成,假设该模型命名为```xxx```:
9 | * ```__init__.py```: 对外提供可访问的接口
10 | * ```modeling_xxx.py```: 模型类的定义,
11 | * ```tokenization_xxx.py```: 模型相关的tokenization类(不是必须的)
12 |
13 | ### 基础文件
14 |
15 |
16 |
--------------------------------------------------------------------------------
/pec_baseline/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Models
18 | # Author: Chen Yirong
19 | # Date: 2022.04.06
20 |
21 | __version__ = "1.0.0"
22 |
23 | # 关键包版本说明:
24 | # pytorch: 1.9.0+
25 | # transformers: 4.11.3+
26 |
27 | from .base_model import (is_torch_available)
28 |
29 | # 模型类
30 | if is_torch_available():
31 | from . import (
32 | gpt,
33 | gpt2,
34 | cvgpt
35 | )
36 |
37 | # Model parameters calculating and freezing
38 | from .model_parameters import (count_trainable_parameters, count_total_parameters, show_trainable_parameters,
39 | set_freeze_by_names, freeze_by_model_name, unfreeze_by_model_name)
--------------------------------------------------------------------------------
/pec_baseline/models/base_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Basic model configuration file
18 | # File: base_model.py
19 | # Used for model configuration
20 | # 用于数据集读取的基础方法
21 | # Author: Chen Yirong
22 | # Date: 2022.04.06
23 |
24 | import os
25 | import logging
26 | import importlib.util
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | _torch_available = importlib.util.find_spec("torch") is not None
31 |
32 | def is_torch_available():
33 | return _torch_available
--------------------------------------------------------------------------------
/pec_baseline/models/gpt/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # Copyright 2020 The HuggingFace Team. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # 关键包版本说明:
21 | # pytorch: 1.9.0+
22 | # transformers: 4.11.3+
23 |
24 |
25 | from typing import TYPE_CHECKING
26 |
27 | from transformers.file_utils import is_torch_available
28 |
29 | if is_torch_available():
30 | from .modeling_openai import (
31 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
32 | OpenAIGPTDoubleHeadsModel,
33 | OpenAIGPTForSequenceClassification,
34 | OpenAIGPTLMHeadModel,
35 | OpenAIGPTModel,
36 | OpenAIGPTPreTrainedModel,
37 | load_tf_weights_in_openai_gpt,
38 | )
39 |
--------------------------------------------------------------------------------
/pec_baseline/models/gpt/modeling_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch OpenAI GPT model."""
17 |
18 |
19 | # 关键包版本说明:
20 | # pytorch: 1.9.0+
21 | # transformers: 4.11.3
22 |
23 |
24 | import json
25 | import math
26 | import os
27 | from dataclasses import dataclass
28 | from typing import Optional, Tuple
29 |
30 | import torch
31 | from torch import nn
32 | from torch.nn import CrossEntropyLoss, MSELoss
33 | from torch.cuda.amp import autocast as autocast # 用于使用自动混合精度,要求torch版本为1.6+
34 |
35 | from transformers.activations import gelu_new, silu
36 | from transformers.file_utils import (
37 | ModelOutput,
38 | add_code_sample_docstrings,
39 | add_start_docstrings,
40 | add_start_docstrings_to_model_forward,
41 | replace_return_docstrings,
42 | )
43 | from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
44 | from transformers.modeling_utils import (
45 | Conv1D,
46 | PreTrainedModel,
47 | SequenceSummary,
48 | find_pruneable_heads_and_indices,
49 | prune_conv1d_layer,
50 | )
51 | from transformers.utils import logging
52 | from transformers import OpenAIGPTConfig
53 |
54 |
55 | logger = logging.get_logger(__name__)
56 |
57 | _CHECKPOINT_FOR_DOC = "openai-gpt"
58 | _CONFIG_FOR_DOC = "OpenAIGPTConfig"
59 | _TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
60 |
61 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
62 | "openai-gpt",
63 | # See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
64 | ]
65 |
66 |
67 | def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
68 | """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
69 | import re
70 |
71 | import numpy as np
72 |
73 | if ".ckpt" in openai_checkpoint_folder_path:
74 | openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
75 |
76 | logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
77 |
78 | with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
79 | names = json.load(names_handle)
80 | with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
81 | shapes = json.load(shapes_handle)
82 | offsets = np.cumsum([np.prod(shape) for shape in shapes])
83 | init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
84 | init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
85 | init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
86 |
87 | # This was used when we had a single embedding matrix for positions and tokens
88 | # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
89 | # del init_params[1]
90 | init_params = [arr.squeeze() for arr in init_params]
91 |
92 | try:
93 | assert model.tokens_embed.weight.shape == init_params[1].shape
94 | assert model.positions_embed.weight.shape == init_params[0].shape
95 | except AssertionError as e:
96 | e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
97 | e.args += (model.positions_embed.weight.shape, init_params[0].shape)
98 | raise
99 |
100 | model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
101 | model.positions_embed.weight.data = torch.from_numpy(init_params[0])
102 | names.pop(0)
103 | # Pop position and token embedding arrays
104 | init_params.pop(0)
105 | init_params.pop(0)
106 |
107 | for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
108 | name = name[6:] # skip "model/"
109 | assert name[-2:] == ":0"
110 | name = name[:-2]
111 | name = name.split("/")
112 | pointer = model
113 | for m_name in name:
114 | if re.fullmatch(r"[A-Za-z]+\d+", m_name):
115 | scope_names = re.split(r"(\d+)", m_name)
116 | else:
117 | scope_names = [m_name]
118 | if scope_names[0] == "g":
119 | pointer = getattr(pointer, "weight")
120 | elif scope_names[0] == "b":
121 | pointer = getattr(pointer, "bias")
122 | elif scope_names[0] == "w":
123 | pointer = getattr(pointer, "weight")
124 | else:
125 | pointer = getattr(pointer, scope_names[0])
126 | if len(scope_names) >= 2:
127 | num = int(scope_names[1])
128 | pointer = pointer[num]
129 | try:
130 | assert (
131 | pointer.shape == array.shape
132 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
133 | except AssertionError as e:
134 | e.args += (pointer.shape, array.shape)
135 | raise
136 | try:
137 | assert (
138 | pointer.shape == array.shape
139 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
140 | except AssertionError as e:
141 | e.args += (pointer.shape, array.shape)
142 | raise
143 | logger.info(f"Initialize PyTorch weight {name}")
144 | pointer.data = torch.from_numpy(array)
145 | return model
146 |
147 |
148 | ACT_FNS = {"relu": nn.ReLU, "silu": silu, "gelu": gelu_new, "swish": silu}
149 |
150 |
151 | class Attention(nn.Module):
152 | def __init__(self, nx, n_ctx, config, scale=False):
153 | super().__init__()
154 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
155 | # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
156 | assert n_state % config.n_head == 0
157 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
158 | self.n_head = config.n_head
159 | self.split_size = n_state
160 | self.scale = scale
161 |
162 | self.c_attn = Conv1D(n_state * 3, nx)
163 | self.c_proj = Conv1D(n_state, nx)
164 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
165 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
166 | self.pruned_heads = set()
167 |
168 | def prune_heads(self, heads):
169 | if len(heads) == 0:
170 | return
171 | heads, index = find_pruneable_heads_and_indices(
172 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
173 | )
174 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
175 | # Prune conv1d layers
176 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
177 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
178 | # Update hyper params
179 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
180 | self.n_head = self.n_head - len(heads)
181 | self.pruned_heads = self.pruned_heads.union(heads)
182 |
183 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
184 | w = torch.matmul(q, k)
185 | if self.scale:
186 | w = w / math.sqrt(v.size(-1))
187 | # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights
188 | # XD: self.b may be larger than w, so we need to crop it
189 | b = self.bias[:, :, : w.size(-2), : w.size(-1)]
190 | w = w * b + -1e4 * (1 - b)
191 |
192 | if attention_mask is not None:
193 | # Apply the attention mask
194 | w = w + attention_mask
195 |
196 | w = nn.Softmax(dim=-1)(w)
197 | w = self.attn_dropout(w)
198 |
199 | # Mask heads if we want to
200 | if head_mask is not None:
201 | w = w * head_mask
202 |
203 | outputs = [torch.matmul(w, v)]
204 | if output_attentions:
205 | outputs.append(w)
206 | return outputs
207 |
208 | def merge_heads(self, x):
209 | x = x.permute(0, 2, 1, 3).contiguous()
210 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
211 | return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states
212 |
213 | def split_heads(self, x, k=False):
214 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
215 | x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states
216 | if k:
217 | return x.permute(0, 2, 3, 1)
218 | else:
219 | return x.permute(0, 2, 1, 3)
220 |
221 | def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
222 | x = self.c_attn(x)
223 | query, key, value = x.split(self.split_size, dim=2)
224 | query = self.split_heads(query)
225 | key = self.split_heads(key, k=True)
226 | value = self.split_heads(value)
227 |
228 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
229 | a = attn_outputs[0]
230 |
231 | a = self.merge_heads(a)
232 | a = self.c_proj(a)
233 | a = self.resid_dropout(a)
234 |
235 | outputs = [a] + attn_outputs[1:]
236 | return outputs # a, (attentions)
237 |
238 |
239 | class MLP(nn.Module):
240 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
241 | super().__init__()
242 | nx = config.n_embd
243 | self.c_fc = Conv1D(n_state, nx)
244 | self.c_proj = Conv1D(nx, n_state)
245 | self.act = ACT_FNS[config.afn]
246 | self.dropout = nn.Dropout(config.resid_pdrop)
247 |
248 | def forward(self, x):
249 | h = self.act(self.c_fc(x))
250 | h2 = self.c_proj(h)
251 | return self.dropout(h2)
252 |
253 |
254 | class Block(nn.Module):
255 | def __init__(self, n_ctx, config, scale=False):
256 | super().__init__()
257 | nx = config.n_embd
258 | self.attn = Attention(nx, n_ctx, config, scale)
259 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
260 | self.mlp = MLP(4 * nx, config)
261 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
262 |
263 | def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
264 | attn_outputs = self.attn(
265 | x,
266 | attention_mask=attention_mask,
267 | head_mask=head_mask,
268 | output_attentions=output_attentions,
269 | )
270 | a = attn_outputs[0]
271 |
272 | n = self.ln_1(x + a)
273 | m = self.mlp(n)
274 | h = self.ln_2(n + m)
275 |
276 | outputs = [h] + attn_outputs[1:]
277 | return outputs
278 |
279 |
280 | class OpenAIGPTPreTrainedModel(PreTrainedModel):
281 | """
282 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
283 | models.
284 | """
285 |
286 | config_class = OpenAIGPTConfig
287 | load_tf_weights = load_tf_weights_in_openai_gpt
288 | base_model_prefix = "transformer"
289 | _keys_to_ignore_on_load_missing = [r"position_ids"]
290 |
291 | def _init_weights(self, module):
292 | """Initialize the weights."""
293 | if isinstance(module, (nn.Linear, Conv1D)):
294 | # Slightly different from the TF version which uses truncated_normal for initialization
295 | # cf https://github.com/pytorch/pytorch/pull/5617
296 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
297 | if module.bias is not None:
298 | module.bias.data.zero_()
299 | elif isinstance(module, nn.Embedding):
300 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
301 | if module.padding_idx is not None:
302 | module.weight.data[module.padding_idx].zero_()
303 | elif isinstance(module, nn.LayerNorm):
304 | module.bias.data.zero_()
305 | module.weight.data.fill_(1.0)
306 |
307 |
308 | @dataclass
309 | class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
310 | """
311 | Base class for outputs of models predicting if two sentences are consecutive or not.
312 |
313 | Args:
314 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
315 | Language modeling loss.
316 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
317 | Multiple choice classification loss.
318 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
319 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
320 | mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
321 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
322 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
323 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
324 | of shape :obj:`(batch_size, sequence_length, hidden_size)`.
325 |
326 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
327 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
328 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
329 | sequence_length, sequence_length)`.
330 |
331 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
332 | heads.
333 | """
334 |
335 | loss: Optional[torch.FloatTensor] = None
336 | mc_loss: Optional[torch.FloatTensor] = None
337 | logits: torch.FloatTensor = None
338 | mc_logits: torch.FloatTensor = None
339 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
340 | attentions: Optional[Tuple[torch.FloatTensor]] = None
341 |
342 |
343 | OPENAI_GPT_START_DOCSTRING = r"""
344 |
345 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
346 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
347 | pruning heads etc.)
348 |
349 | This model is also a PyTorch `torch.nn.Module `__
350 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
351 | general usage and behavior.
352 |
353 | Parameters:
354 | config (:class:`~transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
355 | Initializing with a config file does not load the weights associated with the model, only the
356 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
357 | weights.
358 | """
359 |
360 | OPENAI_GPT_INPUTS_DOCSTRING = r"""
361 | Args:
362 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
363 | Indices of input sequence tokens in the vocabulary.
364 |
365 | Indices can be obtained using :class:`~transformers.OpenAIGPTTokenizer`. See
366 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
367 | details.
368 |
369 | `What are input IDs? <../glossary.html#input-ids>`__
370 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
371 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
372 |
373 | - 1 for tokens that are **not masked**,
374 | - 0 for tokens that are **masked**.
375 |
376 | `What are attention masks? <../glossary.html#attention-mask>`__
377 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
378 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
379 | 1]``:
380 |
381 | - 0 corresponds to a `sentence A` token,
382 | - 1 corresponds to a `sentence B` token.
383 |
384 | `What are token type IDs? <../glossary.html#token-type-ids>`_
385 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
386 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
387 | config.max_position_embeddings - 1]``.
388 |
389 | `What are position IDs? <../glossary.html#position-ids>`__
390 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
391 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
392 |
393 | - 1 indicates the head is **not masked**,
394 | - 0 indicates the head is **masked**.
395 |
396 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
397 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
398 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
399 | vectors than the model's internal embedding lookup matrix.
400 | output_attentions (:obj:`bool`, `optional`):
401 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
402 | tensors for more detail.
403 | output_hidden_states (:obj:`bool`, `optional`):
404 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
405 | more detail.
406 | return_dict (:obj:`bool`, `optional`):
407 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
408 | """
409 |
410 |
411 | @add_start_docstrings(
412 | "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
413 | OPENAI_GPT_START_DOCSTRING,
414 | )
415 | class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
416 | def __init__(self, config):
417 | super().__init__(config)
418 |
419 | self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
420 | self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
421 | self.drop = nn.Dropout(config.embd_pdrop)
422 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
423 |
424 | self.register_buffer("position_ids", torch.arange(config.n_positions))
425 | self.init_weights()
426 |
427 | def get_input_embeddings(self):
428 | return self.tokens_embed
429 |
430 | def set_input_embeddings(self, new_embeddings):
431 | self.tokens_embed = new_embeddings
432 |
433 | def _prune_heads(self, heads_to_prune):
434 | """
435 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
436 | """
437 | for layer, heads in heads_to_prune.items():
438 | self.h[layer].attn.prune_heads(heads)
439 |
440 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
441 | @add_code_sample_docstrings(
442 | checkpoint=_CHECKPOINT_FOR_DOC,
443 | output_type=BaseModelOutput,
444 | config_class=_CONFIG_FOR_DOC,
445 | )
446 | @autocast()
447 | def forward(
448 | self,
449 | input_ids=None,
450 | attention_mask=None,
451 | token_type_ids=None,
452 | position_ids=None,
453 | head_mask=None,
454 | inputs_embeds=None,
455 | output_attentions=None,
456 | output_hidden_states=None,
457 | return_dict=None,
458 | ):
459 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
460 | output_hidden_states = (
461 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
462 | )
463 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
464 |
465 | if input_ids is not None and inputs_embeds is not None:
466 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
467 | elif input_ids is not None:
468 | input_shape = input_ids.size()
469 | input_ids = input_ids.view(-1, input_shape[-1])
470 | elif inputs_embeds is not None:
471 | input_shape = inputs_embeds.size()[:-1]
472 | else:
473 | raise ValueError("You have to specify either input_ids or inputs_embeds")
474 |
475 | if position_ids is None:
476 | # Code is different from when we had a single embedding matrix from position and token embeddings
477 | position_ids = self.position_ids[None, : input_shape[-1]]
478 |
479 | # Attention mask.
480 | if attention_mask is not None:
481 | # We create a 3D attention mask from a 2D tensor mask.
482 | # Sizes are [batch_size, 1, 1, to_seq_length]
483 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
484 | # this attention mask is more simple than the triangular masking of causal attention
485 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
486 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
487 |
488 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
489 | # masked positions, this operation will create a tensor which is 0.0 for
490 | # positions we want to attend and -10000.0 for masked positions.
491 | # Since we are adding it to the raw scores before the softmax, this is
492 | # effectively the same as removing these entirely.
493 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
494 | attention_mask = (1.0 - attention_mask) * -10000.0
495 |
496 | # Prepare head mask if needed
497 | head_mask = self.get_head_mask(head_mask, self.config.n_layer)
498 |
499 | if inputs_embeds is None:
500 | inputs_embeds = self.tokens_embed(input_ids)
501 | position_embeds = self.positions_embed(position_ids)
502 | if token_type_ids is not None:
503 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
504 | token_type_embeds = self.tokens_embed(token_type_ids)
505 | else:
506 | token_type_embeds = 0
507 | hidden_states = inputs_embeds + position_embeds + token_type_embeds
508 | hidden_states = self.drop(hidden_states)
509 |
510 | output_shape = input_shape + (hidden_states.size(-1),)
511 |
512 | all_attentions = () if output_attentions else None
513 | all_hidden_states = () if output_hidden_states else None
514 | for i, block in enumerate(self.h):
515 | if output_hidden_states:
516 | all_hidden_states = all_hidden_states + (hidden_states,)
517 |
518 | outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
519 | hidden_states = outputs[0]
520 | if output_attentions:
521 | all_attentions = all_attentions + (outputs[1],)
522 |
523 | hidden_states = hidden_states.view(*output_shape)
524 | # Add last layer
525 | if output_hidden_states:
526 | all_hidden_states = all_hidden_states + (hidden_states,)
527 |
528 | if not return_dict:
529 | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
530 |
531 | return BaseModelOutput(
532 | last_hidden_state=hidden_states,
533 | hidden_states=all_hidden_states,
534 | attentions=all_attentions,
535 | )
536 |
537 |
538 | @add_start_docstrings(
539 | """
540 | OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
541 | embeddings).
542 | """,
543 | OPENAI_GPT_START_DOCSTRING,
544 | )
545 | class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
546 | def __init__(self, config):
547 | super().__init__(config)
548 | self.transformer = OpenAIGPTModel(config)
549 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
550 |
551 | self.init_weights()
552 |
553 | def get_output_embeddings(self):
554 | return self.lm_head
555 |
556 | def set_output_embeddings(self, new_embeddings):
557 | self.lm_head = new_embeddings
558 |
559 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
560 | @add_code_sample_docstrings(
561 | checkpoint=_CHECKPOINT_FOR_DOC,
562 | output_type=CausalLMOutput,
563 | config_class=_CONFIG_FOR_DOC,
564 | )
565 | @autocast()
566 | def forward(
567 | self,
568 | input_ids=None,
569 | attention_mask=None,
570 | token_type_ids=None,
571 | position_ids=None,
572 | head_mask=None,
573 | inputs_embeds=None,
574 | labels=None,
575 | output_attentions=None,
576 | output_hidden_states=None,
577 | return_dict=None,
578 | ):
579 | r"""
580 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
581 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
582 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
583 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
584 | """
585 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586 |
587 | transformer_outputs = self.transformer(
588 | input_ids,
589 | attention_mask=attention_mask,
590 | token_type_ids=token_type_ids,
591 | position_ids=position_ids,
592 | head_mask=head_mask,
593 | inputs_embeds=inputs_embeds,
594 | output_attentions=output_attentions,
595 | output_hidden_states=output_hidden_states,
596 | return_dict=return_dict,
597 | )
598 | hidden_states = transformer_outputs[0]
599 | lm_logits = self.lm_head(hidden_states)
600 |
601 | loss = None
602 | if labels is not None:
603 | # Shift so that tokens < n predict n
604 | shift_logits = lm_logits[..., :-1, :].contiguous()
605 | shift_labels = labels[..., 1:].contiguous()
606 | #print("shift_logits=", shift_logits)
607 | #print("shift_labels=", shift_labels)
608 | # Flatten the tokens
609 | loss_fct = CrossEntropyLoss(ignore_index=-1)
610 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
611 |
612 | if not return_dict:
613 | output = (lm_logits,) + transformer_outputs[1:]
614 | return ((loss,) + output) if loss is not None else output
615 |
616 | return CausalLMOutput(
617 | loss=loss,
618 | logits=lm_logits,
619 | hidden_states=transformer_outputs.hidden_states,
620 | attentions=transformer_outputs.attentions,
621 | )
622 |
623 |
624 | @add_start_docstrings(
625 | """
626 | OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
627 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
628 | input embeddings, the classification head takes as input the input of a specified classification token index in the
629 | input sequence).
630 | """,
631 | OPENAI_GPT_START_DOCSTRING,
632 | )
633 | class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
634 | def __init__(self, config):
635 | super().__init__(config)
636 |
637 | config.num_labels = 1
638 | self.transformer = OpenAIGPTModel(config)
639 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
640 | self.multiple_choice_head = SequenceSummary(config)
641 |
642 | self.init_weights()
643 |
644 | def get_output_embeddings(self):
645 | return self.lm_head
646 |
647 | def set_output_embeddings(self, new_embeddings):
648 | self.lm_head = new_embeddings
649 |
650 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
651 | @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
652 | @autocast()
653 | def forward(
654 | self,
655 | input_ids=None,
656 | attention_mask=None,
657 | token_type_ids=None,
658 | position_ids=None,
659 | head_mask=None,
660 | inputs_embeds=None,
661 | mc_token_ids=None,
662 | labels=None,
663 | mc_labels=None,
664 | output_attentions=None,
665 | output_hidden_states=None,
666 | return_dict=None,
667 | ):
668 | r"""
669 | mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
670 | Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) -
671 | 1]``.
672 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
673 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
674 | ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to
675 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
676 | mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
677 | Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
678 | num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see
679 | `input_ids` above)
680 |
681 | Return:
682 |
683 | Examples::
684 |
685 | >>> from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel
686 | >>> import torch
687 |
688 | >>> tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
689 | >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt')
690 | >>> tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!)
691 | >>> model.resize_token_embeddings(len(tokenizer))
692 |
693 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
694 | >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
695 | >>> mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1
696 |
697 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
698 | >>> lm_logits = outputs.lm_logits
699 | >>> mc_logits = outputs.mc_logits
700 | """
701 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702 |
703 | transformer_outputs = self.transformer(
704 | input_ids,
705 | attention_mask=attention_mask,
706 | token_type_ids=token_type_ids,
707 | position_ids=position_ids,
708 | head_mask=head_mask,
709 | inputs_embeds=inputs_embeds,
710 | output_attentions=output_attentions,
711 | output_hidden_states=output_hidden_states,
712 | return_dict=return_dict,
713 | )
714 | hidden_states = transformer_outputs[0]
715 |
716 | lm_logits = self.lm_head(hidden_states)
717 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
718 |
719 | lm_loss, mc_loss = None, None
720 | if mc_labels is not None:
721 | loss_fct = CrossEntropyLoss(ignore_index=-1)
722 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
723 | if labels is not None:
724 | shift_logits = lm_logits[..., :-1, :].contiguous()
725 | shift_labels = labels[..., 1:].contiguous()
726 | loss_fct = CrossEntropyLoss(ignore_index=-1)
727 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
728 |
729 | if not return_dict:
730 | output = (lm_logits, mc_logits) + transformer_outputs[1:]
731 | if mc_loss is not None:
732 | output = (mc_loss,) + output
733 | return ((lm_loss,) + output) if lm_loss is not None else output
734 |
735 | return OpenAIGPTDoubleHeadsModelOutput(
736 | loss=lm_loss,
737 | mc_loss=mc_loss,
738 | logits=lm_logits,
739 | mc_logits=mc_logits,
740 | hidden_states=transformer_outputs.hidden_states,
741 | attentions=transformer_outputs.attentions,
742 | )
743 |
744 |
745 | @add_start_docstrings(
746 | """
747 | The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
748 | :class:`~transformers.OpenAIGPTForSequenceClassification` uses the last token in order to do the classification, as
749 | other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the
750 | position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that
751 | is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each
752 | row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of
753 | :obj:`input_ids`, it does the same (take the last value in each row of the batch).
754 | """,
755 | OPENAI_GPT_START_DOCSTRING,
756 | )
757 | class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
758 | def __init__(self, config):
759 | super().__init__(config)
760 | self.num_labels = config.num_labels
761 | self.transformer = OpenAIGPTModel(config)
762 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
763 |
764 | self.init_weights()
765 |
766 | @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
767 | @add_code_sample_docstrings(
768 | checkpoint=_CHECKPOINT_FOR_DOC,
769 | output_type=SequenceClassifierOutput,
770 | config_class=_CONFIG_FOR_DOC,
771 | )
772 | @autocast()
773 | def forward(
774 | self,
775 | input_ids=None,
776 | attention_mask=None,
777 | token_type_ids=None,
778 | position_ids=None,
779 | head_mask=None,
780 | inputs_embeds=None,
781 | labels=None,
782 | output_attentions=None,
783 | output_hidden_states=None,
784 | return_dict=None,
785 | ):
786 | r"""
787 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
788 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
789 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
790 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
791 | """
792 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
793 |
794 | transformer_outputs = self.transformer(
795 | input_ids,
796 | attention_mask=attention_mask,
797 | token_type_ids=token_type_ids,
798 | position_ids=position_ids,
799 | head_mask=head_mask,
800 | inputs_embeds=inputs_embeds,
801 | output_attentions=output_attentions,
802 | output_hidden_states=output_hidden_states,
803 | return_dict=return_dict,
804 | )
805 |
806 | hidden_states = transformer_outputs[0]
807 | logits = self.score(hidden_states)
808 |
809 | if input_ids is not None:
810 | batch_size, sequence_length = input_ids.shape[:2]
811 | else:
812 | batch_size, sequence_length = inputs_embeds.shape[:2]
813 |
814 | assert (
815 | self.config.pad_token_id is not None or batch_size == 1
816 | ), "Cannot handle batch sizes > 1 if no padding token is defined."
817 | if self.config.pad_token_id is None:
818 | sequence_lengths = -1
819 | else:
820 | if input_ids is not None:
821 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
822 | else:
823 | sequence_lengths = -1
824 | logger.warning(
825 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
826 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
827 | )
828 |
829 | pooled_logits = logits[range(batch_size), sequence_lengths]
830 |
831 | loss = None
832 | if labels is not None:
833 | if self.num_labels == 1:
834 | # We are doing regression
835 | loss_fct = MSELoss()
836 | loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
837 | else:
838 | loss_fct = CrossEntropyLoss(ignore_index=-1)
839 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
840 |
841 | if not return_dict:
842 | output = (pooled_logits,) + transformer_outputs[1:]
843 | return ((loss,) + output) if loss is not None else output
844 |
845 | return SequenceClassifierOutput(
846 | loss=loss,
847 | logits=pooled_logits,
848 | hidden_states=transformer_outputs.hidden_states,
849 | attentions=transformer_outputs.attentions,
850 | )
851 |
--------------------------------------------------------------------------------
/pec_baseline/models/gpt/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 |
17 |
18 | # 关键包版本说明:
19 | # pytorch: 1.9.0+
20 | # transformers: 4.11.3+
21 |
22 | import json
23 | import os
24 | import re
25 | from typing import Optional, Tuple
26 |
27 | from transformers import PreTrainedTokenizer
28 | from transformers.utils import logging
29 | from transformers import BasicTokenizer
30 |
31 |
32 | logger = logging.get_logger(__name__)
33 |
34 | VOCAB_FILES_NAMES = {
35 | "vocab_file": "vocab.json",
36 | "merges_file": "merges.txt",
37 | }
38 |
39 | PRETRAINED_VOCAB_FILES_MAP = {
40 | "vocab_file": {"openai-gpt": "https://huggingface.co/openai-gpt/resolve/main/vocab.json"},
41 | "merges_file": {"openai-gpt": "https://huggingface.co/openai-gpt/resolve/main/merges.txt"},
42 | }
43 |
44 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
45 | "openai-gpt": 512,
46 | }
47 |
48 |
49 | def get_pairs(word):
50 | """
51 | Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
52 | strings)
53 | """
54 | pairs = set()
55 | prev_char = word[0]
56 | for char in word[1:]:
57 | pairs.add((prev_char, char))
58 | prev_char = char
59 | return pairs
60 |
61 |
62 | def text_standardize(text):
63 | """
64 | fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization
65 | """
66 | text = text.replace("—", "-")
67 | text = text.replace("–", "-")
68 | text = text.replace("―", "-")
69 | text = text.replace("…", "...")
70 | text = text.replace("´", "'")
71 | text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
72 | text = re.sub(r"\s*\n\s*", " \n ", text)
73 | text = re.sub(r"[^\S\n]+", " ", text)
74 | return text.strip()
75 |
76 |
77 | class OpenAIGPTTokenizer(PreTrainedTokenizer):
78 | """
79 | Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities:
80 |
81 | - lowercases all inputs,
82 | - uses :obj:`SpaCy` tokenizer and :obj:`ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's
83 | :obj:`BasicTokenizer` if not.
84 |
85 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
86 | Users should refer to this superclass for more information regarding those methods.
87 |
88 | Args:
89 | vocab_file (:obj:`str`):
90 | Path to the vocabulary file.
91 | merges_file (:obj:`str`):
92 | Path to the merges file.
93 | unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
94 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
95 | token instead.
96 | """
97 |
98 | vocab_files_names = VOCAB_FILES_NAMES
99 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
100 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
101 | model_input_names = ["input_ids", "attention_mask"]
102 |
103 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs):
104 | super().__init__(unk_token=unk_token, **kwargs)
105 |
106 | try:
107 | import ftfy
108 | from spacy.lang.en import English
109 |
110 | _nlp = English()
111 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
112 | self.fix_text = ftfy.fix_text
113 | except ImportError:
114 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
115 | self.nlp = BasicTokenizer(do_lower_case=True)
116 | self.fix_text = None
117 |
118 | with open(vocab_file, encoding="utf-8") as vocab_handle:
119 | self.encoder = json.load(vocab_handle)
120 | self.decoder = {v: k for k, v in self.encoder.items()}
121 | with open(merges_file, encoding="utf-8") as merges_handle:
122 | merges = merges_handle.read().split("\n")[1:-1]
123 | merges = [tuple(merge.split()) for merge in merges]
124 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
125 | self.cache = {}
126 |
127 | @property
128 | def do_lower_case(self):
129 | return True
130 |
131 | @property
132 | def vocab_size(self):
133 | return len(self.encoder)
134 |
135 | def get_vocab(self):
136 | return dict(self.encoder, **self.added_tokens_encoder)
137 |
138 | def bpe(self, token):
139 | word = tuple(token[:-1]) + (token[-1] + "",)
140 | if token in self.cache:
141 | return self.cache[token]
142 | pairs = get_pairs(word)
143 |
144 | if not pairs:
145 | return token + ""
146 |
147 | while True:
148 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
149 | if bigram not in self.bpe_ranks:
150 | break
151 | first, second = bigram
152 | new_word = []
153 | i = 0
154 | while i < len(word):
155 | try:
156 | j = word.index(first, i)
157 | except ValueError:
158 | new_word.extend(word[i:])
159 | break
160 | else:
161 | new_word.extend(word[i:j])
162 | i = j
163 |
164 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
165 | new_word.append(first + second)
166 | i += 2
167 | else:
168 | new_word.append(word[i])
169 | i += 1
170 | new_word = tuple(new_word)
171 | word = new_word
172 | if len(word) == 1:
173 | break
174 | else:
175 | pairs = get_pairs(word)
176 | word = " ".join(word)
177 | if word == "\n ":
178 | word = "\n"
179 | self.cache[token] = word
180 | return word
181 |
182 | def _tokenize(self, text):
183 | """Tokenize a string."""
184 | split_tokens = []
185 | if self.fix_text is None:
186 | # Using BERT's BasicTokenizer
187 | text = self.nlp.tokenize(text)
188 | for token in text:
189 | split_tokens.extend([t for t in self.bpe(token).split(" ")])
190 | else:
191 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
192 | text = self.nlp(text_standardize(self.fix_text(text)))
193 | for token in text:
194 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")])
195 | return split_tokens
196 |
197 | def _convert_token_to_id(self, token):
198 | """Converts a token (str) in an id using the vocab."""
199 | return self.encoder.get(token, self.encoder.get(self.unk_token))
200 |
201 | def _convert_id_to_token(self, index):
202 | """Converts an id in a token (BPE) using the vocab."""
203 | return self.decoder.get(index, self.unk_token)
204 |
205 | def convert_tokens_to_string(self, tokens):
206 | """Converts a sequence of tokens (string) in a single string."""
207 | out_string = "".join(tokens).replace("", " ").strip()
208 | return out_string
209 |
210 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
211 | if not os.path.isdir(save_directory):
212 | logger.error(f"Vocabulary path ({save_directory}) should be a directory")
213 | return
214 | vocab_file = os.path.join(
215 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
216 | )
217 | merge_file = os.path.join(
218 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
219 | )
220 |
221 | with open(vocab_file, "w", encoding="utf-8") as f:
222 | f.write(json.dumps(self.encoder, ensure_ascii=False))
223 |
224 | index = 0
225 | with open(merge_file, "w", encoding="utf-8") as writer:
226 | writer.write("#version: 0.2\n")
227 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
228 | if index != token_index:
229 | logger.warning(
230 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
231 | " Please check that the tokenizer is not corrupted!"
232 | )
233 | index = token_index
234 | writer.write(" ".join(bpe_tokens) + "\n")
235 | index += 1
236 |
237 | return vocab_file, merge_file
238 |
--------------------------------------------------------------------------------
/pec_baseline/models/gpt2/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # Copyright 2020 The HuggingFace Team. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # 关键包版本说明:
21 | # pytorch: 1.9.0+
22 | # transformers: 4.11.3
23 |
24 |
25 | from typing import TYPE_CHECKING
26 |
27 | from transformers.file_utils import is_torch_available
28 |
29 |
30 | if is_torch_available():
31 | from .modeling_gpt2 import (
32 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
33 | GPT2DoubleHeadsModel,
34 | GPT2ForSequenceClassification,
35 | GPT2ForTokenClassification,
36 | GPT2LMHeadModel,
37 | GPT2Model,
38 | GPT2PreTrainedModel,
39 | load_tf_weights_in_gpt2,
40 | )
41 |
--------------------------------------------------------------------------------
/pec_baseline/models/gpt2/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 |
17 |
18 | # 关键包版本说明:
19 | # pytorch: 1.9.0+
20 | # transformers: 4.11.3
21 |
22 |
23 | import json
24 | import os
25 | from functools import lru_cache
26 | from typing import TYPE_CHECKING, List, Optional, Tuple
27 |
28 | import regex as re
29 |
30 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
31 | from transformers.utils import logging
32 |
33 |
34 | if TYPE_CHECKING:
35 | from transformers.pipelines.conversational import Conversation
36 |
37 | logger = logging.get_logger(__name__)
38 |
39 | VOCAB_FILES_NAMES = {
40 | "vocab_file": "vocab.json",
41 | "merges_file": "merges.txt",
42 | }
43 |
44 | PRETRAINED_VOCAB_FILES_MAP = {
45 | "vocab_file": {
46 | "gpt2": "https://huggingface.co/gpt2/resolve/main/vocab.json",
47 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
48 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/vocab.json",
49 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
50 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/vocab.json",
51 | },
52 | "merges_file": {
53 | "gpt2": "https://huggingface.co/gpt2/resolve/main/merges.txt",
54 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
55 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/merges.txt",
56 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
57 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/merges.txt",
58 | },
59 | }
60 |
61 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
62 | "gpt2": 1024,
63 | "gpt2-medium": 1024,
64 | "gpt2-large": 1024,
65 | "gpt2-xl": 1024,
66 | "distilgpt2": 1024,
67 | }
68 |
69 |
70 | @lru_cache()
71 | def bytes_to_unicode():
72 | """
73 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
74 | characters the bpe code barfs on.
75 |
76 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
77 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
78 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
79 | tables between utf-8 bytes and unicode strings.
80 | """
81 | bs = (
82 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
83 | )
84 | cs = bs[:]
85 | n = 0
86 | for b in range(2 ** 8):
87 | if b not in bs:
88 | bs.append(b)
89 | cs.append(2 ** 8 + n)
90 | n += 1
91 | cs = [chr(n) for n in cs]
92 | return dict(zip(bs, cs))
93 |
94 |
95 | def get_pairs(word):
96 | """
97 | Return set of symbol pairs in a word.
98 |
99 | Word is represented as tuple of symbols (symbols being variable-length strings).
100 | """
101 | pairs = set()
102 | prev_char = word[0]
103 | for char in word[1:]:
104 | pairs.add((prev_char, char))
105 | prev_char = char
106 | return pairs
107 |
108 |
109 | class GPT2Tokenizer(PreTrainedTokenizer):
110 | """
111 | Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
112 |
113 | This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
114 | be encoded differently whether it is at the beginning of the sentence (without space) or not:
115 |
116 | ::
117 |
118 | >>> from transformers import GPT2Tokenizer
119 | >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
120 | >>> tokenizer("Hello world")['input_ids']
121 | [15496, 995]
122 | >>> tokenizer(" Hello world")['input_ids']
123 | [18435, 995]
124 |
125 | You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you
126 | call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
127 |
128 | .. note::
129 |
130 | When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first
131 | one).
132 |
133 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
134 | Users should refer to this superclass for more information regarding those methods.
135 |
136 | Args:
137 | vocab_file (:obj:`str`):
138 | Path to the vocabulary file.
139 | merges_file (:obj:`str`):
140 | Path to the merges file.
141 | errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`):
142 | Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode
143 | `__ for more information.
144 | unk_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
145 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
146 | token instead.
147 | bos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
148 | The beginning of sequence token.
149 | eos_token (:obj:`str`, `optional`, defaults to :obj:`<|endoftext|>`):
150 | The end of sequence token.
151 | add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`):
152 | Whether or not to add an initial space to the input. This allows to treat the leading word just as any
153 | other word. (GPT2 tokenizer detect beginning of words by the preceding space).
154 | """
155 |
156 | vocab_files_names = VOCAB_FILES_NAMES
157 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
159 | model_input_names = ["input_ids", "attention_mask"]
160 |
161 | def __init__(
162 | self,
163 | vocab_file,
164 | merges_file,
165 | errors="replace",
166 | unk_token="<|endoftext|>",
167 | bos_token="<|endoftext|>",
168 | eos_token="<|endoftext|>",
169 | add_prefix_space=False,
170 | **kwargs
171 | ):
172 | bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
173 | eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
174 | unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
175 | super().__init__(
176 | errors=errors,
177 | unk_token=unk_token,
178 | bos_token=bos_token,
179 | eos_token=eos_token,
180 | add_prefix_space=add_prefix_space,
181 | **kwargs,
182 | )
183 |
184 | with open(vocab_file, encoding="utf-8") as vocab_handle:
185 | self.encoder = json.load(vocab_handle)
186 | self.decoder = {v: k for k, v in self.encoder.items()}
187 | self.errors = errors # how to handle errors in decoding
188 | self.byte_encoder = bytes_to_unicode()
189 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
190 | with open(merges_file, encoding="utf-8") as merges_handle:
191 | bpe_merges = merges_handle.read().split("\n")[1:-1]
192 | bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
193 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
194 | self.cache = {}
195 | self.add_prefix_space = add_prefix_space
196 |
197 | # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
198 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
199 |
200 | @property
201 | def vocab_size(self):
202 | return len(self.encoder)
203 |
204 | def get_vocab(self):
205 | return dict(self.encoder, **self.added_tokens_encoder)
206 |
207 | def bpe(self, token):
208 | if token in self.cache:
209 | return self.cache[token]
210 | word = tuple(token)
211 | pairs = get_pairs(word)
212 |
213 | if not pairs:
214 | return token
215 |
216 | while True:
217 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
218 | if bigram not in self.bpe_ranks:
219 | break
220 | first, second = bigram
221 | new_word = []
222 | i = 0
223 | while i < len(word):
224 | try:
225 | j = word.index(first, i)
226 | except ValueError:
227 | new_word.extend(word[i:])
228 | break
229 | else:
230 | new_word.extend(word[i:j])
231 | i = j
232 |
233 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
234 | new_word.append(first + second)
235 | i += 2
236 | else:
237 | new_word.append(word[i])
238 | i += 1
239 | new_word = tuple(new_word)
240 | word = new_word
241 | if len(word) == 1:
242 | break
243 | else:
244 | pairs = get_pairs(word)
245 | word = " ".join(word)
246 | self.cache[token] = word
247 | return word
248 |
249 | def _tokenize(self, text):
250 | """Tokenize a string."""
251 | bpe_tokens = []
252 | for token in re.findall(self.pat, text):
253 | token = "".join(
254 | self.byte_encoder[b] for b in token.encode("utf-8")
255 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
256 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
257 | return bpe_tokens
258 |
259 | def _convert_token_to_id(self, token):
260 | """Converts a token (str) in an id using the vocab."""
261 | return self.encoder.get(token, self.encoder.get(self.unk_token))
262 |
263 | def _convert_id_to_token(self, index):
264 | """Converts an index (integer) in a token (str) using the vocab."""
265 | return self.decoder.get(index)
266 |
267 | def convert_tokens_to_string(self, tokens):
268 | """Converts a sequence of tokens (string) in a single string."""
269 | text = "".join(tokens)
270 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
271 | return text
272 |
273 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
274 | if not os.path.isdir(save_directory):
275 | logger.error(f"Vocabulary path ({save_directory}) should be a directory")
276 | return
277 | vocab_file = os.path.join(
278 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
279 | )
280 | merge_file = os.path.join(
281 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
282 | )
283 |
284 | with open(vocab_file, "w", encoding="utf-8") as f:
285 | f.write(json.dumps(self.encoder, ensure_ascii=False))
286 |
287 | index = 0
288 | with open(merge_file, "w", encoding="utf-8") as writer:
289 | writer.write("#version: 0.2\n")
290 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
291 | if index != token_index:
292 | logger.warning(
293 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
294 | " Please check that the tokenizer is not corrupted!"
295 | )
296 | index = token_index
297 | writer.write(" ".join(bpe_tokens) + "\n")
298 | index += 1
299 |
300 | return vocab_file, merge_file
301 |
302 | def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
303 | add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
304 | if is_split_into_words or add_prefix_space:
305 | text = " " + text
306 | return (text, kwargs)
307 |
308 | def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
309 | input_ids = []
310 | for is_user, text in conversation.iter_texts():
311 | input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
312 | if len(input_ids) > self.model_max_length:
313 | input_ids = input_ids[-self.model_max_length :]
314 | return input_ids
315 |
--------------------------------------------------------------------------------
/pec_baseline/models/model_parameters.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Model parameters processing file
18 | # File: model_parameters.py
19 | # Used for model parameters analysis
20 | # Author: Chen Yirong
21 | # Date: 2022.04.06
22 |
23 | def count_trainable_parameters(model):
24 | '''获取需要训练的参数数量
25 | 使用示例:print(f'The model has {count_trainable_parameters(model):,} trainable parameters')
26 | '''
27 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
28 |
29 | def count_total_parameters(model):
30 | '''获取模型总的参数数量
31 | 使用示例:print(f'The model has {count_total_parameters(model):,} total parameters')
32 | '''
33 | return sum(p.numel() for p in model.parameters())
34 |
35 | def show_trainable_parameters(model):
36 | for name, param in model.named_parameters():
37 | if param.requires_grad:
38 | print(name)
39 |
40 | # 冻结模型参数
41 | # 参考:https://blog.csdn.net/weixin_41712499/article/details/111295683?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_title~default-5.no_search_link&spm=1001.2101.3001.4242
42 | def set_freeze_by_names(model, layer_names, freeze=True):
43 | if not isinstance(layer_names, Iterable):
44 | layer_names = [layer_names]
45 | for name, child in model.named_children():
46 | if name not in layer_names:
47 | continue
48 | for param in child.parameters():
49 | param.requires_grad = not freeze
50 |
51 | def freeze_by_names(model, layer_names):
52 | set_freeze_by_names(model, layer_names, True)
53 |
54 | def unfreeze_by_names(model, layer_names):
55 | set_freeze_by_names(model, layer_names, False)
56 |
57 | def set_freeze_by_idxs(model, idxs, freeze=True):
58 | if not isinstance(idxs, Iterable):
59 | idxs = [idxs]
60 | num_child = len(list(model.children()))
61 | idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
62 | for idx, child in enumerate(model.children()):
63 | if idx not in idxs:
64 | continue
65 | for param in child.parameters():
66 | param.requires_grad = not freeze
67 |
68 | def freeze_by_idxs(model, idxs):
69 | set_freeze_by_idxs(model, idxs, True)
70 |
71 | def unfreeze_by_idxs(model, idxs):
72 | set_freeze_by_idxs(model, idxs, False)
73 |
74 | def freeze_by_model_name(model, model_name):
75 | for name, param in model.named_parameters():
76 | if name.startswith(model_name):
77 | param.requires_grad = False
78 |
79 | def unfreeze_by_model_name(model, model_name):
80 | for name, param in model.named_parameters():
81 | if name.startswith(model_name):
82 | param.requires_grad = True
--------------------------------------------------------------------------------
/pec_baseline/train_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Model training code
18 | # File: train_model.py
19 | # Used for training model
20 | # Author: Chen Yirong
21 | # Date: 2022.04.06
22 |
23 | # 关键包版本说明:
24 | # pytorch: 1.9.0+
25 | # pytorch-ignite: 0.4.8
26 | # transformers: 4.18.0
27 |
28 | import os
29 | import json
30 | import time
31 | import math
32 | import torch
33 | import socket
34 | import random
35 | import logging
36 | import numpy as np
37 | from pprint import pformat
38 | from argparse import ArgumentParser # 用于函数文件传递参数
39 | from torch.optim.lr_scheduler import LambdaLR, CyclicLR, OneCycleLR
40 | from torch.nn.parallel import DistributedDataParallel # 用于分布式模型训练
41 | from torch.cuda.amp import autocast as autocast # 用于使用自动混合精度,要求torch版本为1.6+
42 | from ignite.engine import Engine, Events
43 | from ignite.handlers import ModelCheckpoint, EarlyStopping, Checkpoint
44 | from ignite.metrics import Loss, MetricsLambda, RunningAverage
45 | from ignite.contrib.handlers import ProgressBar, PiecewiseLinear, LRScheduler
46 | from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine
47 | from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler
48 |
49 | from transformers import (WEIGHTS_NAME, CONFIG_NAME, BertTokenizer, OpenAIGPTTokenizer, OpenAIGPTConfig, GPT2Config) # , AdamW
50 |
51 | from torch.optim import AdamW
52 |
53 | # 本项目自主撰写的python包
54 | from models import (count_trainable_parameters, count_total_parameters, show_trainable_parameters, freeze_by_model_name, unfreeze_by_model_name)
55 | # 从transformers的代码修改,适配混合精度训练
56 | from models.gpt import OpenAIGPTLMHeadModel
57 | from models.gpt2 import GPT2LMHeadModel
58 | # 读取数据集的dataloader构建函数
59 | from utils import build_cped_dataloaders
60 |
61 | logger = logging.getLogger(__file__)
62 |
63 | def setup_seed(seed):
64 | torch.manual_seed(seed)
65 | torch.cuda.manual_seed_all(seed)
66 | np.random.seed(seed)
67 | random.seed(seed)
68 | torch.backends.cudnn.deterministic = True
69 |
70 | setup_seed(2022)
71 |
72 |
73 | def average_distributed_scalar(scalar, args):
74 | """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
75 | if args.local_rank == -1:
76 | return scalar
77 | scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size()
78 | torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM)
79 | return scalar_t.item()
80 |
81 | def score_function(engine):
82 | '''最小化ppl,也就是最大化负的ppl
83 |
84 | '''
85 | return -engine.state.metrics["average_ppl"]
86 |
87 |
88 |
89 | def train():
90 | '''train()
91 | 封装好的训练模型过程
92 |
93 | '''
94 | # 参数定义
95 | parser = ArgumentParser()
96 | # 模型类型以及路径
97 | parser.add_argument("--model_type", type=str, default="GPT", choices=['GPT', 'GPT-2'], help="Type of Model(模型类型名称)")
98 | parser.add_argument("--model_checkpoint", type=str, default="/home/phd-chen.yirong/PretrainedModel/models.huggingface.co/bert-base-chinese", help="Path or URL of the model")
99 | parser.add_argument("--gpt_model_checkpoint", type=str, default="/home/phd-chen.yirong/PretrainedModel/CDial-GPT/LCCD_GPT_FOR_GPTSPEAKERROBOT", help="Path or URL of the GPT model used to initialized the GPT part of the UiBot.")
100 | parser.add_argument("--bert_model_checkpoint", type=str, default="./runs/SPEAKERBERT", help="Path or URL of the BERT model used to initialized the BERT part of the UiBot.")
101 | parser.add_argument('--log_file', '-log_file', type=str, default="./logs", help="Output logs to a file under this path")
102 |
103 | # 数据集名称、路径等配置
104 | parser.add_argument("--dataset_name", type=str, default="CPED", choices=['CPED', 'MELD', 'CPED-shuffle'], help="Name of Dataset(数据集名称)")
105 | parser.add_argument("--data_path", type=str, default="./data/CPED/", help="dir of the dataset(数据集保存路径,可以是目录或者文件名)")
106 | parser.add_argument("--cache_path", type=str, default="./data/CPED_cache_for_CpedDataset", help="path of the dataset cache(数据集缓存文件的保存路径,必须为文件名)")
107 | parser.add_argument("--use_speaker_name_as_speaker_list", action='store_true',
108 | help="If true using speaker name as speaker_list")
109 | parser.add_argument("--emotion_type", type=str, default="Emotion", choices=['Sentiment', 'BaseEmotion', 'Emotion'], help="Type of Emotion")
110 | parser.add_argument("--da_type", type=str, default="DA", choices=['DA', 'BaseDA'], help="Type of DA")
111 | parser.add_argument('--with_emotion', action='store_true', help="use emotion as token_type")
112 | parser.add_argument('--with_da', action='store_true', help="use da as token_type")
113 | parser.add_argument('--with_current_speaker', action='store_true', help="use current speaker as control signal")
114 | parser.add_argument('--with_current_persona', action='store_true', help="use current persona as control signal")
115 | parser.add_argument('--with_current_emotion', action='store_true', help="use current emotion as control signal")
116 | parser.add_argument('--with_current_da', action='store_true', help="use current da as control signal")
117 | parser.add_argument('--set_eda_in_speaker', action='store_true', help="set eda in speaker")
118 | parser.add_argument('--set_current_speaker_mask', action='store_true', help="set current_speaker_mask")
119 |
120 | # 训练模型配置
121 | parser.add_argument('--find_unused_parameters', action='store_true', help="If True find_unused_parameters")
122 | parser.add_argument('--show_parameters', action='store_true', help="If True show model parameters")
123 | parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step")
124 | parser.add_argument('--pretrained', action='store_true', help="If False train from scratch")
125 | ## Adamw优化器参数配置:正则化、beta1、beta2、eps数值设置
126 | parser.add_argument('--L2_regularization', action='store_true', help="If False train without L2 Regularization")
127 | parser.add_argument("--L2_weight_decay", type=float, default=1e-2, help="L2 weight decay")
128 | parser.add_argument("--adamw_beta1", type=float, default=0.9, help="Adam's beta1 parameter")
129 | parser.add_argument("--adamw_beta2", type=float, default=0.999, help="Adam's beta2 parameter")
130 | parser.add_argument("--adamw_eps", type=float, default=1e-6, help="Adam's epsilon for numerical stability")
131 |
132 | # 损失函数的各部分占比
133 | parser.add_argument("--alpha_nll", type=float, default=1.0, help="alpha_nll")
134 | parser.add_argument("--alpha_emotion", type=float, default=1.0, help="alpha_emotion")
135 | parser.add_argument("--alpha_da", type=float, default=1.0, help="alpha_da")
136 | parser.add_argument("--alpha_per_gen", type=float, default=1.0, help="alpha_per_gen")
137 | parser.add_argument("--alpha_per_neu", type=float, default=1.0, help="alpha_per_neu")
138 | parser.add_argument("--alpha_per_ext", type=float, default=1.0, help="alpha_per_ext")
139 | parser.add_argument("--alpha_per_ope", type=float, default=1.0, help="alpha_per_ope")
140 | parser.add_argument("--alpha_per_agr", type=float, default=1.0, help="alpha_per_agr")
141 | parser.add_argument("--alpha_per_con", type=float, default=1.0, help="alpha_per_con")
142 |
143 | # 冻结模型的部分层
144 | parser.add_argument('--freeze_model', action='store_true', help="If True freeze some layers of the model")
145 | parser.add_argument('--freeze_start_layer', type=int, default=0, help="冻结指定的层范围,格式为start-end,其中start取值范围为0~11,end取值范围为0~11,start<=end")
146 | parser.add_argument('--freeze_end_layer', type=int, default=11, help="冻结指定的层范围,格式为start-end,其中start取值范围为0~11,end取值范围为0~11,start<=end")
147 |
148 |
149 | parser.add_argument("--not_return_dict", action='store_true', help="If False the model return dict as result(适配最新版本transformers的模型输出)")
150 | parser.add_argument("--do_unscale", action='store_true', help="Calling scaler.unscale_(optimizer) before clipping enables you to clip unscaled gradients as usual(梯度裁剪)")
151 | parser.add_argument("--retain_graph", action='store_true', help="If true using retain_graph=True in loss.backward")
152 | parser.add_argument("--num_workers", type=int, default=4, help="Number of subprocesses for data loading")
153 | parser.add_argument("--n_epochs", type=int, default=32, help="Number of training epochs")
154 | parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training")
155 | parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation")
156 | parser.add_argument("--test_batch_size", type=int, default=1, help="Batch size for testing")
157 | parser.add_argument("--max_history", type=int, default=25, help="Number of previous exchanges to keep in history")
158 |
159 | parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)")
160 | # 学习率设置
161 | parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear', 'cyclic', '1cycle','fixedlr'], help="method of optim")
162 | parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
163 | parser.add_argument("--base_lr", type=float, default=1e-3, help="Initial learning rate which is the lower boundary in the cycle for each parameter group.")
164 | parser.add_argument("--max_lr", type=float, default=5e-3, help="Upper learning rate boundaries in the cycle for each parameter group.")
165 | ## CyclicLR
166 | parser.add_argument("--cycliclr_mode", type=str, default="triangular2", choices=['triangular', 'triangular2', 'exp_range'], help="mode of CyclicLR, see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html#torch.optim.lr_scheduler.CyclicLR")
167 |
168 | parser.add_argument("--eval_before_start", action='store_true',
169 | help="If true start with a first evaluation before training")
170 | parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps")
171 | parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps")
172 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradients on several steps")
173 | parser.add_argument("--max_norm", type=float, default=2.0, help="Clipping gradient norm")
174 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
175 | help="Device (cuda or cpu)")
176 | parser.add_argument("--autocast", action='store_true',
177 | help="If true using autocast to automatically mix accuracy to accelerate training(开启自动混合精度加速训练)")
178 | parser.add_argument("--local_rank", type=int, default=-1,
179 | help="Local rank for distributed training (-1: not distributed)")
180 |
181 | args = parser.parse_args()
182 |
183 | # 日志文件配置
184 | # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process.
185 | # logger.info => log main process only, logger.warning => log all processes
186 | # the name of log file looks like './logs/Jan11_22-55-46_gpu144_GPT_CPED.log'
187 | # The log information is output to a specified file, which is convenient
188 | # for viewing the log when the program is running time specified by the ```at``` command.
189 | if not os.path.exists(args.log_file):
190 | # 不存在log目录则创建
191 | os.makedirs(args.log_file)
192 | log_file_name_or_tensorboard_dir_name = str(time.strftime('%b%d_%H-%M-%S',time.localtime(time.time())))+'_'+str(socket.gethostname())+'_'+args.model_type+'_'+args.dataset_name
193 | logging_file_name = os.path.join(args.log_file,log_file_name_or_tensorboard_dir_name+'.log')
194 | logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
195 | filename=logging_file_name) # output the log information to the log file
196 | logger.warning("Running process %d", args.local_rank)
197 | logger.info("Arguments: %s", pformat(args))
198 |
199 | # 数据集的训练集、验证集、测试集的文件名称
200 | cped_filenames = {"train":"train_split.csv",
201 | "valid":"valid_split.csv",
202 | "test":"test_split.csv"}
203 | meld_filenames = {"train":"train_sent_emo.csv",
204 | "valid":"dev_sent_emo.csv",
205 | "test":"test_sent_emo.csv"}
206 | cped_shuffle_filenames = {"train":"train_shuffle_split.csv",
207 | "valid":"valid_shuffle_split.csv",
208 | "test":"test_shuffle_split.csv"}
209 | if args.dataset_name == "MELD":
210 | filenames = meld_filenames
211 | elif args.dataset_name == "CPED":
212 | filenames = cped_filenames
213 | elif args.dataset_name == "CPED-shuffle":
214 | filenames = cped_shuffle_filenames
215 |
216 | # 多机分布式训练配置
217 | # Initialize distributed training if needed
218 | args.distributed = (args.local_rank != -1)
219 | if args.distributed:
220 | torch.cuda.set_device(args.local_rank)
221 | args.device = torch.device("cuda", args.local_rank)
222 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
223 |
224 | logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
225 |
226 | # 根据模型类型确定模型类以及tokenizer
227 | if args.model_type == 'GPT':
228 | model_class = OpenAIGPTLMHeadModel
229 | config_class = OpenAIGPTConfig
230 | tokenizer_class = BertTokenizer
231 |
232 | elif args.model_type == 'GPT-2':
233 | model_class = GPT2LMHeadModel
234 | config_class = GPT2Config
235 | tokenizer_class = BertTokenizer
236 |
237 |
238 | '''此处增加模型
239 | elif args.model_type == 'MODELNAME':
240 | model_class = ModelClass
241 | config_class = ModelConfigClass
242 | tokenizer_class = TokenizerClass
243 | '''
244 |
245 |
246 | # 初始化模型参数
247 | if args.pretrained: # 加载预训练模型参数
248 | tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, do_lower_case=True)
249 | model = model_class.from_pretrained(args.model_checkpoint)
250 | else: # 不是从预训练模型中初始化
251 | print("不是从预训练模型中初始化")
252 |
253 | # 输出模型结构与参数信息
254 | if args.show_parameters:
255 | print(model) # 输出网络结构
256 | # 输出模型总的参数量
257 | # https://huggingface.co/docs/transformers/main_classes/model#transformers.modeling_utils.ModuleUtilsMixin.num_parameters
258 | total_params = model.num_parameters() #直接调用模型自身函数计算模型的参数量
259 | print(f'{total_params:,} total parameters.')
260 |
261 | # 冻结模型的部分层
262 | if args.freeze_model:
263 | freeze_by_model_name(model, "transformer.tokens_embed")
264 | freeze_by_model_name(model, "transformer.positions_embed")
265 | for i in range(args.freeze_start_layer,args.freeze_end_layer+1):
266 | # 冻结args.freeze_start_layer~args.freeze_end_layer层
267 | freeze_by_model_name(model, "transformer.h."+str(i)+".") # 加点在后面,防止冻结1~6时,把第10、11层也冻结了
268 |
269 | # 查看模型的可训练的层
270 | print("可训练的层如下所示:")
271 | for name, param in model.named_parameters():
272 | if param.requires_grad:
273 | print(name,':',param.size())
274 | # 计算模型可训练的参数量
275 | # total_trainable_params = count_trainable_parameters(model)
276 | total_trainable_params = model.num_parameters(only_trainable=True)
277 | print(f'{total_trainable_params:,} total trainable parameters.')
278 |
279 | model.to(args.device)
280 |
281 |
282 |
283 | if args.L2_regularization: # L2 Regularization
284 | # 旧版本的L2正则化
285 | # reference: https://blog.csdn.net/mch2869253130/article/details/105994044
286 | # 不参与L2正则化的列表
287 | # optimizer = AdamW([{'params': model.parameters(), 'weight_decay': args.L2_weight_decay, 'initial_lr': args.lr}], lr=args.lr, betas=(0.9, 0.999), eps=1e-6, weight_decay=args.L2_weight_decay, correct_bias=True)
288 |
289 | # 新版本的
290 | # 参考:https://arxiv.org/pdf/1711.05101.pdf
291 | # https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html?highlight=adamw#torch.optim.AdamW
292 | # transformers 版本:from transformers import AdamW
293 | # optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay, correct_bias=True)
294 | # pytorch版本:from torch.optim import AdamW
295 | optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay)
296 | '''
297 | no_decay = ['bias', 'bias_ih_l0', 'bias_hh_l0', 'LayerNorm.weight','layernorm_1.weight']
298 | parameters_list = []
299 | optimizer_grouped_parameters = [
300 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.L2_weight_decay, 'initial_lr': args.lr},
301 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'initial_lr': args.lr}
302 | ]
303 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, correct_bias=True)
304 | '''
305 | else: # not L2 Regularization
306 | # transformers 版本:from transformers import AdamW
307 | # optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps, weight_decay=args.L2_weight_decay, correct_bias=True)
308 | # pytorch版本:from torch.optim import AdamW
309 | optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, betas=(args.adamw_beta1, args.adamw_beta2), eps=args.adamw_eps)
310 |
311 | if args.autocast:
312 | # 混合精度训练
313 | # 参考: https://pytorch.org/docs/1.9.0/amp.html?highlight=torch%20cuda%20amp%20gradscaler
314 | # https://pytorch.org/docs/1.9.0/notes/amp_examples.html#amp-examples
315 | scaler = torch.cuda.amp.GradScaler() # pytorch版本要求:1.6+
316 |
317 | if args.distributed:
318 | # Add "find_unused_parameters=True" to avoid the following error
319 | # ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception:
320 | # Expected to have finished reduction in the prior iteration before starting a new one.
321 | # This error indicates that your module has parameters that were not used in producing loss.
322 | # You can enable unused parameter detection by (1) passing the keyword argument
323 | # `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`;
324 | if args.find_unused_parameters:
325 | model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
326 | else:
327 | model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) # ,find_unused_parameters=True
328 |
329 |
330 | logger.info("Prepare datasets...")
331 | if args.dataset_name == 'CPED':
332 | loader_class = build_cped_dataloaders
333 |
334 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
335 | train_loader, valid_loader, train_sampler, valid_sampler = build_cped_dataloaders(args, tokenizer, logger, load_test=False, filenames=filenames)
336 | test_loader, test_sampler = build_cped_dataloaders(args, tokenizer, logger, load_test=True, filenames=filenames)
337 |
338 |
339 |
340 | # Training function and trainer
341 | def update(engine, batch):
342 | model.train()
343 | # 从batch中读取数据
344 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
345 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch)
346 |
347 |
348 | # 模型进行前向计算
349 | '''参考代码
350 | if args.model_type == 'SPEAKERBERT':
351 | if args.autocast:
352 | with autocast():
353 | (lm_loss), *_ = model(input_ids=input_ids, masked_lm_labels=lm_labels, speaker_type_ids=speaker_type_ids)
354 | else:
355 | (lm_loss), *_ = model(input_ids=input_ids, masked_lm_labels=lm_labels, speaker_type_ids=speaker_type_ids)
356 | '''
357 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
358 | if args.autocast:
359 | with autocast():
360 | CausalLMOutput = model(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids)
361 | else:
362 | CausalLMOutput = model(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids)
363 | lm_loss = CausalLMOutput.loss
364 | #print("results=", lm_loss)
365 |
366 |
367 | # 反向传递
368 | if args.autocast:
369 | with autocast():
370 | loss = lm_loss / args.gradient_accumulation_steps
371 | else:
372 | loss = lm_loss / args.gradient_accumulation_steps
373 | if args.autocast: # 混合精度训练,要求:torch1.6+
374 | scaler.scale(loss).backward(retain_graph=args.retain_graph) # retain_graph here is unrelated to amp, it's present because in this both backward() calls share some sections of graph.
375 |
376 | if engine.state.iteration % args.gradient_accumulation_steps == 0:
377 | # 参考:https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
378 | if args.do_unscale:
379 | # 参考:https://pytorch.org/docs/1.9.0/notes/amp_examples.html#amp-examples
380 | # Unscales the gradients of optimizer's assigned params in-place
381 | scaler.unscale_(optimizer)
382 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual
383 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
384 | scaler.step(optimizer)
385 | scaler.update()
386 | optimizer.zero_grad()
387 |
388 | else:
389 | loss.backward(retain_graph=args.retain_graph) # 增加:retain_graph=True,以解决:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
390 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) # 梯度剪切函数,为防止梯度爆炸
391 | if engine.state.iteration % args.gradient_accumulation_steps == 0:
392 | optimizer.step()
393 | optimizer.zero_grad()
394 |
395 | return loss.item(), optimizer.param_groups[0]['lr']
396 |
397 | trainer = Engine(update)
398 |
399 | # Evaluation function and evaluator (evaluator output is the input of the metrics)
400 | def inference(engine, batch):
401 | model.eval()
402 | # 从batch中读取数据
403 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
404 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch)
405 |
406 | with torch.no_grad():
407 | # 模型进行前向计算
408 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
409 | if args.autocast:
410 | with autocast():
411 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids)
412 | else:
413 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids)
414 |
415 |
416 | lm_logits = CausalLMOutput.logits
417 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
418 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
419 | return lm_logits_flat_shifted, lm_labels_flat_shifted
420 |
421 | evaluator = Engine(inference)
422 |
423 | # Evaluation function and evaluator (evaluator output is the input of the metrics)
424 | def test(engine, batch):
425 | model.eval()
426 | # 从batch中读取数据
427 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
428 | input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels = tuple(None if input_tensor==None else input_tensor.to(args.device) for input_tensor in batch)
429 |
430 | with torch.no_grad():
431 | # 模型进行前向计算
432 | if args.model_type == 'GPT' or args.model_type == 'GPT-2':
433 | if args.autocast:
434 | with autocast():
435 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids)
436 | else:
437 | CausalLMOutput = model(input_ids=input_ids, token_type_ids=token_type_ids)
438 | lm_logits = CausalLMOutput.logits
439 | lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
440 | lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
441 | return lm_logits_flat_shifted, lm_labels_flat_shifted
442 |
443 | testor = Engine(test)
444 |
445 |
446 | # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
447 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(valid_loader))
448 | trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: testor.run(test_loader))
449 |
450 | if args.n_epochs < 1:
451 | trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(valid_loader))
452 | trainer.add_event_handler(Events.COMPLETED, lambda _: testor.run(test_loader))
453 |
454 | if args.eval_before_start:
455 | trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(valid_loader))
456 | trainer.add_event_handler(Events.COMPLETED, lambda _: testor.run(test_loader))
457 |
458 |
459 | # Evaluation during training
460 | @trainer.on(Events.ITERATION_STARTED)
461 | def log_iterations(engine):
462 | # if engine.state.iteration % max(int(0.1 * len(train_loader)), 1) == 0:
463 | if engine.state.iteration % args.valid_steps == 0:
464 | evaluator.run(valid_loader)
465 | testor.run(test_loader)
466 |
467 | # Make sure distributed data samplers split the dataset nicely between the distributed processes
468 | if args.distributed:
469 | trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
470 | evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))
471 | testor.add_event_handler(Events.EPOCH_STARTED, lambda engine: test_sampler.set_epoch(engine.state.epoch))
472 |
473 | # noam decrease the learning rate
474 | # d_model = model.config.n_embd
475 | # 参考论文《transformer is all you need》第5.3节
476 | d_model = args.n_emd
477 | noam_lambda = lambda step: (
478 | d_model ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5)))
479 | noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step)
480 | if args.scheduler == "noam":
481 | scheduler = LRScheduler(noam_scheduler)
482 | if args.scheduler == "linear":
483 | scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
484 | if args.scheduler == "cyclic":
485 | # 文章https://arxiv.org/pdf/1506.01186.pdf
486 | scheduler = LRScheduler(CyclicLR(optimizer, base_lr=args.base_lr, max_lr=args.max_lr, step_size_up=2000, step_size_down=2000, mode=args.cycliclr_mode, gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1))
487 | if args.scheduler == "1cycle":
488 | scheduler = LRScheduler(OneCycleLR(optimizer, args.lr, total_steps=args.n_epochs, epochs=None, steps_per_epoch=None, pct_start=0.3, anneal_strategy='cos', cycle_momentum=False, base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=10000.0, last_epoch=-1))
489 | if args.scheduler == "fixedlr":
490 | scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), args.lr)])
491 |
492 | trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
493 |
494 |
495 |
496 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
497 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "lr")
498 | metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0], x[1]))}
499 | metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
500 | metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
501 | for name, metric in metrics.items():
502 | metric.attach(evaluator, name)
503 | metric.attach(testor, name)
504 |
505 | # On the main process: add progress bar, tensorboard, checkpoints
506 | # And save model, configuration and tokenizer before we start to train
507 | if args.local_rank in [-1, 0]:
508 | pbar = ProgressBar(persist=True, mininterval=2)
509 | pbar.attach(trainer, metric_names=["loss", "lr"])
510 | evaluator.add_event_handler(Events.COMPLETED,
511 | lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))
512 | testor.add_event_handler(Events.COMPLETED,
513 | lambda _: pbar.log_message("Test: %s" % pformat(testor.state.metrics)))
514 |
515 | # tb_logger = TensorboardLogger(log_dir=None, comment='_'+args.model_type+'_'+args.dataset_name)
516 | # 统一日志名称和tensorboard输出文件夹名称, 2022.04.11
517 | tb_logger = TensorboardLogger(log_dir=os.path.join("./runs/",log_file_name_or_tensorboard_dir_name))
518 |
519 | tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]),
520 | event_name=Events.ITERATION_COMPLETED)
521 |
522 | tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
523 |
524 | tb_logger.attach_output_handler(evaluator,
525 | event_name=Events.EPOCH_COMPLETED,
526 | tag="validation",
527 | metric_names=list(metrics.keys()),
528 | global_step_transform=global_step_from_engine(trainer))
529 | tb_logger.attach_output_handler(testor,
530 | event_name=Events.EPOCH_COMPLETED,
531 | tag="test",
532 | metric_names=list(metrics.keys()),
533 | global_step_transform=global_step_from_engine(trainer))
534 | '''
535 | tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys())),
536 | event_name=Events.EPOCH_COMPLETED, global_step_transform=global_step_from_engine(trainer))
537 | tb_logger.attach(testor, log_handler=OutputHandler(tag="test", metric_names=list(metrics.keys())),
538 | event_name=Events.EPOCH_COMPLETED, global_step_transform=global_step_from_engine(trainer))
539 | '''
540 |
541 | # 连续3个epoch,测试集的ppl没有下降就停止训练
542 | early_stop_handler = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)
543 | testor.add_event_handler(Events.COMPLETED, early_stop_handler)
544 |
545 | #checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
546 | # ignite v0.4.8版本
547 |
548 | best_model_handler = Checkpoint(
549 | {"model": model},
550 | tb_logger.writer.log_dir,
551 | filename_prefix="best",
552 | n_saved=2,
553 | global_step_transform=global_step_from_engine(trainer),
554 | score_name="test_ppl",
555 | score_function=score_function,
556 | )
557 | testor.add_event_handler(Events.COMPLETED, best_model_handler)
558 |
559 | '''
560 | # save model after evaluation
561 | testor.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
562 | 'mymodel': getattr(model, 'module', model)})
563 | evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
564 | 'mymodel': getattr(model, 'module', model)})
565 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
566 | 'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation
567 | '''
568 |
569 | torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
570 | getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
571 |
572 | #tokenizer.save_vocabulary(tb_logger.writer.log_dir)
573 | # save the new tokens vacab
574 | tokenizer.save_pretrained(tb_logger.writer.log_dir)
575 | with open(tb_logger.writer.log_dir + "/training_args.json",'w',encoding='utf-8') as json_file:
576 | json.dump(pformat(args),json_file,ensure_ascii=False)
577 |
578 | # Run the training
579 | trainer.run(train_loader, max_epochs=args.n_epochs)
580 |
581 | # On the main process: close tensorboard logger and rename the last checkpoint
582 | # (for easy re-loading with OpenAIGPTModel.from_pretrained method)
583 | if args.local_rank in [-1, 0] and args.n_epochs > 0:
584 | # 重命名模型名称为WEIGHTS_NAME指定的字符串
585 | # ignite0.4.8有所调整
586 | # os.rename(os.path.join(tb_logger.writer.log_dir, checkpoint_handler._saved[-1][1]),
587 | # os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner)
588 | os.rename(os.path.join(tb_logger.writer.log_dir, best_model_handler.last_checkpoint),
589 | os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))
590 |
591 | tb_logger.close()
592 |
593 |
594 | if __name__ == "__main__":
595 | train()
--------------------------------------------------------------------------------
/pec_baseline/train_model_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # coding=utf-8
3 | # Copyright 2021 South China University of Technology and
4 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | # Author: Chen Yirong
19 | # Date: 2022.04.06
20 |
21 | # Train model with torch.distributed.launch by using torch.distributed.launch
22 | # Trained on Ubuntu18.04 with 2 GeForce RTX 2080ti GPUs and 125G Memory
23 | # --n_epochs=120 --warmup_steps=10000 --lr=5.5e-3 --pretrained --train_batch_size=4 --valid_batch_size=2 --test_batch_size=1 --num_workers=4
24 |
25 | # CPU个数:2,单个CPU核心数:10
26 | # cat /proc/cpuinfo| grep "physical id"| sort| uniq| wc -l
27 | # cat /proc/cpuinfo| grep "cpu cores"| uniq
28 |
29 |
30 |
31 | #############################################################################
32 | # 中文数据集CPED
33 | dataset_name=CPED
34 | data_path=../data/CPED
35 | model_checkpoint=~/PretrainedModel/CDial-GPT/CDial-GPT_LCCC-base
36 | cache_path=../data/CPED_cache_for_CpedDataset
37 |
38 | # GPT
39 | CUDA_VISIBLE_DEVICES=5 python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 2301 train_model.py --model_type=GPT --model_checkpoint=~/PretrainedModel/CDial-GPT/CDial-GPT_LCCC-base --pretrained --dataset_name=CPED --data_path=../data/CPED/ --cache_path=../data/CPED_cache_for_CpedDataset --lr=6.25e-5 --scheduler=noam --autocast --train_batch_size=1 --valid_batch_size=1 --test_batch_size=1 --n_epochs=2
40 |
--------------------------------------------------------------------------------
/pec_baseline/utils/README.md:
--------------------------------------------------------------------------------
1 | # utils
2 | # 数据读取模块
3 | * **Author: Chen Yirong **
4 | * **Date: 2022.03.21**
5 |
6 | ## 架构说明
7 | ### 基础文件
8 | base_util.py: 读取数据集的一些共用函数
9 | dataset_statistics.py: 用于数据集的统计等
10 |
11 | ### 数据集处理文件
12 | 每一个数据集(假设名字为:xxx)都有两个.py文件,格式如下:
13 | * xxx_util.py: 封装读取数据集的函数以及一些常量
14 | * xxx_dataset.py: 构建torch.utils.data.Dataset的子类、build_xxx_dataloaders函数
15 |
16 | ### CPED数据集的处理文件:
17 | * cped_util.py: 封装读取CPED数据集的函数以及一些常量
18 | * cped_dataset.py: 构建torch.utils.data.Dataset的子类、build_xxx_dataloaders函数
19 |
--------------------------------------------------------------------------------
/pec_baseline/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Dataset data loading module
18 | # Author: Chen Yirong
19 | # Date: 2022.03.21
20 |
21 | __version__ = "1.0.0"
22 |
23 | # 关键包版本说明:
24 | # pytorch: 1.9.0+
25 |
26 |
27 | try:
28 | import absl.logging
29 | absl.logging.set_verbosity('info')
30 | absl.logging.set_stderrthreshold('info')
31 | absl.logging._warn_preinit_stderr = False
32 | except:
33 | pass
34 |
35 | import logging
36 |
37 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
38 |
39 |
40 | # Files and general utilities
41 | from .base_util import (load_csv_data_from_dir,shuffle_total_data,combine_csv_files,save_speaker,load_speaker,
42 | convert_speaker_to_id,convert_id_to_speaker,convert_cache_to_csv,is_torch_available)
43 |
44 | from .dataset_statistics import (get_data_for_analysis, get_totaldata_for_analysis, get_row_statistics, cout_dialogue_words)
45 |
46 |
47 | # Dataset utilities
48 | if is_torch_available():
49 | # 读取CPED数据集的若干种函数
50 | from .cped_util import (cped_get_single_file,cped_get_single_cache_file,cped_get_data_from_dir,
51 | cped_get_single_file_for_bert_gpt,cped_get_data_from_dir_for_bert_gpt,
52 | CPED_SPECIAL_TOKENS,CPED_IGNORE_ID,CPED_DA_TOKENS,CPED_SENTIMENT_TOKENS,
53 | CPED_EMOTION_TOKENS,CPED_DA_TO_TOKENS,CPED_SENTIMENT_TO_TOKENS,CPED_EMOTION_TO_TOKENS,
54 | CPED_DA_TO_ID,CPED_EMOTION_TO_ID,CPED_GENDER_TO_ID,CPED_BIGFIVE_TO_ID,CPED_SPEAKER_TYPE_TO_ID)
55 |
56 | from .cped_dataset import (CpedDataset, build_cped_dataloaders, find_split_id_of_response, create_speaker_type,
57 | convert_emotion_to_tokens, convert_da_to_tokens, set_da_in_speaker, set_emotion_in_speaker)
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/pec_baseline/utils/base_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Dataset data loading file
18 | # File: base_util.py
19 | # Used for dataset loading
20 | # 用于数据集读取的基础方法
21 | # Author: Chen Yirong
22 | # Date: 2022.03.21
23 |
24 | import os
25 | import re
26 | import math
27 | import json
28 | import shutil
29 | import random
30 | import collections
31 | import importlib.util
32 | import pandas as pd
33 | import logging
34 | from io import open
35 | from os.path import join
36 | from .dataset_statistics import get_row_statistics
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | _torch_available = importlib.util.find_spec("torch") is not None
42 |
43 |
44 | def is_torch_available():
45 | return _torch_available
46 |
47 | def load_csv_data_from_dir(data_dir="../data/CPED",
48 | file_dict={"train":"train_split.csv",
49 | "valid":"valid_split.csv",
50 | "test":"test_split.csv"}):
51 | '''get_data from dir, which have train_split.csv, valid_split.csv, test_split.csv file
52 | Inputs:
53 | **data_dir**: str,
54 |
55 | '''
56 | print("Read dataset from ", data_dir)
57 | train_data = pd.read_csv(join(data_dir,file_dict["train"]), encoding="UTF-8-SIG")
58 | valid_data = pd.read_csv(join(data_dir,file_dict["valid"]), encoding="UTF-8-SIG")
59 | test_data = pd.read_csv(join(data_dir,file_dict["test"]), encoding="UTF-8-SIG")
60 | return train_data, valid_data, test_data
61 |
62 |
63 | def shuffle_total_data(data_path,
64 | save_path,
65 | validation_split_percentage=0.1,
66 | test_split_percentage=0.1,
67 | file_names = ["train_shuffle_split.csv", "valid_shuffle_split.csv", "test_shuffle_split.csv"],
68 | regen=False):
69 | '''shuffle_total_data
70 | 功能:将一个.csv文件随机打乱,拆分为训练集、验证集、测试集,分别保存
71 | 输入:
72 | data_path: .csv文件的路径
73 | save_path: 拆分后的文件保存的目录
74 | validation_split_percentage: 验证集比例
75 |
76 | '''
77 | if regen==False:
78 | print("不进行重复生成!")
79 | return False
80 | else:
81 | # 以下操作删除原先的文件,为危险操作
82 | if os.path.exists(join(save_path,file_names[0])):
83 | os.remove(join(save_path,file_names[0]))
84 |
85 | if os.path.exists(join(save_path,file_names[1])):
86 | os.remove(join(save_path,file_names[1]))
87 |
88 | if os.path.exists(join(save_path,file_names[2])):
89 | os.remove(join(save_path,file_names[2]))
90 |
91 | print("Read dataset from ", data_path)
92 | data = pd.read_csv(data_path,
93 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
94 | encoding="UTF-8-SIG")
95 | # 划分为训练集、测试集
96 | keys = list(set(data['Dialogue_ID']))
97 | random.shuffle(keys) # 随机打乱
98 | validation_split_id = int(len(keys)*(1-validation_split_percentage-test_split_percentage))
99 | test_split_id = int(len(keys)*(1-test_split_percentage))
100 | train_keys = keys[:validation_split_id] # 训练集索引
101 | valid_keys = keys[validation_split_id:test_split_id] # 验证集索引
102 | test_keys = keys[test_split_id:] # 测试集索引
103 | train_data = data[data['Dialogue_ID'].isin(train_keys)]
104 | valid_data = data[data['Dialogue_ID'].isin(valid_keys)]
105 | test_data = data[data['Dialogue_ID'].isin(test_keys)]
106 |
107 | train_data.to_csv(join(save_path,file_names[0]), encoding="UTF-8-SIG", index=False)
108 | valid_data.to_csv(join(save_path,file_names[1]), encoding="UTF-8-SIG", index=False)
109 | test_data.to_csv(join(save_path,file_names[2]), encoding="UTF-8-SIG", index=False)
110 | print("已经完成数据集生成!")
111 |
112 | return True
113 |
114 |
115 | # 将指定路径下的所有csv文件合并为一个csv文件
116 | # 实现该函数主要方便汇总统计
117 | def combine_csv_files(data_path="./MELD/",
118 | save_path="./MELD/MELD_total_text.csv",
119 | regen=False):
120 | '''combine_csv_files
121 | 将指定路径下的所有csv文件合并为一个csv文件
122 |
123 | 使用示例:
124 |
125 | combine_csv_files(data_path="./MELD/",
126 | save_name="MELD_total_text",
127 | files=file_names,
128 | save_in="./MELD/")
129 |
130 | '''
131 | if regen==False:
132 | print("不进行重复生成!")
133 | return False
134 |
135 |
136 |
137 | files = os.listdir("./MELD/")
138 | if os.path.isfile(join(save_in, "%s.csv") % save_name):
139 | return 0
140 | else:
141 | try:
142 | main_list = []
143 | for i in range(len(files)):
144 | content = pd.read_csv(join(data_path, files[i]), encoding="UTF-8-SIG")
145 | if i == 0:
146 | main_list.extend([list(content.keys())])
147 | main_list.extend(content.values.tolist())
148 |
149 | main_dict = {}
150 | for i in list(zip(*main_list)):
151 | main_dict[i[0]] = list(i[1:])
152 | data_df = pd.DataFrame(main_dict)
153 | data_df.to_csv(join(save_in, "%s.csv") % save_name, encoding="UTF-8-SIG", index=False)
154 | except:
155 | print("合并[%s]时发生错误" % save_name)
156 |
157 |
158 |
159 | def save_speaker(data_path, save_path, row_name="Speaker", regen=False):
160 | '''读取数据集的所有姓名,制作姓名表
161 | 使用示例:
162 | save_speaker(data_path="/148Dataset/Dataset/MELD/MELD/train_sent_emo.csv",
163 | save_path="/148Dataset/Dataset/MELD/MELD/speakers.txt",
164 | row_name="Speaker",
165 | regen=True)
166 | '''
167 | if os.path.exists(save_path):
168 | if regen == False:
169 | return None
170 | elif regen == True:
171 | os.remove(save_path)
172 |
173 | data = pd.read_csv(data_path, encoding="UTF-8-SIG")
174 | results = get_row_statistics(data,row_name)
175 | print(results["keys"])
176 | print(results["element_stastics"])
177 | with open(save_path, 'w') as f:
178 | for i in range(len(results["keys"])):
179 | f.write(results["keys"][i]+'\n')
180 | return True
181 |
182 |
183 | def load_speaker(speakers_file): #speakers.txt
184 | """Loads a speaker file into a dictionary.
185 | speakers_file: 姓名汇总表格
186 | speakers: 返回的列表形式的表格
187 | speakers_to_ids: 返回的字典,通过该字典可以根据姓名获得对应的id
188 | ids_to_speakers:通过该字典可以根据id获得对应的姓名
189 | """
190 | speakers_to_ids = collections.OrderedDict()
191 | with open(speakers_file, "r", encoding="utf-8") as reader:
192 | speakers = reader.readlines()
193 | for index, token in enumerate(speakers):
194 | token = token.rstrip('\n')
195 | speakers[index] = token
196 | speakers_to_ids[token] = index
197 | ids_to_speakers = collections.OrderedDict([(ids, tok) for tok, ids in speakers_to_ids.items()])
198 | return speakers, speakers_to_ids, ids_to_speakers
199 |
200 |
201 | def convert_speaker_to_id(speakers_to_ids, speaker, unk_token="其他"):
202 | """ Converts a speaker (str/unicode) in an id using the speakers. """
203 | return speakers_to_ids.get(speaker, speakers_to_ids.get(unk_token))
204 |
205 | def convert_id_to_speaker(ids_to_speakers, index, unk_token="其他"):
206 | """Converts an index (integer) in a speaker (string/unicode) using the speakers."""
207 | return ids_to_speakers.get(index, unk_token)
208 |
209 | def convert_cache_to_csv(dataset_cache,output_dir):
210 |
211 | data = torch.load(dataset_cache)
212 | train_data = data["train"]
213 | valid_data = data["valid"]
214 | test_data = data["test"]
215 | train_data.to_csv(join(output_dir,dataset_cache+"train.csv"),columns=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"])
216 |
217 |
--------------------------------------------------------------------------------
/pec_baseline/utils/cped_dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # CPED Dataset data loading file
18 | # File: cped_dataset.py
19 | # Used for CPED dataset loading
20 | # Author: Chen Yirong
21 | # Date: 2022.03.29
22 |
23 | # CPED数据集说明
24 | # 数据集存储形式如下:
25 | # CPED/ 顶层文件夹,
26 | # CPED_total_text.csv 全部对话数据汇总的csv格式文件,
27 | # speakers.txt 所有说话人姓名汇总集合的文本文件,
28 | # train_split.csv 训练集,与valid_split.csv、test_split.csv无重叠说话人
29 | # valid_split.csv 验证集,与train_split.csv、test_split.csv无重叠说话人
30 | # test_split.csv 测试集,与train_split.csv、valid_split.csv无重叠说话人
31 | # train_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
32 | # valid_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
33 | # test_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
34 | # 上述两种数据集划分方式可以用于不同的研究场景!
35 |
36 | # 本文件约定函数命名风格:cped_function_name,例如:
37 | # cped_get_total_data、cped_get_single_file
38 |
39 | # 本文件约定类的命名风格:CpedClassName,例如:
40 | # CpedDataset、CpedBertDataset、CpedBertGptDataset
41 |
42 | # 关键包版本说明:
43 | # pytorch: 1.9.0+
44 | import torch
45 | import logging
46 | from itertools import chain # 将二维列表转换为一维列表,[[1127, 1127, 6432, 6814],[118, 117, 116]]---> [1127, 1127, 6432, 6814, 118, 117, 116]
47 | from torch.utils.data import Dataset # 参考:https://pytorch.org/docs/1.9.0/data.html#torch.utils.data.Dataset
48 | from torch.utils.data import DataLoader
49 | from torch.nn.utils.rnn import pad_sequence
50 | from .cped_util import (tokenize, cped_get_single_file, cped_get_single_cache_file, cped_get_data_from_dir,
51 | cped_get_single_file_for_bert_gpt, cped_get_data_from_dir_for_bert_gpt,
52 | CPED_SPECIAL_TOKENS, CPED_IGNORE_ID, CPED_DA_TOKENS, CPED_SENTIMENT_TOKENS,
53 | CPED_EMOTION_TOKENS, CPED_DA_TO_TOKENS, CPED_SENTIMENT_TO_TOKENS, CPED_EMOTION_TO_TOKENS,
54 | CPED_DA_TO_ID, CPED_EMOTION_TO_ID, CPED_GENDER_TO_ID, CPED_BIGFIVE_TO_ID, CPED_SPEAKER_TYPE_TO_ID)
55 |
56 |
57 | logger = logging.getLogger(__name__)
58 |
59 |
60 | def find_split_id_of_response(speaker_list,responder):
61 | '''find_split_id_of_response
62 | Inputs:
63 | speaker_list: 姓名组成的列表,例如:['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲', '诺澜', '诺澜']
64 | responder: 字符串,表示回复者的姓名,例如:'诺澜'
65 | Outputs:
66 | split_id: 负整数,范围为-1至-len(speaker_list)+1
67 | Examples:
68 | speaker_list = ['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲', '诺澜', '诺澜']
69 | responder = '诺澜'
70 | split_id= find_split_id_of_response(speaker_list,responder)
71 | # 返回结果:-2
72 | # utterance_history = data_index["Token"].tolist()[-max_history_utterances:split_id]
73 | # reponse = data_index["Token"].tolist()[-split_id:]
74 | '''
75 | split_id = -1
76 | for i in range(-2,-len(speaker_list),-1):
77 | if speaker_list[i] != responder:
78 | return split_id
79 | else:
80 | split_id = split_id-1
81 | return -1 # 极端情形,只有一个人说话,则只认为最后一句为response内容
82 |
83 | def create_speaker_type(speaker_list,responder=None):
84 | '''create_speaker_type: 将姓名列表转换为"[speaker1]"、"[speaker2]"组成的列表
85 | Inputs:
86 | speaker_list: 姓名组成的列表,例如:['诺澜', '诺澜', '胡一菲', '胡一菲', '胡一菲']
87 | responder: 字符串,表示回复者的姓名,例如:'诺澜'
88 | Outputs:
89 | speaker_type_list: "[speaker1]"、"[speaker2]"组成的列表
90 | '''
91 | if responder==None: # 不指定responder
92 | speaker2 = speaker_list[-1] # 最后一个句子的说话人被定义为回复者
93 | else:
94 | speaker2 = responder
95 | speaker_type_list = []
96 | for speaker in speaker_list:
97 | if speaker==speaker2:
98 | speaker_type_list.append("[speaker2]") # "[speaker2]"代表回复者
99 | else:
100 | speaker_type_list.append("[speaker1]")
101 | return speaker_type_list
102 |
103 |
104 | def convert_emotion_to_tokens(emotion_list,
105 | emotion_type="Emotion",
106 | SELECTED_EMOTION_TO_TOKENS={"Emotion":CPED_EMOTION_TO_TOKENS,
107 | "Sentiment":CPED_SENTIMENT_TO_TOKENS}):
108 | '''convert_emotion_to_tokens: 将情感列表转换为词表当中的情感字符
109 | Inputs:
110 | emotion_list: 对话的情感标签列,每一个元素表示对应的句子的原始情感标签,例如:["happy","happy"]
111 | emotion_type: 字符串,表示情感类型,"Emotion"或者"Sentiment",指定了使用SELECTED_EMOTION_TO_TOKENS
112 | 当中的某一种字典用于将原始标签转换为TOKENS标签
113 | SELECTED_EMOTION_TO_TOKENS: 字典,其键为情感类型,值为相应的情感标签转换为Tokens的字典,由.cped_util定义
114 | Outputs:
115 | emotion_tokens_list: 经过转换后的情感tokens列表
116 | '''
117 | # emotion_tokens_list = [SELECTED_EMOTION_TO_TOKENS[emotion_type][emo] for emo in emotion_list]
118 | emotion_tokens_list = []
119 | for emo in emotion_list:
120 | if emo not in SELECTED_EMOTION_TO_TOKENS[emotion_type]:
121 | emotion_tokens_list.append("[neutral]")
122 | else:
123 | emotion_tokens_list.append(SELECTED_EMOTION_TO_TOKENS[emotion_type][emo])
124 | return emotion_tokens_list
125 |
126 | def convert_da_to_tokens(da_list,
127 | da_type="DA",
128 | SELECTED_DA_TO_TOKENS={"DA":CPED_DA_TO_TOKENS}):
129 | '''convert_da_to_tokens: 将DA列表转换为词表当中的DA字符
130 | Inputs:
131 | da_list: 对话的DA标签列,每一个元素表示对应的句子的原始DA标签
132 | da_type: 字符串,表示DA类型,"DA"或者自定义的DA列名称,指定了使用SELECTED_da_TO_TOKENS
133 | 当中的某一种字典用于将原始标签转换为TOKENS标签
134 | SELECTED_DA_TO_TOKENS: 字典,其键为DA类型,值为相应的DA标签转换为Tokens的字典,由.cped_util定义
135 | Outputs:
136 | da_tokens_list: 经过转换后的DA的tokens列表
137 | '''
138 | da_tokens_list = [SELECTED_DA_TO_TOKENS[da_type][da] for da in da_list]
139 | return da_tokens_list
140 |
141 |
142 | def set_da_in_speaker(da_ids,input_ids,bos, eos, speaker1, speaker2, pad):
143 | '''set_da_in_speaker: 仅在说话人标志位叠加DA Embedding
144 |
145 | '''
146 | special_token_ids_list = [bos, eos, speaker1, speaker2]
147 | new_da_ids = []
148 | for i,da in enumerate(da_ids):
149 | if input_ids[i] in special_token_ids_list:
150 | new_da_ids.append(da_ids[i])
151 | else:
152 | new_da_ids.append(pad)
153 | return new_da_ids
154 |
155 | def set_emotion_in_speaker(emotion_ids,input_ids,bos, eos, speaker1, speaker2, pad):
156 | '''set_emotion_in_speaker: 仅在说话人标志位叠加情感标签 Embedding
157 |
158 | '''
159 | special_token_ids_list = [bos, eos, speaker1, speaker2]
160 | new_emotion_ids = []
161 | for i,emotion in enumerate(emotion_ids):
162 | if input_ids[i] in special_token_ids_list:
163 | new_emotion_ids.append(emotion_ids[i])
164 | else:
165 | new_emotion_ids.append(pad)
166 | return new_emotion_ids
167 |
168 |
169 | class CpedDataset(Dataset):
170 | '''CpedDataset:用于常规对话生成模型,例如CDialGPT,增加情感或者个性或者DA控制,或者增加Embedding进行叠加
171 | 不适用于SPEAKERBERT、BERTGPT等模型
172 |
173 | Inputs:
174 | data: DataFrame,经过预处理后的对话数据,每一行为一个句子,data至少包含'Dialogue_ID'、"Token"两列
175 | 其中,data['Dialogue_ID']表示样本编号
176 | 其中data["Token"]为对应的句子经过tokenizer映射的ids,看起来像下面这样,
177 | [2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814]
178 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
179 | emotion_type: 字符串,指定情感列名,有"Sentiment"、"Emotion两种选择",可自行组建新的情感列
180 | da_type: 字符串,指定DA列名,目前只有"DA"一种,可自行组建新的DA列
181 | persona_type: 列表,指定个性列名
182 | max_history: 最大对话轮数,句子数=2*max_history
183 | batch_first: 布尔类型,指定batch是否在第一维,也就是(batch,...)
184 | lm_labels: 布尔类型,指定是否返回response
185 | with_current_speaker: 布尔类型,指定回复时的说话人
186 | with_current_persona: 布尔类型,指定回复时的个性
187 | with_current_emotion: 布尔类型,指定回复时的情感
188 | with_current_da: 布尔类型,指定回复时的对话动作DA
189 | with_emotion: 布尔类型,指定情感嵌入
190 | with_da=False: 布尔类型,指定DA嵌入
191 | use_speaker_name_as_speaker_list: 布尔类型,指定说话人姓名作为speaker_list
192 | set_eda_in_speaker: 布尔类型,指定在说话人位置嵌入情感或DA
193 |
194 | Outputs:
195 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``.
196 | Indices of input sequence tokens in the vocabulary.
197 | **token_type_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``.
198 | A parallel sequence of tokens (can be used to indicate various portions of the inputs).
199 | The embeddings from these tokens will be summed with the respective token embeddings.
200 | Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices)
201 | **emotion_ids**: (`optional`, returned when ``with_emotion=True``)
202 | ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``.
203 | **da_ids**: (`optional`, returned when ``with_da=True``)
204 | ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``.
205 | **current_speaker_id**: (`optional`, returned when ``with_current_speaker=True``)
206 | ``torch.LongTensor`` of shape ``(batch_size, 1)``.
207 | **current_persona_ids**: (`optional`, returned when ``with_current_persona=True``)
208 | ``torch.LongTensor`` of shape ``(batch_size, persona_size)``.
209 | **current_emotion_id**: (`optional`, returned when ``with_current_emotion=True``)
210 | ``torch.LongTensor`` of shape ``(batch_size, 1)``.
211 | **current_da_id**: (`optional`, returned when ``with_current_da=True``)
212 | ``torch.LongTensor`` of shape ``(batch_size, 1)``.
213 | **lm_labels**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
214 |
215 | Examples:
216 |
217 |
218 | '''
219 | def __init__(self,
220 | data,
221 | tokenizer,
222 | emotion_type="Emotion",
223 | da_type="DA",
224 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness","Age"],
225 | max_history=25, # 句子数则为50
226 | batch_first=True,
227 | lm_labels=True,
228 | with_current_speaker=False,
229 | with_current_persona=False,
230 | with_current_emotion=False,
231 | with_current_da=False,
232 | with_emotion=False,
233 | with_da=False,
234 | use_speaker_name_as_speaker_list=False,
235 | set_eda_in_speaker=False,
236 | set_current_speaker_mask=False,
237 | max_word_length=512): # 增加限制总的字符数
238 | self.data = data
239 | self.tokenizer = tokenizer
240 | self.emotion_type = emotion_type # 'Emotion' 情感标签列名
241 | self.da_type = da_type # 'DA' DA标签列名
242 | self.persona_type = persona_type
243 | self.with_current_speaker = with_current_speaker
244 | self.with_current_persona = with_current_persona
245 | self.with_current_emotion = with_current_emotion
246 | self.with_current_da = with_current_da
247 | self.with_emotion=with_emotion # Whether use emotion to help generate dialogue
248 | self.with_da=with_da # Whether use DA to help generate dialogue
249 | self.max_history = max_history # Maximum number of dialogue turns
250 | self.max_history_utterances = 2*max_history # Maximum number of dialogue sentences
251 | self.max_word_length = max_word_length
252 | self.use_speaker_name_as_speaker_list = use_speaker_name_as_speaker_list
253 | self.set_eda_in_speaker = set_eda_in_speaker
254 | self.set_current_speaker_mask = set_current_speaker_mask
255 | self.pad = tokenizer.pad_token_id
256 | self.batch_first = batch_first
257 | self.lm_labels = lm_labels
258 | self.keys = list(set(self.data['Dialogue_ID']))
259 | self.len = len(self.keys)
260 |
261 | def __len__(self):
262 | return self.len
263 |
264 | def __getitem__(self, index):
265 | dialogue_id = self.keys[index] # 当前对话样本编号
266 | data_index = self.data[self.data['Dialogue_ID']==dialogue_id]
267 |
268 | if len(data_index["Speaker"].tolist()) > self.max_history_utterances: # 实际句子数大于self.max_history_utterances
269 | max_history_utterances = self.max_history_utterances
270 | else: # 实际句子数小于self.max_history_utterances
271 | max_history_utterances = len(data_index["Speaker"].tolist())
272 | # 判断data_index的“句子数+data_index["Token"]的Token数+2”是否大于self.max_word_length
273 | while len(data_index["Speaker"].tolist()[-max_history_utterances:])+len(list(chain(*data_index["Token"].tolist()[-max_history_utterances:])))+2>self.max_word_length:
274 | max_history_utterances = max_history_utterances - 1
275 |
276 | speaker_name_list = data_index["Speaker"].tolist()[-max_history_utterances:] # 说话人姓名列表
277 | responder = speaker_name_list[-1] # 回复者姓名
278 | responder_token = self.tokenizer.convert_tokens_to_ids(responder) # 整数
279 |
280 | # 找出回复内容与历史对话的分割id
281 | response_split_id = find_split_id_of_response(speaker_name_list,responder)
282 | # 历史对话内容,长这样:[[2108, 3342, 3342, 8024], [1962, 1008, 2769, 1420]]
283 | history_utterance_tokens = data_index["Token"].tolist()[-max_history_utterances:response_split_id]
284 | # 回复内容,长这样:[[1127, 1127, 6432, 6814],[118, 117, 116]]---> [1127, 1127, 6432, 6814, 118, 117, 116]
285 | if self.lm_labels:
286 | response_utterance_tokens = data_index["Token"].tolist()[response_split_id:]
287 | response_utterance_tokens = list(chain(*response_utterance_tokens)) # 二维列表转一维列表
288 | else:
289 | response_utterance_tokens = []
290 |
291 | # 创建历史对话对应的history_speaker_types
292 | if self.use_speaker_name_as_speaker_list:
293 | # 使用说话人姓名嵌入,需要把说话人姓名加进词表!
294 | history_speaker_types = speaker_name_list[-max_history_utterances:response_split_id]
295 | else:
296 | # 使用函数create_speaker_type创建姓名嵌入
297 | # "[speaker2]"表示回复者,"[speaker1]"表示另一个说话人
298 | history_speaker_types = create_speaker_type(speaker_list=speaker_name_list[-max_history_utterances:response_split_id], responder=responder)
299 | # 将字符串表示转换为id表示
300 | history_speaker_tokens = self.tokenizer.convert_tokens_to_ids(history_speaker_types)
301 |
302 | # 创建历史对话对应的history_emotion_tokens
303 | if self.with_emotion: # 需要把情感标签加进词表!
304 | history_emotion_tokens = convert_emotion_to_tokens(emotion_list=data_index[self.emotion_type].tolist()[-max_history_utterances:response_split_id],
305 | emotion_type=self.emotion_type)
306 | history_emotion_tokens = self.tokenizer.convert_tokens_to_ids(history_emotion_tokens)
307 | else:
308 | history_emotion_tokens = []
309 |
310 | # 创建历史对话对应的history_da_tokens
311 | if self.with_da: # 需要把DA标签加进词表!
312 | history_da_tokens = convert_da_to_tokens(da_list=data_index[self.da_type].tolist()[-max_history_utterances:response_split_id],
313 | da_type=self.da_type)
314 | history_da_tokens = self.tokenizer.convert_tokens_to_ids(history_da_tokens)
315 | else:
316 | history_da_tokens = []
317 |
318 |
319 | # 创建用于指定回复的情感、DA、个性
320 | # 以下用于情感、DA与词嵌入共同使用一个Embedding
321 | current_emotion_token = self.tokenizer.convert_tokens_to_ids(data_index[self.emotion_type].tolist()[-1])
322 | current_da_token = self.tokenizer.convert_tokens_to_ids(data_index[self.da_type].tolist()[-1])
323 | # 以下用于情感、DA不与词嵌入共同使用一个Embedding
324 | current_emotion_id = CPED_EMOTION_TO_ID[data_index[self.emotion_type].tolist()[-1]]
325 | current_da_id = CPED_DA_TO_ID[data_index[self.da_type].tolist()[-1]]
326 | if self.with_current_persona:
327 | current_gender_id = CPED_GENDER_TO_ID[data_index[self.persona_type[0]].tolist()[-1]]
328 | current_Neuroticism_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[1]].tolist()[-1]]
329 | current_Extraversion_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[2]].tolist()[-1]]
330 | current_Openness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[3]].tolist()[-1]]
331 | current_Agreeableness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[4]].tolist()[-1]]
332 | current_Conscientiousness_id = CPED_BIGFIVE_TO_ID[data_index[self.persona_type[5]].tolist()[-1]]
333 | current_persona_ids = [current_gender_id,current_Neuroticism_id,current_Extraversion_id,current_Openness_id,
334 | current_Agreeableness_id,current_Conscientiousness_id]
335 | else:
336 | current_persona_ids = []
337 |
338 | return self.process(history_speaker_tokens,
339 | history_utterance_tokens,
340 | history_emotion_tokens,
341 | history_da_tokens,
342 | responder_token,
343 | current_emotion_token,
344 | current_da_token,
345 | current_emotion_id,
346 | current_da_id,
347 | current_persona_ids,
348 | response_utterance_tokens)
349 |
350 | def process(self,
351 | history_speaker_tokens,
352 | history_utterance_tokens,
353 | history_emotion_tokens,
354 | history_da_tokens,
355 | responder_token,
356 | current_emotion_token,
357 | current_da_token,
358 | current_emotion_id,
359 | current_da_id,
360 | current_persona_ids,
361 | response_utterance_tokens,
362 | with_eos=True):
363 | instance = {}
364 | bos, eos, speaker1, speaker2 = self.tokenizer.convert_tokens_to_ids(CPED_SPECIAL_TOKENS)
365 | speaker_tokens = history_speaker_tokens + [responder_token]
366 | emotion_tokens = history_emotion_tokens + [current_emotion_token]
367 | da_tokens = history_da_tokens + [current_da_token]
368 | sequence = [[bos]] + history_utterance_tokens + [response_utterance_tokens + ([eos] if with_eos else [])]
369 | sequence = [sequence[0]] + [[speaker_tokens[i]] + s
370 | for i, s in enumerate(sequence[1:])]
371 | instance["input_ids"] = list(chain(*sequence))
372 | instance["token_type_ids"] = [bos] + [speaker_tokens[i] for i, s in
373 | enumerate(sequence[1:])
374 | for _ in s]
375 |
376 | if self.with_da:
377 | instance["da_ids"] = [bos] + [da_tokens[i] for i, s in
378 | enumerate(sequence[1:])
379 | for _ in s]
380 | # only set the DA in [speaker1] or [speaker2]
381 | if self.set_eda_in_speaker:
382 | instance["da_ids"] = set_da_in_speaker(instance["da_ids"],instance["input_ids"],bos, eos, speaker1, speaker2, self.pad)
383 | if self.with_emotion:
384 | instance["emotion_ids"] = [bos] + [emotion_tokens[i] for i, s in
385 | enumerate(sequence[1:])
386 | for _ in s]
387 | # only set the emotion in [speaker1] or [speaker2]
388 | if self.set_eda_in_speaker:
389 | instance["emotion_ids"] = self.set_emotion_in_speaker(instance["emotion_ids"],instance["input_ids"],bos, eos, speaker1, speaker2, self.pad)
390 | if self.with_current_speaker:
391 | instance["current_speaker_id"] = responder_token
392 |
393 | if self.with_current_emotion:
394 | instance["current_emotion_id"] = current_emotion_id
395 |
396 | if self.with_current_da:
397 | instance["current_da_id"] = current_da_id
398 |
399 | if self.with_current_persona:
400 | instance["current_persona_ids"] = current_persona_ids
401 |
402 | instance["lm_labels"] = [-1] * len(instance["input_ids"])
403 | if self.lm_labels:
404 | instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:]
405 | if self.set_current_speaker_mask:
406 | instance["current_speaker_mask"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [1] + ([-1] * len(sequence[-1][1:]) )
407 |
408 | return instance
409 |
410 | def collate(self, batch):
411 | input_ids = pad_sequence(
412 | [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in batch],
413 | batch_first=self.batch_first, padding_value=self.pad)
414 | token_type_ids = pad_sequence(
415 | [torch.tensor(instance["token_type_ids"], dtype=torch.long) for instance in batch],
416 | batch_first=self.batch_first, padding_value=self.pad)
417 |
418 | if self.with_emotion:
419 | emotion_ids = pad_sequence(
420 | [torch.tensor(instance["emotion_ids"], dtype=torch.long) for instance in batch],
421 | batch_first=self.batch_first, padding_value=self.pad)
422 | else:
423 | emotion_ids = None
424 |
425 | if self.with_da:
426 | da_ids = pad_sequence(
427 | [torch.tensor(instance["da_ids"], dtype=torch.long) for instance in batch],
428 | batch_first=self.batch_first, padding_value=self.pad)
429 | else:
430 | da_ids = None
431 |
432 | if self.with_current_speaker:
433 | current_speaker_id = torch.tensor(
434 | [torch.tensor(instance["current_speaker_id"], dtype=torch.long) for instance in batch],
435 | dtype=torch.long)
436 | else:
437 | current_speaker_id = None
438 |
439 | if self.with_current_persona:
440 | current_persona_ids = pad_sequence(
441 | [torch.tensor(instance["current_persona_ids"], dtype=torch.long) for instance in batch],
442 | batch_first=self.batch_first, padding_value=1) # padding_value=1 means unknown here
443 | else:
444 | current_persona_ids = None
445 |
446 | if self.with_current_emotion:
447 | current_emotion_id = torch.tensor(
448 | [torch.tensor(instance["current_emotion_id"], dtype=torch.long) for instance in batch],
449 | dtype=torch.long)
450 | else:
451 | current_emotion_id = None
452 |
453 | if self.with_current_da:
454 | current_da_id = torch.tensor(
455 | [torch.tensor(instance["current_da_id"], dtype=torch.long) for instance in batch],
456 | dtype=torch.long)
457 | else:
458 | current_da_id = None
459 | lm_labels = pad_sequence(
460 | [torch.tensor(instance["lm_labels"], dtype=torch.long) for instance in batch],
461 | batch_first=self.batch_first, padding_value=-1)
462 |
463 | if self.set_current_speaker_mask: # for CVGPT
464 | current_speaker_mask = pad_sequence(
465 | [torch.tensor(instance["current_speaker_mask"], dtype=torch.long) for instance in batch],
466 | batch_first=self.batch_first, padding_value=-1)
467 | return input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels, current_speaker_mask
468 | else:
469 | return input_ids, token_type_ids, emotion_ids, da_ids, current_speaker_id, current_persona_ids, current_emotion_id, current_da_id, lm_labels
470 |
471 |
472 | def build_cped_dataloaders(args,
473 | tokenizer,
474 | logger,
475 | load_test=False,
476 | filenames={"train":"train_shuffle_split.csv",
477 | "valid":"valid_shuffle_split.csv",
478 | "test":"test_shuffle_split.csv"}):
479 | data,sample = cped_get_data_from_dir(dir_path=args.data_path,
480 | cache_path=args.cache_path,
481 | tokenizer=tokenizer,
482 | logger=logger,
483 | filenames=filenames)
484 |
485 | if load_test==False:
486 | logger.info("Build train and validation dataloaders")
487 | train_data = data["train"]
488 | valid_data = data["valid"]
489 | train_dataset = CpedDataset(data=train_data,
490 | tokenizer=tokenizer,
491 | emotion_type=args.emotion_type,
492 | da_type=args.da_type,
493 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
494 | max_history=args.max_history,
495 | batch_first=True,
496 | lm_labels=True,
497 | with_current_speaker=args.with_current_speaker,
498 | with_current_persona=args.with_current_persona,
499 | with_current_emotion=args.with_current_emotion,
500 | with_current_da=args.with_current_da,
501 | with_emotion=args.with_emotion,
502 | with_da=args.with_da,
503 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list,
504 | set_eda_in_speaker=args.set_eda_in_speaker,
505 | set_current_speaker_mask=args.set_current_speaker_mask)
506 |
507 |
508 | valid_dataset = CpedDataset(data=valid_data,
509 | tokenizer=tokenizer,
510 | emotion_type=args.emotion_type,
511 | da_type=args.da_type,
512 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
513 | max_history=args.max_history,
514 | batch_first=True,
515 | lm_labels=True,
516 | with_current_speaker=args.with_current_speaker,
517 | with_current_persona=args.with_current_persona,
518 | with_current_emotion=args.with_current_emotion,
519 | with_current_da=args.with_current_da,
520 | with_emotion=args.with_emotion,
521 | with_da=args.with_da,
522 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list,
523 | set_eda_in_speaker=args.set_eda_in_speaker,
524 | set_current_speaker_mask=args.set_current_speaker_mask)
525 |
526 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
527 | valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
528 | train_loader = DataLoader(train_dataset,
529 | sampler=train_sampler,
530 | collate_fn=train_dataset.collate,
531 | num_workers=args.num_workers,
532 | batch_size=args.train_batch_size,
533 | shuffle=(not args.distributed))
534 | valid_loader = DataLoader(valid_dataset,
535 | sampler=valid_sampler,
536 | collate_fn=valid_dataset.collate,
537 | num_workers=args.num_workers,
538 | batch_size=args.valid_batch_size,
539 | shuffle=False)
540 |
541 | return train_loader, valid_loader, train_sampler, valid_sampler
542 |
543 | else:
544 | logger.info("Build test dataloaders")
545 | test_data = data["test"]
546 | test_dataset = CpedDataset(data=test_data,
547 | tokenizer=tokenizer,
548 | emotion_type=args.emotion_type,
549 | da_type=args.da_type,
550 | persona_type=["Gender","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
551 | max_history=args.max_history,
552 | batch_first=True,
553 | lm_labels=True,
554 | with_current_speaker=args.with_current_speaker,
555 | with_current_persona=args.with_current_persona,
556 | with_current_emotion=args.with_current_emotion,
557 | with_current_da=args.with_current_da,
558 | with_emotion=args.with_emotion,
559 | with_da=args.with_da,
560 | use_speaker_name_as_speaker_list=args.use_speaker_name_as_speaker_list,
561 | set_eda_in_speaker=args.set_eda_in_speaker,
562 | set_current_speaker_mask=args.set_current_speaker_mask)
563 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if args.distributed else None
564 | test_loader = DataLoader(test_dataset,
565 | sampler=test_sampler,
566 | collate_fn=test_dataset.collate,
567 | num_workers=args.num_workers,
568 | batch_size=args.test_batch_size,
569 | shuffle=False)
570 | return test_loader, test_sampler
571 |
--------------------------------------------------------------------------------
/pec_baseline/utils/cped_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # CPED Dataset data loading file
18 | # File: cped_util.py
19 | # Used for CPED dataset loading
20 | # Author: Chen Yirong
21 | # Date: 2022.03.22
22 |
23 | # CPED数据集说明
24 | # 数据集存储形式如下:
25 | # CPED/ 顶层文件夹,
26 | # CPED_total_text.csv 全部对话数据汇总的csv格式文件,
27 | # speakers.txt 所有说话人姓名汇总集合的文本文件,
28 | # train_split.csv 训练集,与valid_split.csv、test_split.csv无重叠说话人
29 | # valid_split.csv 验证集,与train_split.csv、test_split.csv无重叠说话人
30 | # test_split.csv 测试集,与train_split.csv、valid_split.csv无重叠说话人
31 | # train_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
32 | # valid_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
33 | # test_shuffle_split.csv 将CPED_total_text.csv随机打乱以8:1:1切割得到
34 | # 上述两种数据集划分方式可以用于不同的研究场景!
35 |
36 | # 本文件约定函数命名风格:cped_function_name,例如:
37 | # cped_get_total_data、cped_get_single_file
38 |
39 | # 关键包版本说明:
40 | # pytorch: 1.9.0+
41 |
42 | import os
43 | import re
44 | import torch
45 | import logging
46 | import pandas as pd
47 |
48 | logger = logging.getLogger(__name__)
49 |
50 |
51 | # CPED数据采集用到的一些常量,例如:情感标签、对话动作、
52 | CPED_SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[speaker1]", "[speaker2]"]
53 | CPED_IGNORE_ID = -1 # Tokens with indices set to ``-1`` are ignored,用于训练SpeakerBert
54 |
55 | # 以下列表用于添加到词表当中
56 | CPED_DA_TOKENS = ["[greeting]","[question]","[answer]","[statement-opinion]","[statement-non-opinion]","[apology]",
57 | "[command]","[agreement]","[disagreement]","[acknowledge]","[appreciation]","[interjection]",
58 | "[conventional-closing]","[quotation]","[reject]","[irony]","[comfort]","[thanking]","[da-other]"] # 19 DA labels
59 |
60 | CPED_SENTIMENT_TOKENS = ["[neutral]","[positive]","[negative]"]
61 |
62 | CPED_EMOTION_TOKENS = ["[happy]","[grateful]","[relaxed]","[positive-other]","[anger]","[sadness]","[fear]",
63 | "[depress]","[disgust]","[astonished]","[worried]","[negative-other]","[neutral]"] # 13 emotion labels
64 |
65 | CPED_DA_TO_TOKENS = {'greeting': '[greeting]', 'question': '[question]', 'answer': '[answer]',
66 | 'statement-opinion': '[statement-opinion]', 'statement-non-opinion': '[statement-non-opinion]',
67 | 'apology': '[apology]', 'command': '[command]', 'agreement': '[agreement]',
68 | 'disagreement': '[disagreement]', 'acknowledge': '[acknowledge]', 'appreciation': '[appreciation]',
69 | 'interjection': '[interjection]', 'conventional-closing': '[conventional-closing]',
70 | 'quotation': '[quotation]', 'reject': '[reject]', 'irony': '[irony]',
71 | 'comfort': '[comfort]','thanking':'[thanking]', 'other': '[da-other]'}
72 |
73 | CPED_SENTIMENT_TO_TOKENS = {'neutral': '[neutral]', 'positive': '[positive]', 'negative': '[negative]'}
74 |
75 | CPED_EMOTION_TO_TOKENS = {'happy': '[happy]', 'grateful': '[grateful]', 'relaxed': '[relaxed]',
76 | 'positive-other': '[positive-other]', 'anger': '[anger]', 'sadness': '[sadness]',
77 | 'fear': '[fear]', 'depress': '[depress]', 'disgust': '[disgust]',
78 | 'astonished': '[astonished]', 'worried': '[worried]', 'negative-other': '[negative-other]',
79 | 'neutral': '[neutral]'}
80 |
81 | CPED_DA_TO_ID = {'greeting': 0, 'question': 1, 'answer': 2, 'statement-opinion': 3, 'statement-non-opinion': 4,
82 | 'apology': 5, 'command': 6, 'agreement': 7, 'disagreement': 8, 'acknowledge': 9, 'appreciation': 10,
83 | 'interjection': 11, 'conventional-closing': 12, 'quotation': 13, 'reject': 14, 'irony': 15,
84 | 'comfort': 16,'thanking':17, 'other': 18}
85 |
86 | CPED_EMOTION_TO_ID = {'happy': 0, 'grateful': 1, 'relaxed': 2, 'positive-other': 3, 'anger': 4, 'sadness': 5,
87 | 'fear': 6, 'depress': 7, 'disgust': 8, 'astonished': 9, 'worried': 10,
88 | 'negative-other': 11, 'neutral': 12}
89 |
90 | CPED_GENDER_TO_ID = {'female': 0, 'unknown': 1, 'male': 2}
91 | CPED_BIGFIVE_TO_ID = {'low': 0, 'unknown': 1, 'high': 2}
92 |
93 | CPED_SPEAKER_TYPE_TO_ID ={"[speaker1]": 0, "[speaker2]": 1, "[MASK]": 2}
94 |
95 | # 给语音识别文本加上标点符号
96 | # https://blog.csdn.net/qq_33200967/article/details/122474859
97 |
98 |
99 |
100 |
101 |
102 | def tokenize(utterance, tokenizer):
103 | '''tokenize:使用tokenizer对utterance进行tokenize
104 | Inputs:
105 | utterance: 字符串,一个句子
106 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
107 | Outputs:
108 | ids: 列表,列表中的每个元素对应token的id,例如:[2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814]
109 | Example:
110 | from transformers import BertTokenizer
111 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
112 | ids = tokenize(utterance="季杨杨,好像我听凡凡说过", tokenizer=tokenizer)
113 | print(ids)
114 | # 返回:[2108, 3342, 3342, 8024, 1962, 1008, 2769, 1420, 1127, 1127, 6432, 6814]
115 | '''
116 | utterance = str(utterance) # 保证为str类型
117 | # 对于问句添加问号
118 | utterance = utterance.replace("吗", "吗?")
119 | utterance = utterance.replace("??", "?")
120 |
121 | # 对于感叹句添加感叹号
122 | utterance = utterance.replace("啊", "啊!")
123 | utterance = utterance.replace("吧", "吧!")
124 | utterance = utterance.replace("啦", "啦!")
125 | utterance = utterance.replace("呀", "呀!")
126 | utterance = utterance.replace("!!", "!")
127 |
128 | # 对于句子中间非问句,非感叹句添加逗号
129 | utterance = utterance.replace(" ", ",")
130 | # 去除重复标点符号
131 | utterance = utterance.split() # 去除全部空格
132 |
133 | utt_list = list(utterance) # "季杨杨,好像我听凡凡说过" --> ['季', '杨', '杨', ',', '好', '像', '我', '听', '凡', '凡', '说', '过']
134 |
135 | utterance = ' '.join(utt_list) # ['季', '杨', '杨', ',', '好', '像', '我', '听', '凡', '凡', '说', '过']--> “季 杨 杨 , 好 像 我 听 凡 凡 说 过” #
136 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(utterance))
137 |
138 |
139 | def cped_get_single_file(file_path,
140 | tokenizer,
141 | logger,
142 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
143 | args=None):
144 | '''cped_get_single_file: 读取指定路径的csv文件,例如:CPED_total_text.csv、train_split.csv、...
145 | Inputs:
146 | file_path: 字符串,指定文件路径
147 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
148 | logger: logging日志对象
149 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中
150 | "Dialogue_ID","Utterance_ID","Speaker","Utterance"
151 | 是必需项
152 | args: parser.parse_args()返回的参数字典
153 | Outputs:
154 | data: DataFrame对象
155 | samples: DataFrame对象
156 | Example:
157 | import logging
158 | from transformers import BertTokenizer
159 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
160 | logger = logging.getLogger(__name__)
161 | file_path = "../data/CPED/test_split.csv"
162 | data, samples = cped_get_single_file(file_path, tokenizer, logger)
163 |
164 | '''
165 | logger.info("Read file from %s", file_path)
166 | data = pd.read_csv(file_path,
167 | usecols=usecols,
168 | encoding="UTF-8-SIG")
169 | samples = data.iloc[0:30]
170 |
171 | logger.info("Start tokenizing and encoding the file")
172 | data["Token"] = [tokenize(s, tokenizer) for s in data["Utterance"]]
173 | logger.info("Finished tokenizing and encoding the dataset")
174 | return data, samples
175 |
176 |
177 | def cped_get_single_cache_file(file_path,
178 | cache_path,
179 | tokenizer,
180 | logger,
181 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
182 | args=None):
183 | '''cped_get_single_cache_file: 读取指定路径的csv文件,如果存在cache,则直接读取cache文件,
184 | 例如:CPED_total_text.csv、train_split.csv、...
185 | 这个函数与cped_get_single_file的最大不同就是,第一次读取会保存一个cache文件,之后再读取,就不需要
186 | 调用tokenizer进行预处理了,节省大量实验时间
187 | Inputs:
188 | file_path: 字符串,指定文件路径
189 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path)
190 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
191 | logger: logging日志对象
192 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中
193 | "Dialogue_ID","Utterance_ID","Speaker","Utterance"
194 | 是必需项
195 | args: parser.parse_args()返回的参数字典
196 | Outputs:
197 | data: DataFrame对象
198 | samples: DataFrame对象
199 | Example:
200 | import logging
201 | from transformers import BertTokenizer
202 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
203 | logger = logging.getLogger(__name__)
204 | file_path = "../data/CPED/test_split.csv"
205 | cache_path = "../data/CPED/test_split_cache"
206 | data, samples = cped_get_single_cache_file(file_path, cache_path, tokenizer, logger)
207 |
208 | '''
209 | if cache_path and os.path.isfile(cache_path):
210 | logger.info("Load tokenized dataset from cache at %s", cache_path)
211 | data = torch.load(cache_path)
212 | samples = None
213 | else: # 从原始文件中读取数据
214 | logger.info("Read dataset from %s", file_path)
215 | data, samples = cped_get_single_file(file_path=file_path,
216 | tokenizer=tokenizer,
217 | logger=logger,
218 | usecols=usecols,
219 | args=args)
220 | logger.info("Finished tokenizing and encoding the dataset")
221 | logger.info("Save tokenized dataset to cache at %s", cache_path)
222 | torch.save(data, cache_path)
223 | return data, samples
224 |
225 |
226 | def cped_get_data_from_dir(dir_path,
227 | cache_path,
228 | tokenizer,
229 | logger,
230 | filenames={"train":"train_shuffle_split.csv",
231 | "valid":"valid_shuffle_split.csv",
232 | "test":"test_shuffle_split.csv"},
233 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
234 | args=None):
235 | '''cped_get_data_from_dir: 读取dir_path指定目录下,字典filenames指定的数据集
236 | 如果存在cache,则直接读取cache文件,
237 | Inputs:
238 | dir_path: 字符串,指定数据集存放的目录
239 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path)
240 | tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
241 | logger: logging日志对象
242 | filenames: 字典,包括"train"、"valid"、"test"三个键,其值指定对应的文件名
243 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中
244 | "Dialogue_ID","Utterance_ID","Speaker","Utterance"
245 | 是必需项
246 | args: parser.parse_args()返回的参数字典
247 | Outputs:
248 | data: 字典,格式为{"train":train_data,"valid":valid_data, "test":test_data},每一个值为DataFrame对象
249 | samples: DataFrame对象
250 | Example:
251 | import logging
252 | from transformers import BertTokenizer
253 | tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
254 | logger = logging.getLogger(__name__)
255 | dir_path = "../data/CPED"
256 | cache_path = "../data/CPED/cped_cache"
257 | filenames = {"train":"train_shuffle_split.csv",
258 | "valid":"valid_shuffle_split.csv",
259 | "test":"test_shuffle_split.csv"}
260 | data, samples = cped_get_data_from_dir(dir_path, cache_path, tokenizer, logger, filenames)
261 |
262 | '''
263 | if cache_path and os.path.isfile(cache_path):
264 | logger.info("Load tokenized dataset from cache at %s", cache_path)
265 | data = torch.load(cache_path)
266 | samples = None
267 | else: # 从原始文件中读取数据
268 | logger.info("Read dataset from %s", dir_path)
269 | train_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["train"]), tokenizer, logger, usecols, args)
270 | valid_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["valid"]), tokenizer, logger, usecols, args)
271 | test_data, samples = cped_get_single_file(os.path.join(dir_path,filenames["test"]), tokenizer, logger, usecols, args)
272 | data = {"train":train_data,"valid":valid_data, "test":test_data}
273 | logger.info("Finished tokenizing and encoding the dataset")
274 | logger.info("Save tokenized dataset to cache at %s", cache_path)
275 | torch.save(data, cache_path)
276 | return data, samples
277 |
278 |
279 | def cped_get_single_file_for_bert_gpt(file_path,
280 | bert_tokenizer,
281 | gpt_tokenizer,
282 | logger,
283 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
284 | args=None):
285 | '''cped_get_single_file_for_bert_gpt: 读取指定路径的csv文件,例如:CPED_total_text.csv、train_split.csv、...
286 | 并使用两种tokenizer进行tokenize
287 | Inputs:
288 | file_path: 字符串,指定文件路径
289 | bert_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
290 | gpt_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
291 | logger: logging日志对象
292 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中
293 | "Dialogue_ID","Utterance_ID","Speaker","Utterance"
294 | 是必需项
295 | args: parser.parse_args()返回的参数字典
296 | Outputs:
297 | data: DataFrame对象
298 | samples: DataFrame对象
299 | Example:
300 | import logging
301 | from transformers import BertTokenizer
302 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
303 | gpt_tokenizer = BertTokenizer.from_pretrained("openai-gpt")
304 | logger = logging.getLogger(__name__)
305 | file_path = "../data/CPED/test_split.csv"
306 | data, samples = cped_get_single_file_for_bert_gpt(file_path, bert_tokenizer, gpt_tokenizer, logger)
307 |
308 | '''
309 | logger.info("Read file from %s", file_path)
310 | data = pd.read_csv(file_path,
311 | usecols=usecols,
312 | encoding="UTF-8-SIG")
313 | samples = data.iloc[0:30]
314 |
315 | logger.info("Start tokenizing and encoding the file")
316 | data["Token_bert"] = [tokenize(s, bert_tokenizer) for s in data["Utterance"]]
317 | data["Token_gpt"] = [tokenize(s, gpt_tokenizer) for s in data["Utterance"]]
318 | logger.info("Finished tokenizing and encoding the dataset")
319 | return data, samples
320 |
321 |
322 | def cped_get_data_from_dir_for_bert_gpt(dir_path,
323 | cache_path,
324 | bert_tokenizer,
325 | gpt_tokenizer,
326 | logger,
327 | filenames={"train":"train_shuffle_split.csv",
328 | "valid":"valid_shuffle_split.csv",
329 | "test":"test_shuffle_split.csv"},
330 | usecols=["Dialogue_ID","Utterance_ID","Speaker","Sentiment","Emotion","DA","Utterance","Gender","Age","Neuroticism","Extraversion","Openness","Agreeableness","Conscientiousness"],
331 | args=None):
332 | '''cped_get_data_from_dir_for_bert_gpt: 读取dir_path指定目录下,字典filenames指定的数据集
333 | 如果存在cache,则直接读取cache文件,
334 | Inputs:
335 | dir_path: 字符串,指定数据集存放的目录
336 | cache_path: cache文件保存的路径,建议这个文件的命名做好管理,否则容易混淆数据集,torch.save(data, cache_path)
337 | bert_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
338 | gpt_tokenizer: Tokenizer对象,参考:https://huggingface.co/docs/transformers/main_classes/tokenizer
339 | logger: logging日志对象
340 | filenames: 字典,包括"train"、"valid"、"test"三个键,其值指定对应的文件名
341 | usecols: 列表,列表中的字符串指定了读取的csv文件的列名,其中
342 | "Dialogue_ID","Utterance_ID","Speaker","Utterance"
343 | 是必需项
344 | args: parser.parse_args()返回的参数字典
345 | Outputs:
346 | data: 字典,格式为{"train":train_data,"valid":valid_data, "test":test_data},每一个值为DataFrame对象
347 | samples: DataFrame对象
348 | Example:
349 | import logging
350 | from transformers import BertTokenizer
351 | bert_tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
352 | gpt_tokenizer = BertTokenizer.from_pretrained("openai-gpt")
353 | logger = logging.getLogger(__name__)
354 | dir_path = "../data/CPED"
355 | cache_path = "../data/CPED/cped_cache"
356 | filenames = {"train":"train_shuffle_split.csv",
357 | "valid":"valid_shuffle_split.csv",
358 | "test":"test_shuffle_split.csv"}
359 | data, samples = cped_get_data_from_dir_for_bert_gpt(dir_path, cache_path, bert_tokenizer, gpt_tokenizer, logger, filenames)
360 |
361 | '''
362 | if cache_path and os.path.isfile(cache_path):
363 | logger.info("Load tokenized dataset from cache at %s", cache_path)
364 | data = torch.load(cache_path)
365 | samples = None
366 | else: # 从原始文件中读取数据
367 | logger.info("Read dataset from %s", dir_path)
368 | train_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["train"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args)
369 | valid_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["valid"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args)
370 | test_data, samples = cped_get_single_file_for_bert_gpt(os.path.join(dir_path,filenames["test"]), bert_tokenizer, gpt_tokenizer, logger, usecols, args)
371 | data = {"train":train_data,"valid":valid_data, "test":test_data}
372 | logger.info("Finished tokenizing and encoding the dataset")
373 | logger.info("Save tokenized dataset to cache at %s", cache_path)
374 | torch.save(data, cache_path)
375 | return data, samples
376 |
377 |
--------------------------------------------------------------------------------
/pec_baseline/utils/dataset_statistics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 South China University of Technology and
3 | # Engineering Research Ceter of Ministry of Education on Human Body Perception.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Dataset statistic file
18 | # File: dataset_statistics.py
19 | # Used for dataset analysis
20 | # Author: Chen Yirong
21 | # Date: 2022.03.21
22 |
23 | import os
24 | import numpy as np
25 | import pandas as pd
26 | from os.path import join
27 | from collections import Counter
28 |
29 |
30 | def get_data_for_analysis(data_dir="../data/CPED",
31 | file_dict={"train":"train_split.csv","valid":"valid_split.csv","test":"test_split.csv"}):
32 | '''get_data from dir, which have train_split.csv, valid_split.csv, test_split.csv file
33 | Get .csv format dataset from data_dir.
34 | '''
35 | print("Read dataset from ", data_dir)
36 | train_data = pd.read_csv(join(data_dir,file_dict["train"]), encoding="UTF-8-SIG")
37 | valid_data = pd.read_csv(join(data_dir,file_dict["valid"]), encoding="UTF-8-SIG")
38 | test_data = pd.read_csv(join(data_dir,file_dict["test"]), encoding="UTF-8-SIG")
39 | return train_data, valid_data, test_data
40 |
41 |
42 | def get_totaldata_for_analysis(data_path="/home/MMMTD/data/processed_cleaned_data/total_checked_processed_cleaned_data/checked_processed_cleaned_data.csv"):
43 | '''get total data from data_path
44 | Get .csv format dataset from data_path.
45 | '''
46 | print("Read dataset from ", data_path)
47 | total_data = pd.read_csv(data_path, encoding="UTF-8-SIG")
48 | return total_data
49 |
50 |
51 | def get_row_statistics(data,row_name):
52 | '''get_row_statistics
53 | Get dataset row statistics with row_name
54 | E.g.
55 | train_data_TV_ID=get_row_statistics(train_data,"TV_ID")
56 | print("train TV stastics:\n", train_data_TV_ID["element_stastics"])
57 | '''
58 | name = row_name
59 | keys = list(set(data[name]))
60 | values = data[name].tolist()
61 | element_stastics=pd.value_counts(values)
62 |
63 | row_size = len(values)
64 | row_class = len(keys)
65 | results={"name":name,"keys":keys,"values":values,"element_stastics":element_stastics,"size":row_size,"class":row_class}
66 | return results
67 |
68 |
69 | def cout_dialogue_words(data,dialogue_id):
70 | dialogue_data = data[data['Dialogue_ID']==dialogue_id]
71 | count = 0
72 | for utt in dialogue_data["Utterance"]:
73 | count = count + len(str(utt))
74 | return count
75 |
76 |
77 |
78 | def remove_element(utt_list,word = " "):
79 | temp_list = utt_list
80 | while word in temp_list:
81 | temp_list.remove(word)
82 | return temp_list
83 |
84 |
85 | def statistics_utterance(data,row_name="Utterance"):
86 | '''statistics_utterance
87 | Count the average and maximum word count of sentences
88 |
89 | '''
90 | utt_list = data[row_name].tolist()
91 | utt_word_list = [remove_element(list(utterance)) for utterance in utt_list]
92 | count_utt_word = []
93 | for utt in utt_word_list:
94 | count_utt_word.append(len(utt))
95 | return {"max":max(count_utt_word),"avg":sum(count_utt_word)/len(count_utt_word)}
96 |
97 |
98 | def statistics_emotda(data,row_name="Emotion",dialogue_id="Dialogue_ID"):
99 | '''statistics_emotda
100 | Count the average emotion/DA per Dialogue
101 |
102 | '''
103 | keys = list(set(data[dialogue_id]))
104 | count_dial_eda = []
105 | for key in keys:
106 | dial_eda = list(set(data[data[dialogue_id]==key][row_name].tolist()))
107 | count_dial_eda.append(len(dial_eda))
108 | return {"max":max(count_dial_eda),"avg":sum(count_dial_eda)/len(count_dial_eda)}
109 |
110 |
111 | def statistics_avg_duration(all_data,data,dialogue_id="Dialogue_ID",StartTime="StartTime",EndTime="EndTime"):
112 | '''statistics_emotda
113 | Count the average emotion/DA per Dialogue
114 |
115 | '''
116 | keys = list(set(data[dialogue_id]))
117 | count_dial_time = []
118 | for key in keys:
119 | start_time = all_data[all_data[dialogue_id]==key][StartTime]
120 | end_time = all_data[all_data[dialogue_id]==key][EndTime]
121 | time_list = np.array([int(s) for s in end_time.tolist()])-np.array([int(s) for s in start_time.tolist()])
122 | time_list = time_list.tolist()
123 |
124 | count_dial_time.append(sum(time_list)/len(time_list))
125 | return {"avg":sum(count_dial_time)/len(count_dial_time)}
126 |
127 |
128 | def print_speaker(data_path):
129 | print("Output all the name of speakers from "+data_path)
130 | data = get_totaldata_for_analysis(data_path)
131 | data_speaker_result = get_row_statistics(data,"说话者姓名")
132 | print(data_speaker_result["keys"])
133 | print(data_speaker_result["class"])
134 |
135 |
136 | def print_speaker_from_dir(input_dir="/home/MMMTD/data/processed_cleaned_data/checked_processed_cleaned_data"):
137 | if not os.path.exists(input_dir):
138 | print("unexisted data dir:"+input_dir)
139 | return False
140 | print("From"+input_dir+"load file......")
141 | file_names = os.listdir(input_dir)
142 | for file_name in file_names:
143 | print_speaker(data_path= join(input_dir,file_name))
144 |
145 | return True
146 |
147 |
148 | def statistic_speaker(data_path = "/home/MMMTD/data/MMMTD_cleaned_speaker_annotation.xlsx"):
149 | if not os.path.isfile(data_path):
150 | print("unexisted data dir:"+input_dir)
151 | return 0
152 | else:
153 | data = pd.read_excel(data_path) # xlrd==1.2.0, do not use xlrd==2.0.1
154 | gender_result = get_row_statistics(data,"性别")
155 | age_result = get_row_statistics(data,"年龄段")
156 | return gender_result, age_result
157 |
158 |
159 | def print_sentiment(data_path, sentiment='中性情绪'):
160 | print("From "+data_path+" return "+sentiment)
161 | data = get_totaldata_for_analysis(data_path)
162 | data_result = get_row_statistics(data,"情绪(粗粒度)")
163 | print(data_result['element_stastics'][sentiment])
164 | return data_result['element_stastics'][sentiment]/data_result['size']
165 |
166 |
167 | def print_sentiment_from_dir(input_dir="/home/MMMTD/data/processed_cleaned_data/checked_processed_cleaned_data"):
168 | if not os.path.exists(input_dir):
169 | print("unexisted data dir:"+input_dir)
170 | return False
171 | print("From"+input_dir+"load file......")
172 | file_names = os.listdir(input_dir)
173 | result_dir = {}
174 | for file_name in file_names:
175 | result_dir[file_name] = print_sentiment(data_path= join(input_dir,file_name), sentiment='中性情绪')
176 |
177 | return result_dir
178 |
179 |
180 | def count_eda_array_from_data(data,da_name = "DA", emotion_name = "Emotion", utt_id_name = "Utterance_ID"):
181 | # Get DA and Emotion labels
182 | da_label = list(set(data[da_name].tolist()))
183 | emotion_label = list(set(data[emotion_name].tolist()))
184 | print("da_label=",da_label)
185 | print("da_label数目", len(da_label))
186 | print("emotion_label=",emotion_label)
187 | print("emotion_label数目", len(emotion_label))
188 | # add index to label
189 | da_id = {}
190 | id = 0
191 | for da in da_label:
192 | da_id[da]=id
193 | id = id+1
194 | print(da_id)
195 |
196 | emotion_id={}
197 | id = 0
198 | for emotion in emotion_label:
199 | emotion_id[emotion]=id
200 | id = id+1
201 | print(emotion_id)
202 |
203 | #count_eda_array = np.zeros((len(da_label),len(emotion_label)))
204 | count_eda_array = np.zeros((len(emotion_label),len(da_label)))
205 | # output the initial array
206 | print(count_eda_array)
207 |
208 | # begin statistics
209 |
210 | for utt_id in data[utt_id_name].tolist():
211 | # for the emotion and DA of each row
212 | # print(data[data[utt_id_name]== utt_id ][da_name])
213 | current_da_id = da_id[ str(data[data[utt_id_name]== utt_id ][da_name].values.astype("str")[0]) ]
214 | current_emotion_id = emotion_id[ str(data[data[utt_id_name]== utt_id ][emotion_name].values.astype("str")[0]) ]
215 | count_eda_array[current_emotion_id, current_da_id] = count_eda_array[current_emotion_id, current_da_id] + 1
216 | print("完成统计后的结果",count_eda_array)
217 |
218 | pro_eda_array=np.zeros((len(emotion_label),len(da_label)))
219 |
220 | da_results = get_row_statistics(data=data,row_name= da_name)
221 | emotion_results = get_row_statistics(data=data,row_name= emotion_name)
222 |
223 |
224 | for da in da_id:
225 | for emotion in emotion_id:
226 | current_da_id = da_id[da]
227 | current_emotion_id = emotion_id[emotion]
228 | current_da_number = da_results["element_stastics"][da]
229 | #current_emotion_number = emotion_results["element_stastics"][emotion]
230 | pro_eda_array[current_emotion_id, current_da_id] = count_eda_array[current_emotion_id, current_da_id]/(current_da_number)
231 | return da_id, emotion_id, count_eda_array, pro_eda_array
232 |
233 |
--------------------------------------------------------------------------------
/prc_baseline/README.md:
--------------------------------------------------------------------------------
1 | # PRC: Personality Recognition in Conversation
2 | # 对话中的个性识别任务
3 | * **Author: Chen Yirong **
4 | * **Date: 2022.06.01**
5 |
6 | **Note:** Code and model will be released in the future!
7 |
8 | ## Task Definition
9 | ## 任务定义
10 | See [https://paperswithcode.com/task/personality-recognition-in-conversation](https://paperswithcode.com/task/personality-recognition-in-conversation)
--------------------------------------------------------------------------------