├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── cosent
├── config.py
├── loading.py
├── model.py
├── test.py
└── train.py
├── dataset
├── .DS_Store
└── STS-B
│ ├── dev.txt
│ ├── test.txt
│ └── train.txt
├── esimcse
├── ESimCSE.py
├── loading.py
├── test.py
└── train.py
├── matching
├── README.md
├── config.py
├── cosent.py
├── esimcse.py
├── log.py
├── promptbert.py
├── retrieval.py
├── sbert.py
├── simcse.py
└── view.py
├── pics
├── .DS_Store
├── cosent_loss.svg
├── esimcse.png
├── esimcse_loss.svg
├── promptbert_loss.svg
├── sbert.png
├── simcse.png
└── simcse_loss.svg
├── promptbert
├── PromptBert.py
├── loading.py
├── test.py
└── train.py
├── sbert
├── config.py
├── loading.py
├── model.py
└── train.py
└── simcse
├── .DS_Store
├── SimCSE.py
├── loading.py
├── test.py
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## sentence_representation_matching
2 |
3 | 该项目主要是文本匹配相关模型,包含使用SimCSE、ESimCSE、PromptBert三种无监督文本匹配模型和SBert、CoSent两种有监督文本匹配模型。
4 |
5 |
6 |
7 | ### 无监督文本匹配
8 |
9 | #### 1. SimCSE
10 |
11 | 利用Transformer Dropout机制,使用两次作为正样本对比,以此来拉近正样本,推开负样本。
12 |
13 | ##### 模型结构:
14 |
15 | 
16 |
17 |
18 |
19 | ##### 损失函数:
20 |
21 | 
22 |
23 |
24 |
25 | ##### 模型效果:
26 |
27 | | data | Pertained | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr |
28 | | :---- | :-------------------------- | :------------: | :-----: | :--------: | -------- | ------------------------------------ |
29 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.1 | 64 | 0.76076 | 0.70924 |
30 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.2 | 64 | 0.75996 | **0.71474** |
31 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.3 | 64 | 0.76518 | 0.71237 |
32 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.1 | 64 | 0.75933 | 0.69070 |
33 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.2 | 64 | 0.76907 | **0.72410** |
34 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.3 | 64 | 0.77203 | 0.72155 |
35 |
36 | 参考:
37 |
38 | 1)https://github.com/princeton-nlp/SimCSE
39 |
40 | 2)https://github.com/KwangKa/SIMCSE_unsup
41 |
42 | 3)https://arxiv.org/pdf/2104.08821.pdf
43 |
44 |
45 |
46 | #### 2. ESimCSE
47 |
48 | 在SimCSE的基础上,通过重复句子中部分词组来构造正样本,同时引入动量对比来增加负样本。
49 |
50 | ##### 模型结构:
51 |
52 | 
53 |
54 |
55 |
56 | ##### 损失函数:
57 |
58 | 
59 |
60 |
61 |
62 | ##### 模型效果:
63 |
64 | | data | Pertained | Dup_rate | Queue_num | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr |
65 | | ----- | :-------------------------- | -------- | --------- | :------------: | ------- | :--------: | -------- | ------------------------------------- |
66 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.1 | 64 | 0.77274 | 0.69639 |
67 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.2 | 64 | 0.77047 | 0.70042 |
68 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.3 | 64 | 0.77963 | **0.72478** |
69 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.1 | 64 | 0.77508 | 0.7206 |
70 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.2 | 64 | 0.77416 | 0.7096 |
71 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.3 | 64 | 0.78093 | **0.72495** |
72 |
73 | 参考:https://arxiv.org/pdf/2109.04380.pdf
74 |
75 |
76 |
77 | #### 3. PromptBert
78 |
79 | 使用Prompt方式来表征语义向量,通过不同模板产生的语义向量构造正样本,同一批次中的其他样本作为负样本。
80 |
81 | ##### 损失函数:
82 |
83 | 
84 |
85 | ```
86 | 本实验使用两个句子模板:
87 | 1)[X],它的意思是[MASK]。
88 | 2)[X],这句话的意思是[MASK]。
89 |
90 | 在计算损失函数时为了消除Prompt模板影响,通过替换模板后的句子[MASK]获取的表征减去模板中[MASK]获取的表征来得到句子向量表征。
91 | ```
92 |
93 |
94 |
95 | ##### 模型效果:
96 |
97 | | data | Pertained | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr |
98 | | ----- | :-------------------------- | :-------: | ------- | :--------: | -------- | ------------------------------------ |
99 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.1 | 32 | 0.78216 | **0.73185** |
100 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.2 | 32 | 0.78362 | 0.73129 |
101 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.3 | 32 | 0.76617 | 0.71597 |
102 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.1 | 32 | 0.79963 | **0.73492** |
103 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.2 | 32 | 0.7764 | 0.72024 |
104 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.3 | 32 | 0.77875 | 0.73153 |
105 |
106 | 参考:https://arxiv.org/pdf/2201.04337.pdf
107 |
108 |
109 |
110 | #### 4. 模型对比
111 |
112 | 通过各组对比试验,并挑选模型最优测试集结果展示如下。
113 |
114 | | Model | Pertained-Bert | Pretrained-Roberta |
115 | | ---------- | :------------: | :----------------: |
116 | | SimCSE | 0.71474 | 0.72410 |
117 | | ESimCSE | 0.72478 | 0.72495 |
118 | | PromptBERT | 0.73185 | 0.73492 |
119 |
120 |
121 |
122 | ### 有监督文本匹配
123 |
124 | #### 1. SBert
125 |
126 | 使用双塔式来微调Bert,MSE损失函数来拟合文本之间的cosine相似度。
127 |
128 | 模型结构:
129 |
130 | 
131 |
132 | 参考:https://www.sbert.net/docs/training/overview.html
133 |
134 |
135 |
136 | #### 2. CoSent
137 |
138 | 构造一个排序式损失函数,即所有正样本对的距离都应该小于负样本对的距离,具体小多少由模型和数据决定,没有一个绝对关系。
139 |
140 | 损失函数:
141 |
142 | 
143 |
144 | 参考:
145 |
146 | 1)https://spaces.ac.cn/archives/8847
147 |
148 | 2)https://github.com/shawroad/CoSENT_Pytorch
149 |
--------------------------------------------------------------------------------
/cosent/config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class Params:
4 | epoches = 50
5 | max_length = 32
6 | batch_size = 64
7 | dropout = 0.15
8 | learning_rate = 3e-5
9 | threshold = 0.5
10 | gradient_accumulation_steps = 100
11 | display_steps = 500
12 | pooler_type = "mean"
13 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large"
14 | cosent_model = "models/cosent.pth"
15 |
--------------------------------------------------------------------------------
/cosent/loading.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from torch.utils.data import Dataset
4 | from transformers import AutoTokenizer
5 | from torch.utils.data import DataLoader
6 | from config import Params
7 |
8 |
9 | class Loader():
10 | def __init__(self, data_file):
11 | data_df = pd.read_csv(data_file)
12 | self.data_df = data_df.fillna("")
13 |
14 |
15 | def get_dataset(self):
16 | sentences = []
17 | labels = []
18 | for _, row in self.data_df.iterrows():
19 | question1 = row['question1']
20 | question2 = row['question2']
21 | label = row['label']
22 | sentences.extend([question1, question2])
23 | labels.extend([label, label])
24 | # rows.append([question1, label])
25 | # rows.append([question2, label])
26 | return sentences, labels
27 |
28 |
29 | class CustomerDataset(Dataset):
30 | def __init__(self, sentence, label, tokenizer):
31 | self.sentence = sentence
32 | self.label = label
33 | self.tokenizer = tokenizer
34 |
35 |
36 | def __len__(self):
37 | return len(self.sentence)
38 |
39 |
40 | def __getitem__(self, index):
41 | input_encodings = self.tokenizer(self.sentence[index])
42 | input_ids = input_encodings['input_ids']
43 | attention_mask = input_encodings['attention_mask']
44 | token_type_ids = input_encodings['token_type_ids']
45 | item = {'input_ids': input_ids,
46 | 'attention_mask': attention_mask,
47 | 'token_type_ids': token_type_ids,
48 | 'label': self.label[index]}
49 | return item
50 |
51 |
52 | def pad_to_maxlen(input_ids, max_len, pad_value=0):
53 | if len(input_ids) >= max_len:
54 | input_ids = input_ids[:max_len]
55 | else:
56 | input_ids = input_ids + [pad_value] * (max_len - len(input_ids))
57 | return input_ids
58 |
59 |
60 | def collate_fn(batch):
61 | # 按batch进行padding获取当前batch中最大长度
62 | max_len = max([len(d['input_ids']) for d in batch])
63 |
64 | # 如果当前最长长度超过设定的全局最大长度,则取全局最大长度
65 | max_len = max_len if max_len <= Params.max_length else Params.max_length
66 |
67 | input_ids, attention_mask, token_type_ids, labels = [], [], [], []
68 |
69 | for item in batch:
70 | input_ids.append(pad_to_maxlen(item['input_ids'], max_len=max_len))
71 | attention_mask.append(pad_to_maxlen(item['attention_mask'], max_len=max_len))
72 | token_type_ids.append(pad_to_maxlen(item['token_type_ids'], max_len=max_len))
73 | labels.append(item['label'])
74 |
75 | all_input_ids = torch.tensor(input_ids, dtype=torch.long)
76 | all_input_mask = torch.tensor(attention_mask, dtype=torch.long)
77 | all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long)
78 | all_label_ids = torch.tensor(labels, dtype=torch.float)
79 | return all_input_ids, all_input_mask, all_segment_ids, all_label_ids
80 |
81 |
82 | if __name__ == "__main__":
83 | data_file = "data/train_dataset.csv"
84 | loader = Loader(data_file)
85 | sentence, label = loader.get_dataset()
86 | print("load question file done!")
87 |
88 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model)
89 | dataset = CustomerDataset(sentence, label, tokenizer)
90 | # print(dataset[0])
91 | train_loader = DataLoader(dataset,
92 | shuffle=False,
93 | batch_size=Params.batch_size,
94 | collate_fn=collate_fn)
95 |
96 | for dl in train_loader:
97 | print(dl)
98 | exit(0)
--------------------------------------------------------------------------------
/cosent/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/1/5
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import AutoConfig, AutoModel
7 | import torch
8 |
9 |
10 | class CoSent(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = AutoConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | if self.pool_type == "cls":
23 | output = self.encoder(input_ids,
24 | attention_mask=attention_mask,
25 | token_type_ids=token_type_ids)
26 | output = output.last_hidden_state[:, 0]
27 | elif self.pool_type == "pooler":
28 | output = self.encoder(input_ids,
29 | attention_mask=attention_mask,
30 | token_type_ids=token_type_ids)
31 | output = output.pooler_output
32 | elif self.pool_type == "mean":
33 | output = self.get_mean_tensor(input_ids, attention_mask)
34 | return output
35 |
36 |
37 | def get_mean_tensor(self, input_ids, attention_mask):
38 | '''
39 | get first and last layer avg tensor
40 | '''
41 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True)
42 | hidden_states = encode_states.hidden_states
43 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
44 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask)
45 | mean_avg_state = (last_avg_state + first_avg_state) / 2
46 | return mean_avg_state
47 |
48 |
49 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
50 | '''
51 | layer_hidden_state: 模型一层表征向量 [B * L * D]
52 | attention_mask: 句子padding位置 [B * L]
53 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
54 | '''
55 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
56 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
57 | sum_mask = input_mask_expanded.sum(1)
58 | sum_mask = torch.clamp(sum_mask, min=1e-9)
59 | avg_embeddings = sum_embeddings / sum_mask
60 | return avg_embeddings
61 |
62 |
63 | def get_avg_tensor2(self, layer_hidden_state, attention_mask):
64 | '''
65 | layer_hidden_state: 模型一层表征向量 [B * L * D]
66 | attention_mask: 句子padding位置 [B * L]
67 | return: 非零位置词语的平均向量 [B * D]
68 | '''
69 | layer_hidden_dim = layer_hidden_state.shape[-1]
70 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
71 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask)
72 | layer_sum_state = layer_attention_state.sum(dim=1)
73 | # print(last_attention_state.shape)
74 |
75 | attention_length_mask = attention_mask.sum(dim=-1)
76 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
77 | # print(attention_length_repeat_mask.shape)
78 |
79 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask)
80 | return layer_avg_state
--------------------------------------------------------------------------------
/cosent/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from transformers import AutoTokenizer
4 | from model import CoSent
5 | from config import Params
6 |
7 |
8 | def read_test_data(test_file):
9 | test_df = pd.read_csv(test_file)
10 | test_df = test_df.fillna("")
11 | return test_df
12 |
13 |
14 | def split_similarity(similarity, threshold):
15 | if similarity >= threshold:
16 | return 1
17 | else:
18 | return 0
19 |
20 |
21 | def calculate_accuracy(data_df, threshold):
22 | data_df['pred'] = data_df.apply(lambda x: split_similarity(x['similarity'], threshold), axis=1)
23 | pred_correct = data_df[data_df['pred'] == data_df['label']]
24 | # pred_error = data_df[data_df['pred'] != data_df['label']]
25 | # print(pred_error)
26 | accuracy = len(pred_correct) / len(data_df)
27 | return accuracy, len(pred_correct), len(data_df)
28 |
29 |
30 | class CoSentRetrieval(object):
31 | def __init__(self, pretrained_model_path, cosent_path):
32 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
33 | model = CoSent(Params.pretrained_model, Params.pooler_type, Params.dropout)
34 | self.checkpoint = torch.load(cosent_path, map_location='cpu')
35 | model.load_state_dict(self.checkpoint['model_state_dict'])
36 | model.eval()
37 | self.model = model
38 |
39 |
40 | def print_checkpoint_info(self):
41 | loss = self.checkpoint['loss']
42 | epoch = self.checkpoint['epoch']
43 | model_info = {'loss': loss, 'epoch': epoch}
44 | return model_info
45 |
46 |
47 | def calculate_sentence_embedding(self, sentence):
48 | device = "cpu"
49 | input_encodings = self.tokenizer(sentence,
50 | padding=True,
51 | truncation=True,
52 | max_length=Params.max_length,
53 | return_tensors='pt')
54 | sentence_embedding = self.model(input_encodings['input_ids'].to(device),
55 | input_encodings['attention_mask'].to(device),
56 | input_encodings['token_type_ids'].to(device))
57 | return sentence_embedding
58 |
59 |
60 | def calculate_sentence_similarity(self, sentence1, sentence2):
61 | sentence1 = sentence1.strip()
62 | sentence2 = sentence2.strip()
63 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
64 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
65 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
66 | similarity = float(similarity.item())
67 | return similarity
68 |
69 |
70 | cosent_retrieval = CoSentRetrieval(Params.pretrained_model, Params.cosent_model)
71 | model_info = cosent_retrieval.print_checkpoint_info()
72 | print("load model done, model_info: {}".format(model_info))
73 | test_file = "data/test_dataset.csv"
74 | test_df = read_test_data(test_file)
75 | results = []
76 | for rid, row in test_df.iterrows():
77 | question1 = row['question1']
78 | question2 = row['question2']
79 | label = row['label']
80 | similarity = cosent_retrieval.calculate_sentence_similarity(question1, question2)
81 | item = {'question1': question1,
82 | 'question2': question2,
83 | 'similarity': similarity,
84 | 'label': label}
85 | if rid % 100 == 0:
86 | print("rid: {}, item: {}".format(rid, item))
87 | results.append(item)
88 | print("prediction done!")
89 |
90 |
91 | pred_df = pd.DataFrame(results)
92 | pred_file = "results/test_pred.csv"
93 | pred_df.to_csv(pred_file)
94 | max_acc = 0
95 | for t in range(50, 80, 1):
96 | t = t / 100
97 | acc, correct, num = calculate_accuracy(pred_df, t)
98 | if acc > max_acc:
99 | max_acc = acc
100 | print(t, acc, correct, num)
101 |
102 | print(max_acc)
--------------------------------------------------------------------------------
/cosent/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import DataLoader
4 | from transformers import AutoTokenizer
5 | from transformers import AdamW, get_linear_schedule_with_warmup
6 | from sklearn.metrics import accuracy_score
7 | from tqdm import tqdm
8 | from loading import Loader, CustomerDataset, collate_fn
9 | from config import Params
10 | from model import CoSent
11 | import os
12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
13 |
14 |
15 | def compute_sim(y_pred):
16 | # 1. 对输出的句子向量进行l2归一化 后面只需要对应为相乘 就可以得到cos值了
17 | norms = (y_pred ** 2).sum(axis=1, keepdims=True) ** 0.5
18 | # y_pred = y_pred / torch.clip(norms, 1e-8, torch.inf)
19 | y_pred = y_pred / norms
20 |
21 | # 2. 奇偶向量相乘
22 | sim = torch.sum(y_pred[::2] * y_pred[1::2], dim=1)
23 | return sim
24 |
25 |
26 | def compute_acc(y_true, y_sim, threshold):
27 | # 1. 取出真实的标签(每两行是一个文本匹配对)
28 | y_true = y_true[::2] # tensor([1, 0, 1]) 真实的标签
29 |
30 | # 2. 根据阈值分割
31 | y_pred_label = (y_sim >= threshold).float()
32 | acc = accuracy_score(y_pred_label.detach().cpu().numpy(), y_true.cpu().numpy())
33 | return acc
34 |
35 |
36 | def compute_loss(y_true, y_sim):
37 | # 1. 取出真实的标签(每两行是一个文本匹配对)
38 | y_true = y_true[::2] # tensor([1, 0, 1]) 真实的标签
39 |
40 | # 2. 根据句子间相似度进行放缩
41 | y_sim = y_sim * 20
42 |
43 | # 3. 取出负例-正例的差值
44 | y_sim = y_sim[:, None] - y_sim[None, :] # 这里是算出所有位置 两两之间余弦的差值
45 | # 矩阵中的第i行j列 表示的是第i个余弦值-第j个余弦值
46 | y_true = y_true[:, None] < y_true[None, :] # 取出负例-正例的差值
47 | y_true = y_true.float()
48 | y_sim = y_sim - (1 - y_true) * 1e12
49 | y_sim = y_sim.view(-1)
50 | if torch.cuda.is_available():
51 | y_sim = torch.cat((torch.tensor([0]).float().cuda(), y_sim), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1
52 | else:
53 | y_sim = torch.cat((torch.tensor([0]).float(), y_sim), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1
54 |
55 | return torch.logsumexp(y_sim, dim=0)
56 |
57 |
58 | data_file = "data/train_dataset.csv"
59 | loader = Loader(data_file)
60 | sentence, label = loader.get_dataset()
61 | print("load question file done!")
62 |
63 | # load tokenizer
64 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model)
65 | train_dataset = CustomerDataset(sentence, label, tokenizer)
66 | train_loader = DataLoader(train_dataset,
67 | shuffle=False,
68 | batch_size=Params.batch_size,
69 | collate_fn=collate_fn)
70 | print("tokenize all batch done!")
71 |
72 | total_steps = len(train_loader) * Params.epoches
73 | train_optimization_steps = int(len(train_dataset) / Params.batch_size) * Params.epoches
74 |
75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76 | # load model
77 | model = CoSent(Params.pretrained_model, Params.pooler_type, Params.dropout)
78 | model.to(device)
79 |
80 | param_optimizer = list(model.named_parameters())
81 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
82 | optimizer_grouped_parameters = [
83 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
84 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
85 | ]
86 |
87 | # load optimizer and scheduler
88 | optimizer = AdamW(optimizer_grouped_parameters, lr=Params.learning_rate)
89 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.05 * total_steps,
90 | num_training_steps=total_steps)
91 |
92 | best_loss = 1000000
93 | print("start training...")
94 | for epoch in range(Params.epoches):
95 | model.train()
96 | epoch_losses = []
97 | epoch_acces = []
98 | step = 0
99 |
100 | for batch in tqdm(train_loader):
101 | input_ids, input_mask, segment_ids, label_ids = batch
102 | input_ids = input_ids.to(device)
103 | input_mask = input_mask.to(device)
104 | segment_ids = segment_ids.to(device)
105 | label_ids = label_ids.to(device)
106 |
107 | output = model(input_ids, input_mask, segment_ids)
108 | y_sim = compute_sim(output)
109 | loss = compute_loss(label_ids, y_sim)
110 | acc = compute_acc(label_ids, y_sim, Params.threshold)
111 |
112 | # update gradient and scheduler
113 | optimizer.zero_grad()
114 | loss.backward()
115 | optimizer.step()
116 | scheduler.step()
117 |
118 | if step % Params.display_steps == 0:
119 | print("Epoch: {}, Step: {}, Batch loss: {}, acc: {}".format(epoch, step, loss.item(), acc), flush=True)
120 | epoch_losses.append(loss.item())
121 | epoch_acces.append(acc)
122 | step += 1
123 |
124 | avg_epoch_loss = np.mean(epoch_losses)
125 | avg_epoch_acc = np.mean(epoch_acces)
126 | print("Epoch: {}, avg loss: {}, acc: {}".format(epoch, avg_epoch_loss, avg_epoch_acc), flush=True)
127 | if avg_epoch_loss < best_loss:
128 | best_loss = avg_epoch_loss
129 | print("Epoch: {}, best loss: {}, acc: {} save model".format(epoch, best_loss, avg_epoch_acc), flush=True)
130 | torch.save({
131 | 'epoch': epoch,
132 | 'model_state_dict': model.state_dict(),
133 | 'loss': avg_epoch_loss,
134 | 'acc': avg_epoch_acc,
135 | }, Params.cosent_model)
--------------------------------------------------------------------------------
/dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/dataset/.DS_Store
--------------------------------------------------------------------------------
/esimcse/ESimCSE.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import AutoConfig, AutoModel
7 | import torch
8 |
9 |
10 | class ESimCSE(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = AutoConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "avg_first_last", "avg_last_two"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | output = self.encoder(input_ids,
23 | attention_mask=attention_mask,
24 | token_type_ids=token_type_ids,
25 | output_hidden_states=True)
26 | hidden_states = output.hidden_states
27 | if self.pool_type == "cls":
28 | output = output.last_hidden_state[:, 0]
29 | elif self.pool_type == "pooler":
30 | output = output.pooler_output
31 | elif self.pool_type == "avg_first_last":
32 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask)
33 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
34 | output = (top_first_state + last_first_state) / 2
35 | else:
36 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
37 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask)
38 | output = (last_first_state + last_second_state) / 2
39 | return output
40 |
41 |
42 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
43 | '''
44 | layer_hidden_state: 模型一层表征向量 [B * L * D]
45 | attention_mask: 句子padding位置 [B * L]
46 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
47 | '''
48 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
49 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
50 | sum_mask = input_mask_expanded.sum(1)
51 | sum_mask = torch.clamp(sum_mask, min=1e-9)
52 | avg_embeddings = sum_embeddings / sum_mask
53 | return avg_embeddings
--------------------------------------------------------------------------------
/esimcse/loading.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class MatchingDataSet:
9 | def read_train_file(self, trainfile, devfile, testfile, filetype):
10 | sents = []
11 | if filetype == "txt":
12 | with open(trainfile, 'r') as f:
13 | for line in f.readlines():
14 | _, s1, s2, _ = line.strip().split(u"||")
15 | sents.append(s1)
16 | sents.append(s2)
17 | with open(devfile, 'r') as f:
18 | for line in f.readlines():
19 | _, s1, s2, _ = line.strip().split(u"||")
20 | sents.append(s1)
21 | sents.append(s2)
22 | with open(testfile, 'r') as f:
23 | for line in f.readlines():
24 | _, s1, s2, _ = line.strip().split(u"||")
25 | sents.append(s1)
26 | sents.append(s2)
27 | return sents
28 |
29 | def read_eval_file(self, file, filetype):
30 | sents = []
31 | if filetype == "txt":
32 | with open(file, 'r') as f:
33 | for line in f.readlines():
34 | _, s1, s2, s = line.strip().split(u"||")
35 | item = {'sent1': s1,
36 | 'sent2': s2,
37 | 'score': float(s)}
38 | sents.append(item)
39 | return sents
40 |
41 |
42 | if __name__ == "__main__":
43 | trainfile = "../dataset/STS-B/train.txt"
44 | devfile = "../dataset/STS-B/dev.txt"
45 | testfile = "../dataset/STS-B/test.txt"
46 | match_dataset = MatchingDataSet()
47 |
48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt")
49 | print(train_list[:5])
50 |
51 | train_lengths = [len(sentence) for sentence in train_list]
52 | max_len = max(train_lengths)
53 |
54 |
55 | dev_list = match_dataset.read_eval_file(devfile, "txt")
56 | dev_sen1_length = [len(d['sent1']) for d in dev_list]
57 | dev_sen2_length = [len(d['sent2']) for d in dev_list]
58 | max_sen1 = max(dev_sen1_length)
59 | max_sen2 = max(dev_sen2_length)
60 | print(max_len, max_sen1, max_sen2)
61 | # dev_loader = DataLoader(dev_list,
62 | # batch_size=8)
63 | # for batch in dev_loader:
64 | # print(batch)
65 | # exit(0)
--------------------------------------------------------------------------------
/esimcse/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/19
3 | # @Author : Maciel
4 |
5 | from loading import MatchingDataSet
6 | import torch
7 | import torch.nn.functional as F
8 | import scipy.stats
9 | from torch.utils.data import DataLoader
10 | from ESimCSE import ESimCSE
11 | from transformers import AutoTokenizer
12 | import os
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14 |
15 |
16 | def eval(model, tokenizer, test_loader, device, max_length):
17 | model.eval()
18 | model.to(device)
19 |
20 | all_sims, all_scores = [], []
21 | with torch.no_grad():
22 | for data in test_loader:
23 | sent1 = data['sent1']
24 | sent2 = data['sent2']
25 | score = data['score']
26 | sent1_encoding = tokenizer(sent1,
27 | padding=True,
28 | truncation=True,
29 | max_length=max_length,
30 | return_tensors='pt')
31 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()}
32 | sent2_encoding = tokenizer(sent2,
33 | padding=True,
34 | truncation=True,
35 | max_length=max_length,
36 | return_tensors='pt')
37 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()}
38 |
39 | sent1_output = model(**sent1_encoding)
40 | sent2_output = model(**sent2_encoding)
41 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
42 | all_sims += sim_score
43 | all_scores += score.tolist()
44 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
45 | return corr
46 |
47 |
48 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length):
49 | match_dataset = MatchingDataSet()
50 | testfile_type = "txt"
51 | test_list = match_dataset.read_eval_file(testfile, testfile_type)
52 | print("test samples num: {}".format(len(test_list)))
53 |
54 | test_loader = DataLoader(test_list,
55 | batch_size=128)
56 | print("test batch num: {}".format(len(test_loader)))
57 |
58 | tokenizer = AutoTokenizer.from_pretrained(pretrained)
59 | model = ESimCSE(pretrained, pool_type, dropout_rate)
60 | model.load_state_dict(torch.load(model_path)['model_state_dict'])
61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62 |
63 | test_corr = eval(model, tokenizer, test_loader, device, max_length)
64 | print("test corr: {}".format(test_corr))
65 |
66 |
67 | if __name__ == "__main__":
68 | testfile = "../dataset/STS-B/test.txt"
69 | pretrained = "hfl/chinese-roberta-wwm-ext"
70 | pool_type = "avg_first_last"
71 | dropout_rate = 0.1
72 | max_length = 128
73 | model_path = "../models/esimcse_roberta_stsb.pth"
74 |
75 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length)
76 |
--------------------------------------------------------------------------------
/esimcse/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/20
3 | # @Author : Maciel
4 |
5 |
6 | from transformers import AutoTokenizer
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | import scipy.stats
12 | from loading import MatchingDataSet
13 | from ESimCSE import ESimCSE
14 | import random
15 | import argparse
16 | from loguru import logger
17 | import copy
18 | import os
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path")
25 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path")
26 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path")
27 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type")
28 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext", help="huggingface pretrained model")
29 | parser.add_argument("--model_out", type=str, default="../models/esimcse_roberta_stsb.pth", help="model output path")
30 | parser.add_argument("--dup_rate", type=float, default=0.2, help="repeat word probability")
31 | parser.add_argument("--queue_size", type=int, default=0.5, help="negative queue num / batch size")
32 | parser.add_argument("--momentum", type=float, default=0.995, help="momentum parameter")
33 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length")
34 | parser.add_argument("--batch_size", type=int, default=64, help="batch size")
35 | parser.add_argument("--epochs", type=int, default=10, help="epochs")
36 | parser.add_argument("--lr", type=float, default=3e-5, help="learning rate")
37 | parser.add_argument("--tao", type=float, default=0.05, help="temperature")
38 | parser.add_argument("--device", type=str, default="cuda", help="device")
39 | parser.add_argument("--display_interval", type=int, default=100, help="display interval")
40 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type")
41 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate")
42 | parser.add_argument("--task", type=str, default="esimcse", help="task name")
43 | args = parser.parse_args()
44 | return args
45 |
46 |
47 | def compute_loss(query, key, queue, tao=0.05):
48 | '''
49 | @function: 计算对比损失函数
50 |
51 | @input:
52 | query: tensor,查询原句向量
53 | key: tensor,增强原句向量
54 | queue: tensor,历史队列句向量
55 | tao: float,温度系数,超参数,默认0.05
56 |
57 | @return: loss(tensor),损失函数值
58 | '''
59 | # N: batch, D: dim
60 | N, D = query.shape[0], query.shape[1]
61 |
62 | # calculate positive similarity
63 | pos = torch.exp(torch.div(torch.bmm(query.view(N,1,D), key.view(N,D,1)).view(N,1),tao))
64 |
65 | # calculate inner_batch similarity
66 | batch_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N,D),torch.t(key)),tao)),dim=1)
67 | # calculate inner_queue similarity
68 | queue_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N,D),torch.t(queue)),tao)),dim=1)
69 |
70 | denominator = batch_all + queue_all
71 |
72 | loss = torch.mean(-torch.log(torch.div(pos, denominator)))
73 | return loss
74 |
75 |
76 | def construct_queue(args, train_loader, tokenizer, key_encoder):
77 | flag = 0
78 | queue_num = int(args.queue_size * args.batch_size)
79 | queue = None
80 | while True:
81 | with torch.no_grad():
82 | for pid, data in enumerate(train_loader):
83 | # 和初始数据不同的数据作为反例
84 | if pid < 10:
85 | continue
86 | query_encodings = tokenizer(data,
87 | padding=True,
88 | truncation=True,
89 | max_length=args.max_length,
90 | return_tensors='pt')
91 | query_encodings = {key: value.to(args.device) for key, value in query_encodings.items()}
92 | query_embedding = key_encoder(**query_encodings)
93 | if queue is None:
94 | queue = query_embedding
95 | else:
96 | if queue.shape[0] < queue_num:
97 | queue = torch.cat((queue, query_embedding), 0)
98 | else:
99 | flag = 1
100 | if flag == 1:
101 | break
102 | if flag == 1:
103 | break
104 | queue = queue[-queue_num:]
105 | queue = torch.div(queue, torch.norm(queue, dim=1).reshape(-1, 1))
106 | return queue
107 |
108 |
109 | def repeat_word(tokenizer, sentence, dup_rate):
110 | '''
111 | @function: 重复句子中的部分token
112 |
113 | @input:
114 | sentence: string,输入语句
115 |
116 | @return:
117 | dup_sentence: string,重复token后生成的句子
118 | '''
119 | word_tokens = tokenizer.tokenize(sentence)
120 |
121 | # dup_len ∈ [0, max(2, int(dup_rate ∗ N))]
122 | max_len = max(2, int(dup_rate * len(word_tokens)))
123 | # 防止随机挑选的数值大于token数量
124 | dup_len = min(random.choice(range(max_len+1)), len(word_tokens))
125 |
126 | random_indices = random.sample(range(len(word_tokens)), dup_len)
127 | # print(max_len, dup_len, random_indices)
128 |
129 | dup_word_tokens = []
130 | for index, word in enumerate(word_tokens):
131 | dup_word_tokens.append(word)
132 | if index in random_indices and "#" not in word:
133 | dup_word_tokens.append(word)
134 | dup_sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(dup_word_tokens)).replace(" ", "")
135 | # dup_sentence = "".join(dup_word_tokens)
136 | return dup_sentence
137 |
138 |
139 | def eval(model, tokenizer, dev_loader, args):
140 | model.eval()
141 | model.to(args.device)
142 |
143 | all_sims, all_scores = [], []
144 | with torch.no_grad():
145 | for data in dev_loader:
146 | sent1 = data['sent1']
147 | sent2 = data['sent2']
148 | score = data['score']
149 | sent1_encoding = tokenizer(sent1,
150 | padding=True,
151 | truncation=True,
152 | max_length=args.max_length,
153 | return_tensors='pt')
154 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()}
155 | sent2_encoding = tokenizer(sent2,
156 | padding=True,
157 | truncation=True,
158 | max_length=args.max_length,
159 | return_tensors='pt')
160 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()}
161 |
162 | sent1_output = model(**sent1_encoding)
163 | sent2_output = model(**sent2_encoding)
164 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
165 | all_sims += sim_score
166 | all_scores += score.tolist()
167 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
168 | return corr
169 |
170 |
171 | def train(args):
172 | train_file = args.trainfile
173 | dev_file = args.devfile
174 | test_file = args.testfile
175 | file_type = args.filetype
176 | queue_num = int(args.queue_size * args.batch_size)
177 | match_dataset = MatchingDataSet()
178 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type)
179 | dev_list = match_dataset.read_eval_file(dev_file, file_type)
180 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list)))
181 |
182 | train_loader = DataLoader(train_list,
183 | batch_size=args.batch_size,
184 | shuffle=True)
185 | dev_loader = DataLoader(dev_list,
186 | batch_size=args.batch_size)
187 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader)))
188 |
189 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
190 | query_encoder = ESimCSE(args.pretrained, args.pool_type, args.dropout_rate)
191 | key_encoder = copy.deepcopy(query_encoder)
192 | query_encoder.train()
193 | query_encoder.to(args.device)
194 | key_encoder.eval()
195 | key_encoder.to(args.device)
196 |
197 | optimizer = torch.optim.AdamW(query_encoder.parameters(), lr=args.lr)
198 | # 构造反例样本队列
199 | queue_embeddings = construct_queue(args, train_loader, tokenizer, key_encoder)
200 |
201 | batch_idx = 0
202 | best_corr = 0
203 | best_loss = 1000000
204 | for epoch in range(args.epochs):
205 | epoch_losses = []
206 | for data in train_loader:
207 | optimizer.zero_grad()
208 | # 构造正例样本
209 | key_data = [repeat_word(tokenizer, sentence, args.dup_rate) for sentence in data]
210 |
211 | query_encodings = tokenizer(data,
212 | padding=True,
213 | truncation=True,
214 | max_length=args.max_length,
215 | return_tensors='pt')
216 | query_encodings = {key: value.to(args.device) for key, value in query_encodings.items()}
217 | key_encodings = tokenizer(key_data,
218 | padding=True,
219 | truncation=True,
220 | max_length=args.max_length,
221 | return_tensors='pt')
222 | key_encodings = {key: value.to(args.device) for key, value in key_encodings.items()}
223 |
224 | query_embeddings = query_encoder(**query_encodings)
225 | key_embeddings = key_encoder(**key_encodings).detach()
226 |
227 | # 对表征进行归一化,便于后面相似度计算
228 | query_embeddings = F.normalize(query_embeddings, dim=1)
229 | key_embeddings = F.normalize(key_embeddings, dim=1)
230 |
231 | batch_loss = compute_loss(query_embeddings, key_embeddings, queue_embeddings, args.tao)
232 | epoch_losses.append(batch_loss.item())
233 | # print(batch_idx, batch_loss.item())
234 |
235 | batch_loss.backward()
236 | optimizer.step()
237 |
238 | # 更新队列中负样本表征
239 | # queue_embeddings = torch.cat((queue_embeddings, query_embeddings.detach()), 0)
240 | queue_embeddings = torch.cat((queue_embeddings, key_embeddings), 0)
241 | queue_embeddings = queue_embeddings[-queue_num:, :]
242 |
243 | # 更新key编码器的动量
244 | for query_params, key_params in zip(query_encoder.parameters(), key_encoder.parameters()):
245 | key_params.data.copy_(args.momentum * key_params + (1-args.momentum) * query_params)
246 | key_params.requires_grad = False
247 |
248 | if batch_idx % args.display_interval == 0:
249 | logger.info("Epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item()))
250 | batch_idx += 1
251 |
252 | avg_epoch_loss = np.mean(epoch_losses)
253 | dev_corr = eval(query_encoder, tokenizer, dev_loader, args)
254 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr))
255 | # if avg_epoch_loss <= best_loss and dev_corr >= best_corr:
256 | if dev_corr >= best_corr:
257 | best_corr = dev_corr
258 | best_loss = avg_epoch_loss
259 | torch.save({
260 | 'epoch': epoch,
261 | 'batch': batch_idx,
262 | 'model_state_dict': query_encoder.state_dict(),
263 | 'loss': best_loss,
264 | 'corr': best_corr
265 | }, args.model_out)
266 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr))
267 |
268 |
269 | if __name__ == "__main__":
270 | args = parse_args()
271 | logger.info("args: {}".format(args))
272 | train(args)
273 |
--------------------------------------------------------------------------------
/matching/README.md:
--------------------------------------------------------------------------------
1 | ## Sentence_Matching
2 |
3 | 该项目主要是文本匹配服务,根据channel不同调用不同模型来计算两句话的相似度。
4 |
5 | 1. channel为0:使用SimCSE模型
6 | 2. channel为1:使用ESimCSE模型,并使用同一句话通过dropout作为正样本对,引入动量对比增加负样本对。
7 | 3. channel为2:使用ESimCSE模型,并重复一句话中部分词组构造正样本对,引入动量对比增加负样本对。
8 | 4. channel为3:同时使用ESimCSE和SimCSE模型,加上多任务损失函数。
9 | 5. channel为4:使用PromptBERT模型。
10 | 6. channel为5:使用SBert模型。
11 | 7. channel为6:使用CoSent模型。
12 |
--------------------------------------------------------------------------------
/matching/config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class Params:
4 | esimcse_same_dropout = 0.15
5 | esimcse_repeat_dropout = 0.15
6 | esimcse_multi_dropout = 0.15
7 | promptbert_dropout = 0.1
8 | simcse_dropout = 0.3
9 | sbert_dropout = 0.2
10 | cosent_dropout = 0.15
11 | max_length = 32
12 | pool_type = "pooler"
13 | sbert_pool_type = "mean"
14 | cosent_pool_type = "mean"
15 | mask_token = "[MASK]"
16 | replace_token = "[X]"
17 | # prompt_templates = ['“[UNK]”,它的意思是[MASK]。', '“[UNK]”,这句话的意思是[MASK]。']
18 | prompt_templates = ['"{}",它的意思是[MASK]。'.format(replace_token), '"{}",这句话的意思是[MASK]。'.format(replace_token)]
19 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large"
20 | esimcse_repeat_model = "models/esimcse_0.32_0.15_160.pth"
21 | esimcse_same_model = "models/esimcse_0.15_64.pth"
22 | esimcse_multi_model = "models/esimcse_multi_0.15_64.pth"
23 | promptbert_model = "models/promptbert_1231.pth"
24 | simcse_model = "models/simcse_1226.pth"
25 | sbert_model = "models/sbert_0106.pth"
26 | cosent_model = "models/cosent_0119.pth"
--------------------------------------------------------------------------------
/matching/cosent.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/1/5
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import AutoConfig, AutoModel
7 | import torch
8 |
9 |
10 | class CoSent(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = AutoConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | if self.pool_type == "cls":
23 | output = self.encoder(input_ids,
24 | attention_mask=attention_mask,
25 | token_type_ids=token_type_ids)
26 | output = output.last_hidden_state[:, 0]
27 | elif self.pool_type == "pooler":
28 | output = self.encoder(input_ids,
29 | attention_mask=attention_mask,
30 | token_type_ids=token_type_ids)
31 | output = output.pooler_output
32 | elif self.pool_type == "mean":
33 | output = self.get_mean_tensor(input_ids, attention_mask)
34 | return output
35 |
36 |
37 | def get_mean_tensor(self, input_ids, attention_mask):
38 | '''
39 | get first and last layer avg tensor
40 | '''
41 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True)
42 | hidden_states = encode_states.hidden_states
43 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
44 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask)
45 | mean_avg_state = (last_avg_state + first_avg_state) / 2
46 | return mean_avg_state
47 |
48 |
49 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
50 | '''
51 | layer_hidden_state: 模型一层表征向量 [B * L * D]
52 | attention_mask: 句子padding位置 [B * L]
53 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
54 | '''
55 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
56 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
57 | sum_mask = input_mask_expanded.sum(1)
58 | sum_mask = torch.clamp(sum_mask, min=1e-9)
59 | avg_embeddings = sum_embeddings / sum_mask
60 | return avg_embeddings
61 |
62 |
63 | def get_avg_tensor2(self, layer_hidden_state, attention_mask):
64 | '''
65 | layer_hidden_state: 模型一层表征向量 [B * L * D]
66 | attention_mask: 句子padding位置 [B * L]
67 | return: 非零位置词语的平均向量 [B * D]
68 | '''
69 | layer_hidden_dim = layer_hidden_state.shape[-1]
70 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
71 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask)
72 | layer_sum_state = layer_attention_state.sum(dim=1)
73 | # print(last_attention_state.shape)
74 |
75 | attention_length_mask = attention_mask.sum(dim=-1)
76 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
77 | # print(attention_length_repeat_mask.shape)
78 |
79 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask)
80 | return layer_avg_state
--------------------------------------------------------------------------------
/matching/esimcse.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/12/16
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import BertConfig, BertModel
7 |
8 |
9 | class ESimCSE(nn.Module):
10 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", dropout_prob=0.15):
11 | super().__init__()
12 | conf = BertConfig.from_pretrained(pretrained)
13 | conf.attention_probs_dropout_prob = dropout_prob
14 | conf.hidden_dropout_prob = dropout_prob
15 | self.encoder = BertModel.from_pretrained(pretrained, config=conf)
16 | self.fc = nn.Linear(conf.hidden_size, conf.hidden_size)
17 |
18 |
19 | def forward(self, input_ids, attention_mask, token_type_ids):
20 | output = self.encoder(input_ids,
21 | attention_mask=attention_mask,
22 | token_type_ids=token_type_ids)
23 | output = output.last_hidden_state[:, 0]
24 | output = self.fc(output)
25 | return output
--------------------------------------------------------------------------------
/matching/log.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 |
3 |
4 | logger.add("logs/runtime.log")
--------------------------------------------------------------------------------
/matching/promptbert.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoConfig, AutoTokenizer
2 | import torch
3 | import torch.nn as nn
4 | from config import Params
5 |
6 |
7 | class PromptBERT(nn.Module):
8 | def __init__(self, pretrained_model_path, dropout_prob, mask_id):
9 | super().__init__()
10 | conf = AutoConfig.from_pretrained(pretrained_model_path)
11 | conf.attention_probs_dropout_prob = dropout_prob
12 | conf.hidden_dropout_prob = dropout_prob
13 | self.encoder = AutoModel.from_pretrained(pretrained_model_path, config=conf)
14 | self.mask_id = mask_id
15 |
16 |
17 | def forward(self, prompt_input_ids, prompt_attention_mask, prompt_token_type_ids, template_input_ids, template_attention_mask, template_token_type_ids):
18 | '''
19 | @function: 计算prompt mask标签表征向量hi和模板向量表征h^i之间的差
20 |
21 | @input:
22 | prompt_input_ids: prompt句子输入id
23 | prompt_attention_mask: prompt句子注意力矩阵
24 | prompt_token_type_ids: prompt句子token类型id
25 | template_input_ids: 模板句子输入id
26 | template_attention_mask: 模板句子注意力矩阵
27 | template_token_type_ids: 模板句子token类型id
28 |
29 | @return: sentence_embedding: 句子表征向量
30 | '''
31 | prompt_mask_embedding = self.calculate_mask_embedding(prompt_input_ids, prompt_attention_mask, prompt_token_type_ids)
32 | template_mask_embedding = self.calculate_mask_embedding(template_input_ids, template_attention_mask, template_token_type_ids)
33 | sentence_embedding = prompt_mask_embedding - template_mask_embedding
34 | return sentence_embedding
35 |
36 |
37 | def calculate_mask_embedding(self, input_ids, attention_mask, token_type_ids):
38 | # print("input_ids: ", input_ids)
39 | output = self.encoder(input_ids,
40 | attention_mask=attention_mask,
41 | token_type_ids=token_type_ids)
42 | token_embeddings = output[0]
43 | mask_index = (input_ids == self.mask_id).long()
44 | # print("mask_index: ", mask_index)
45 | mask_embedding = self.get_mask_embedding(token_embeddings, mask_index)
46 | return mask_embedding
47 |
48 |
49 | def get_mask_embedding(self, token_embeddings, mask_index):
50 | '''
51 | @function: 获取[mask]标签的embedding输出
52 |
53 | @input:
54 | token_embeddings: Tensor, 编码层最后一层token输出
55 | mask_index: Tensor, mask标签位置
56 |
57 | @return: mask_embedding: Tensor, mask标签embedding
58 | '''
59 | input_mask_expanded = mask_index.unsqueeze(-1).expand(token_embeddings.size()).float()
60 | # print("input_mask_expanded: ", input_mask_expanded)
61 | # print("input mask expaned shape: ", input_mask_expanded.shape)
62 | mask_embedding = torch.sum(token_embeddings * input_mask_expanded, 1)
63 | return mask_embedding
64 |
65 |
66 | if __name__ == "__main__":
67 | prompt_templates = '[UNK],这句话的意思是[MASK]'
68 | # sentence = "天气很好"
69 | sentence = '"[X][X][X]",这句话的意思是[MASK]。'
70 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model_path)
71 | special_token_dict = {'additional_special_tokens': ['[X]']}
72 | tokenizer.add_special_tokens(special_token_dict)
73 | sen_tokens = tokenizer.tokenize(sentence)
74 | sen_encodings = tokenizer(sentence,
75 | return_tensors='pt')
76 | print(sen_tokens)
77 | print(sen_encodings)
78 | exit(0)
79 |
80 | mask_id = tokenizer.convert_tokens_to_ids(Params.mask_token)
81 |
82 | model = PromptBERT(Params.pretrained_model_path, Params.dropout, mask_id)
83 | model.train()
84 |
85 | while True:
86 | print("input your sentence:")
87 | sentences = input()
88 | sentence_list = sentences.split(";")
89 |
90 | prompt_lines, template_lines = [], []
91 | for sentence in sentence_list:
92 | words_num = len(tokenizer.tokenize(sentence))
93 | prompt_line = prompt_templates.replace('[UNK]', sentence)
94 | template_line = prompt_templates.replace('[UNK]', '[UNK]'*words_num)
95 | print("prompt_line: {}, template_line: {}".format(prompt_line, template_line))
96 | prompt_lines.append(prompt_line)
97 | template_lines.append(template_line)
98 |
99 | prompt_encodings = tokenizer(list(prompt_lines),
100 | padding=True,
101 | truncation=True,
102 | max_length=Params.max_length,
103 | return_tensors='pt')
104 | template_encodings = tokenizer(list(template_lines),
105 | padding=True,
106 | truncation=True,
107 | max_length=Params.max_length,
108 | return_tensors='pt')
109 |
110 | prompt_mask_embedding = model.calculate_mask_embedding(prompt_encodings['input_ids'],
111 | prompt_encodings['attention_mask'],
112 | prompt_encodings['token_type_ids'])
113 | template_mask_embedding = model.calculate_mask_embedding(template_encodings['input_ids'],
114 | template_encodings['attention_mask'],
115 | template_encodings['token_type_ids'])
116 | # print(prompt_mask_embedding.shape)
117 |
--------------------------------------------------------------------------------
/matching/retrieval.py:
--------------------------------------------------------------------------------
1 | from simcse import SimCSE
2 | from esimcse import ESimCSE
3 | from promptbert import PromptBERT
4 | from sbert import SBERT
5 | from cosent import CoSent
6 | from config import Params
7 | from log import logger
8 | import torch
9 | from transformers import AutoTokenizer
10 |
11 |
12 | class SimCSERetrieval(object):
13 | def __init__(self, pretrained_model_path, simcse_path, pool_type, dropout):
14 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
15 | model = SimCSE(Params.pretrained_model, pool_type, dropout)
16 | self.checkpoint = torch.load(simcse_path, map_location='cpu')
17 | model.load_state_dict(self.checkpoint['model_state_dict'])
18 | model.eval()
19 | self.model = model
20 |
21 |
22 | def print_checkpoint_info(self):
23 | loss = self.checkpoint['loss']
24 | epoch = self.checkpoint['epoch']
25 | model_info = {'loss': loss, 'epoch': epoch}
26 | return model_info
27 |
28 |
29 | def calculate_sentence_embedding(self, sentence):
30 | device = "cpu"
31 | input_encodings = self.tokenizer(sentence,
32 | padding=True,
33 | truncation=True,
34 | max_length=Params.max_length,
35 | return_tensors='pt')
36 | sentence_embedding = self.model(input_encodings['input_ids'].to(device),
37 | input_encodings['attention_mask'].to(device),
38 | input_encodings['token_type_ids'].to(device))
39 | return sentence_embedding
40 |
41 |
42 | def calculate_sentence_similarity(self, sentence1, sentence2):
43 | sentence1 = sentence1.strip()
44 | sentence2 = sentence2.strip()
45 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
46 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
47 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
48 | similarity = float(similarity.item())
49 | return similarity
50 |
51 |
52 | class ESimCSERetrieval(object):
53 | def __init__(self, pretrained_model_path, esimcse_path, dropout):
54 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
55 | model = ESimCSE(Params.pretrained_model, dropout)
56 | self.checkpoint = torch.load(esimcse_path, map_location='cpu')
57 | model.load_state_dict(self.checkpoint['model_state_dict'])
58 | model.eval()
59 | self.model = model
60 |
61 |
62 | def print_checkpoint_info(self):
63 | loss = self.checkpoint['loss']
64 | epoch = self.checkpoint['epoch']
65 | model_info = {'loss': loss, 'epoch': epoch}
66 | return model_info
67 |
68 |
69 | def calculate_sentence_embedding(self, sentence):
70 | device = "cpu"
71 | input_encodings = self.tokenizer(sentence,
72 | padding=True,
73 | truncation=True,
74 | max_length=Params.max_length,
75 | return_tensors='pt')
76 | sentence_embedding = self.model(input_encodings['input_ids'].to(device),
77 | input_encodings['attention_mask'].to(device),
78 | input_encodings['token_type_ids'].to(device))
79 | return sentence_embedding
80 |
81 |
82 | def calculate_sentence_similarity(self, sentence1, sentence2):
83 | sentence1 = sentence1.strip()
84 | sentence2 = sentence2.strip()
85 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
86 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
87 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
88 | similarity = float(similarity.item())
89 | return similarity
90 |
91 |
92 | class PromptBertRetrieval(object):
93 | def __init__(self, pretrained_model_path, promptbert_path, dropout):
94 | super().__init__()
95 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
96 | special_token_dict = {'additional_special_tokens': ['[X]']}
97 | self.tokenizer.add_special_tokens(special_token_dict)
98 | mask_id = self.tokenizer.convert_tokens_to_ids(Params.mask_token)
99 | model = PromptBERT(pretrained_model_path, dropout, mask_id)
100 | model.encoder.resize_token_embeddings(len(self.tokenizer))
101 | checkpoint = torch.load(promptbert_path, map_location='cpu')
102 | model.load_state_dict(checkpoint['model_state_dict'])
103 | self.checkpoint = checkpoint
104 | self.model = model
105 |
106 |
107 | def print_checkpoint_info(self):
108 | loss = self.checkpoint['loss']
109 | epoch = self.checkpoint['epoch']
110 | model_info = {'loss': loss, 'epoch': epoch}
111 | return model_info
112 |
113 |
114 | def calculate_sentence_mask_embedding(self, sentence):
115 | device = "cpu"
116 | prompt_sentence = Params.prompt_templates[0].replace("[X]", sentence)
117 | prompt_encodings = self.tokenizer(prompt_sentence,
118 | padding=True,
119 | truncation=True,
120 | max_length=Params.max_length,
121 | return_tensors='pt')
122 | sentence_mask_embedding = self.model.calculate_mask_embedding(prompt_encodings['input_ids'].to(device),
123 | prompt_encodings['attention_mask'].to(device),
124 | prompt_encodings['token_type_ids'].to(device))
125 | return sentence_mask_embedding
126 |
127 |
128 | def calculate_sentence_embedding(self, sentence):
129 | device = "cpu"
130 | prompt_sentence = Params.prompt_templates[0].replace("[X]", sentence)
131 | sentence_num = len(self.tokenizer.tokenize(sentence))
132 | template_sentence = Params.prompt_templates[0].replace("[X]", "[X]"*sentence_num)
133 | prompt_encodings = self.tokenizer(prompt_sentence,
134 | padding=True,
135 | truncation=True,
136 | max_length=Params.max_length,
137 | return_tensors='pt')
138 | template_encodings = self.tokenizer(template_sentence,
139 | padding=True,
140 | truncation=True,
141 | max_length=Params.max_length,
142 | return_tensors='pt')
143 | sentence_embedding = self.model(prompt_input_ids=prompt_encodings['input_ids'].to(device),
144 | prompt_attention_mask=prompt_encodings['attention_mask'].to(device),
145 | prompt_token_type_ids=prompt_encodings['token_type_ids'].to(device),
146 | template_input_ids=template_encodings['input_ids'].to(device),
147 | template_attention_mask=template_encodings['attention_mask'].to(device),
148 | template_token_type_ids=template_encodings['token_type_ids'].to(device))
149 | return sentence_embedding
150 |
151 |
152 | def calculate_sentence_similarity(self, sentence1, sentence2):
153 | # sentence1_embedding = self.calculate_sentence_mask_embedding(sentence1)
154 | # sentence2_embedding = self.calculate_sentence_mask_embedding(sentence2)
155 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
156 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
157 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
158 | similarity = float(similarity.item())
159 | return similarity
160 |
161 |
162 | class SBERTRetrieval(object):
163 | def __init__(self, pretrained_model_path, sbert_path, pool_type, dropout):
164 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
165 | model = SBERT(Params.pretrained_model, pool_type, dropout)
166 | self.checkpoint = torch.load(sbert_path, map_location='cpu')
167 | model.load_state_dict(self.checkpoint['model_state_dict'])
168 | model.eval()
169 | self.model = model
170 |
171 |
172 | def print_checkpoint_info(self):
173 | loss = self.checkpoint['train_loss']
174 | epoch = self.checkpoint['epoch']
175 | model_info = {'loss': loss, 'epoch': epoch}
176 | return model_info
177 |
178 |
179 | def calculate_sentence_embedding(self, sentence):
180 | device = "cpu"
181 | input_encodings = self.tokenizer(sentence,
182 | padding=True,
183 | truncation=True,
184 | max_length=Params.max_length,
185 | return_tensors='pt')
186 | sentence_embedding = self.model(input_encodings['input_ids'].to(device),
187 | input_encodings['attention_mask'].to(device),
188 | input_encodings['token_type_ids'].to(device))
189 | return sentence_embedding
190 |
191 |
192 | def calculate_sentence_similarity(self, sentence1, sentence2):
193 | sentence1 = sentence1.strip()
194 | sentence2 = sentence2.strip()
195 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
196 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
197 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
198 | similarity = float(similarity.item())
199 | return similarity
200 |
201 |
202 | class CoSentRetrieval(object):
203 | def __init__(self, pretrained_model_path, cosent_path):
204 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
205 | model = CoSent(Params.pretrained_model, Params.cosent_pool_type, Params.cosent_dropout)
206 | self.checkpoint = torch.load(cosent_path, map_location='cpu')
207 | model.load_state_dict(self.checkpoint['model_state_dict'])
208 | model.eval()
209 | self.model = model
210 |
211 |
212 | def print_checkpoint_info(self):
213 | loss = self.checkpoint['loss']
214 | epoch = self.checkpoint['epoch']
215 | model_info = {'loss': loss, 'epoch': epoch}
216 | return model_info
217 |
218 |
219 | def calculate_sentence_embedding(self, sentence):
220 | device = "cpu"
221 | input_encodings = self.tokenizer(sentence,
222 | padding=True,
223 | truncation=True,
224 | max_length=Params.max_length,
225 | return_tensors='pt')
226 | sentence_embedding = self.model(input_encodings['input_ids'].to(device),
227 | input_encodings['attention_mask'].to(device),
228 | input_encodings['token_type_ids'].to(device))
229 | return sentence_embedding
230 |
231 |
232 | def calculate_sentence_similarity(self, sentence1, sentence2):
233 | sentence1 = sentence1.strip()
234 | sentence2 = sentence2.strip()
235 | sentence1_embedding = self.calculate_sentence_embedding(sentence1)
236 | sentence2_embedding = self.calculate_sentence_embedding(sentence2)
237 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1)
238 | similarity = float(similarity.item())
239 | return similarity
240 |
241 |
242 | simcse_retrieval = SimCSERetrieval(Params.pretrained_model, Params.simcse_model, Params.pool_type, Params.simcse_dropout)
243 | logger.info("start simcse model succussfully!")
244 | esimcse_repeat_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_repeat_model, Params.esimcse_repeat_dropout)
245 | logger.info("start esimcse repeat model succussfully!")
246 | esimcse_same_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_same_model, Params.esimcse_same_dropout)
247 | logger.info("start esimcse same model succussfully!")
248 | esimcse_multi_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_multi_model, Params.esimcse_multi_dropout)
249 | logger.info("start esimcse multi model succussfully!")
250 | promptbert_retrieval = PromptBertRetrieval(Params.pretrained_model, Params.promptbert_model, Params.promptbert_dropout)
251 | logger.info("start promptbert model succussfully!")
252 | sbert_retrieval = SBERTRetrieval(Params.pretrained_model, Params.sbert_model, Params.sbert_pool_type, Params.sbert_dropout)
253 | logger.info("start sbert model succussfully!")
254 | cosent_retrieval = CoSentRetrieval(Params.pretrained_model, Params.cosent_model)
255 | logger.info("start cosent model succussfully!")
256 |
257 |
258 | if __name__ == "__main__":
259 | # model_path = "models/esimcse_0.32_0.15_160.pth"
260 | # model_path = "models/esimcse_multi_0.15_64.pth"
261 | # model_path = "models/esimcse_0.15_64.pth"
262 |
263 |
264 |
265 | # simcse_retrieval = SimCSERetrieval(Params.pretrained_model, Params.simcse_model, Params.pool_type, Params.simcse_dropout)
266 | # model_info = simcse_retrieval.print_checkpoint_info()
267 | # print(model_info)
268 |
269 | model_info = sbert_retrieval.print_checkpoint_info()
270 | print(model_info)
271 |
272 | while True:
273 | print("input your sentence1:")
274 | sentence1 = input()
275 | print("input your sentence2:")
276 | sentence2 = input()
277 |
278 | sbert_sentence_similarity = sbert_retrieval.calculate_sentence_similarity(sentence1, sentence2)
279 | # promptbert_sentence_similarity = prom.calculate_sentence_similarity(sentence1, sentence2)
280 | # print("simcse sim: {}, promptbert sim: {}".format(simcse_sentence_similarity, promptbert_sentence_similarity))
281 | print("sbert similarity: {}".format(sbert_sentence_similarity))
--------------------------------------------------------------------------------
/matching/sbert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/1/5
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import BertConfig, BertModel
7 | import torch
8 |
9 |
10 | class SBERT(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = BertConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | if self.pool_type == "cls":
23 | output = self.encoder(input_ids,
24 | attention_mask=attention_mask,
25 | token_type_ids=token_type_ids)
26 | output = output.last_hidden_state[:, 0]
27 | elif self.pool_type == "pooler":
28 | output = self.encoder(input_ids,
29 | attention_mask=attention_mask,
30 | token_type_ids=token_type_ids)
31 | output = output.pooler_output
32 | elif self.pool_type == "mean":
33 | output = self.get_mean_tensor(input_ids, attention_mask)
34 | return output
35 |
36 |
37 | def get_mean_tensor(self, input_ids, attention_mask):
38 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True)
39 | hidden_states = encode_states.hidden_states
40 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
41 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask)
42 | mean_avg_state = (last_avg_state + first_avg_state) / 2
43 | return mean_avg_state
44 |
45 |
46 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
47 | '''
48 | layer_hidden_state: 模型一层表征向量 [B * L * D]
49 | attention_mask: 句子padding位置 [B * L]
50 | return: 非零位置词语的平均向量 [B * D]
51 | '''
52 | layer_hidden_dim = layer_hidden_state.shape[-1]
53 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
54 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask)
55 | layer_sum_state = layer_attention_state.sum(dim=1)
56 | # print(last_attention_state.shape)
57 |
58 | attention_length_mask = attention_mask.sum(dim=-1)
59 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
60 | # print(attention_length_repeat_mask.shape)
61 |
62 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask)
63 | return layer_avg_state
64 |
65 |
66 | def get_avg_tensor2(self, layer_hidden_state, attention_mask):
67 | '''
68 | layer_hidden_state: 模型一层表征向量 [B * L * D]
69 | attention_mask: 句子padding位置 [B * L]
70 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
71 | '''
72 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
73 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
74 | sum_mask = input_mask_expanded.sum(1)
75 | sum_mask = torch.clamp(sum_mask, min=1e-9)
76 | avg_embeddings = sum_embeddings / sum_mask
77 | return avg_embeddings
--------------------------------------------------------------------------------
/matching/simcse.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/6/10
3 | # @Author : kaka
4 |
5 | import torch.nn as nn
6 | from transformers import BertConfig, BertModel
7 |
8 |
9 | class SimCSE(nn.Module):
10 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
11 | super().__init__()
12 | conf = BertConfig.from_pretrained(pretrained)
13 | conf.attention_probs_dropout_prob = dropout_prob
14 | conf.hidden_dropout_prob = dropout_prob
15 | self.encoder = BertModel.from_pretrained(pretrained, config=conf)
16 | assert pool_type in ["cls", "pooler"], "invalid pool_type: %s" % pool_type
17 | self.pool_type = pool_type
18 |
19 |
20 | def forward(self, input_ids, attention_mask, token_type_ids):
21 | output = self.encoder(input_ids,
22 | attention_mask=attention_mask,
23 | token_type_ids=token_type_ids)
24 | if self.pool_type == "cls":
25 | output = output.last_hidden_state[:, 0]
26 | elif self.pool_type == "pooler":
27 | output = output.pooler_output
28 | return output
--------------------------------------------------------------------------------
/matching/view.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify
2 | from retrieval import simcse_retrieval, esimcse_same_retrieval, esimcse_repeat_retrieval, esimcse_multi_retrieval, promptbert_retrieval
3 | from retrieval import sbert_retrieval, cosent_retrieval
4 |
5 |
6 | app = Flask(__name__)
7 | @app.route('/compare', methods=['GET', 'POST'])
8 | def compare():
9 | sentence1 = request.args.get("sentence1", "")
10 | sentence2 = request.args.get("sentence2", "")
11 | channel = request.args.get("channel", "")
12 | sentence1 = strip_sentence(sentence1)
13 | sentence2 = strip_sentence(sentence2)
14 |
15 | if channel == "0":
16 | model_name = "simcse"
17 | model_desc = "use simcse model for data augmentation and compare positive samples"
18 | model_info = simcse_retrieval.print_checkpoint_info()
19 | model_info['model_name'] = model_name
20 | model_info['model_desc'] = model_desc
21 | similarity = simcse_retrieval.calculate_sentence_similarity(sentence1, sentence2)
22 | elif channel == "1":
23 | model_name = "esimcse_same_positive"
24 | model_desc = "use same sentence as positive pairs and construct a negative queue"
25 | model_info = esimcse_same_retrieval.print_checkpoint_info()
26 | model_info['model_name'] = model_name
27 | model_info['model_desc'] = model_desc
28 | similarity = esimcse_same_retrieval.calculate_sentence_similarity(sentence1, sentence2)
29 | elif channel == "2":
30 | model_name = "esimcse_repeat_positive"
31 | model_desc = "use repeat word for data augmentation as positive pairs and construct a negative queue"
32 | model_info = esimcse_repeat_retrieval.print_checkpoint_info()
33 | model_info['model_name'] = model_name
34 | model_info['model_desc'] = model_desc
35 | similarity = esimcse_repeat_retrieval.calculate_sentence_similarity(sentence1, sentence2)
36 | elif channel == "3":
37 | model_name = "esimcse_multi_positive"
38 | model_desc = "Multi task loss: use same sentence and repeat word as positive pairs and construct a negative queue"
39 | model_info = esimcse_multi_retrieval.print_checkpoint_info()
40 | model_info['model_name'] = model_name
41 | model_info['model_desc'] = model_desc
42 | similarity = esimcse_multi_retrieval.calculate_sentence_similarity(sentence1, sentence2)
43 | elif channel == "4":
44 | model_name = "promptbert"
45 | model_desc = "use different templates to generate sentence embedding as positive pairs"
46 | model_info = promptbert_retrieval.print_checkpoint_info()
47 | model_info['model_name'] = model_name
48 | model_info['model_desc'] = model_desc
49 | similarity = promptbert_retrieval.calculate_sentence_similarity(sentence1, sentence2)
50 | elif channel == "5":
51 | model_name = "sbert"
52 | model_desc = "train sentence-bert structure model with cosine similarity and mse loss, using high prediction probability cases as training dataset"
53 | model_info = sbert_retrieval.print_checkpoint_info()
54 | model_info['model_name'] = model_name
55 | model_info['model_desc'] = model_desc
56 | similarity = sbert_retrieval.calculate_sentence_similarity(sentence1, sentence2)
57 | elif channel == "6":
58 | model_name = "cosent"
59 | model_desc = "train cosent structure model with contrastive loss"
60 | model_info = cosent_retrieval.print_checkpoint_info()
61 | model_info['model_name'] = model_name
62 | model_info['model_desc'] = model_desc
63 | similarity = cosent_retrieval.calculate_sentence_similarity(sentence1, sentence2)
64 | else:
65 | model_info = {'model_name': 'your channel is illegal'}
66 | similarity = None
67 |
68 | sent_info = {'sentence1': sentence1, 'sentence2': sentence2, 'similarity': similarity}
69 | result = {'model_info': model_info, 'sentence_info': sent_info}
70 | data = {'code': 200, 'message': 'OK', 'data': result}
71 | resp = jsonify(data)
72 | return resp
73 |
74 |
75 | def strip_sentence(sentence):
76 | sentence = sentence.strip().lower()
77 | sentence = sentence.replace("?", "").replace("?", "")
78 | return sentence
79 |
80 |
81 | if __name__ == "__main__":
82 | app.run(host='0.0.0.0', port=8080)
--------------------------------------------------------------------------------
/pics/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/.DS_Store
--------------------------------------------------------------------------------
/pics/cosent_loss.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pics/esimcse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/esimcse.png
--------------------------------------------------------------------------------
/pics/esimcse_loss.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pics/promptbert_loss.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pics/sbert.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/sbert.png
--------------------------------------------------------------------------------
/pics/simcse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/simcse.png
--------------------------------------------------------------------------------
/pics/simcse_loss.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/promptbert/PromptBert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/21
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import AutoConfig, AutoModel
7 | import torch
8 |
9 |
10 | class PromptBERT(nn.Module):
11 | def __init__(self, pretrained_model_path, dropout_prob, mask_id):
12 | super().__init__()
13 | conf = AutoConfig.from_pretrained(pretrained_model_path)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = AutoModel.from_pretrained(pretrained_model_path, config=conf)
17 | self.mask_id = mask_id
18 |
19 |
20 | def forward(self, prompt_input_ids, prompt_attention_mask, prompt_token_type_ids, template_input_ids, template_attention_mask, template_token_type_ids):
21 | '''
22 | @function: 计算prompt mask标签表征向量hi和模板向量表征h^i之间的差
23 |
24 | @input:
25 | prompt_input_ids: prompt句子输入id
26 | prompt_attention_mask: prompt句子注意力矩阵
27 | prompt_token_type_ids: prompt句子token类型id
28 | template_input_ids: 模板句子输入id
29 | template_attention_mask: 模板句子注意力矩阵
30 | template_token_type_ids: 模板句子token类型id
31 |
32 | @return: sentence_embedding: 句子表征向量
33 | '''
34 | prompt_mask_embedding = self.calculate_mask_embedding(prompt_input_ids, prompt_attention_mask, prompt_token_type_ids)
35 | template_mask_embedding = self.calculate_mask_embedding(template_input_ids, template_attention_mask, template_token_type_ids)
36 | sentence_embedding = prompt_mask_embedding - template_mask_embedding
37 | return sentence_embedding
38 |
39 |
40 | def calculate_mask_embedding(self, input_ids, attention_mask, token_type_ids):
41 | # print("input_ids: ", input_ids)
42 | output = self.encoder(input_ids,
43 | attention_mask=attention_mask,
44 | token_type_ids=token_type_ids)
45 | token_embeddings = output[0]
46 | mask_index = (input_ids == self.mask_id).long()
47 | # print("mask_index: ", mask_index)
48 | mask_embedding = self.get_mask_embedding(token_embeddings, mask_index)
49 | return mask_embedding
50 |
51 |
52 | def get_mask_embedding(self, token_embeddings, mask_index):
53 | '''
54 | @function: 获取[mask]标签的embedding输出
55 |
56 | @input:
57 | token_embeddings: Tensor, 编码层最后一层token输出
58 | mask_index: Tensor, mask标签位置
59 |
60 | @return: mask_embedding: Tensor, mask标签embedding
61 | '''
62 | input_mask_expanded = mask_index.unsqueeze(-1).expand(token_embeddings.size()).float()
63 | # print("input_mask_expanded: ", input_mask_expanded)
64 | # print("input mask expaned shape: ", input_mask_expanded.shape)
65 | mask_embedding = torch.sum(token_embeddings * input_mask_expanded, 1)
66 | return mask_embedding
67 |
68 |
69 | def forward_sentence(self, input_ids, attention_mask, token_type_ids, pool_type):
70 | output = self.encoder(input_ids,
71 | attention_mask=attention_mask,
72 | token_type_ids=token_type_ids,
73 | output_hidden_states=True)
74 | hidden_states = output.hidden_states
75 | if pool_type == "cls":
76 | output = output.last_hidden_state[:, 0]
77 | elif pool_type == "pooler":
78 | output = output.pooler_output
79 | elif pool_type == "avg_first_last":
80 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask)
81 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
82 | output = (top_first_state + last_first_state) / 2
83 | else:
84 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
85 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask)
86 | output = (last_first_state + last_second_state) / 2
87 | return output
88 |
89 |
90 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
91 | '''
92 | layer_hidden_state: 模型一层表征向量 [B * L * D]
93 | attention_mask: 句子padding位置 [B * L]
94 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
95 | '''
96 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
97 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
98 | sum_mask = input_mask_expanded.sum(1)
99 | sum_mask = torch.clamp(sum_mask, min=1e-9)
100 | avg_embeddings = sum_embeddings / sum_mask
101 | return avg_embeddings
--------------------------------------------------------------------------------
/promptbert/loading.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class MatchingDataSet:
9 | def read_train_file(self, trainfile, devfile, testfile, filetype):
10 | sents = []
11 | if filetype == "txt":
12 | with open(trainfile, 'r') as f:
13 | for line in f.readlines():
14 | _, s1, s2, _ = line.strip().split(u"||")
15 | sents.append(s1)
16 | sents.append(s2)
17 | with open(devfile, 'r') as f:
18 | for line in f.readlines():
19 | _, s1, s2, _ = line.strip().split(u"||")
20 | sents.append(s1)
21 | sents.append(s2)
22 | with open(testfile, 'r') as f:
23 | for line in f.readlines():
24 | _, s1, s2, _ = line.strip().split(u"||")
25 | sents.append(s1)
26 | sents.append(s2)
27 | return sents
28 |
29 | def read_eval_file(self, file, filetype):
30 | sents = []
31 | if filetype == "txt":
32 | with open(file, 'r') as f:
33 | for line in f.readlines():
34 | _, s1, s2, s = line.strip().split(u"||")
35 | item = {'sent1': s1,
36 | 'sent2': s2,
37 | 'score': float(s)}
38 | sents.append(item)
39 | return sents
40 |
41 |
42 | if __name__ == "__main__":
43 | trainfile = "../dataset/STS-B/train.txt"
44 | devfile = "../dataset/STS-B/dev.txt"
45 | testfile = "../dataset/STS-B/test.txt"
46 | match_dataset = MatchingDataSet()
47 |
48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt")
49 | print(train_list[:5])
50 |
51 | train_lengths = [len(sentence) for sentence in train_list]
52 | max_len = max(train_lengths)
53 |
54 |
55 | dev_list = match_dataset.read_eval_file(devfile, "txt")
56 | dev_sen1_length = [len(d['sent1']) for d in dev_list]
57 | dev_sen2_length = [len(d['sent2']) for d in dev_list]
58 | max_sen1 = max(dev_sen1_length)
59 | max_sen2 = max(dev_sen2_length)
60 | print(max_len, max_sen1, max_sen2)
61 | # dev_loader = DataLoader(dev_list,
62 | # batch_size=8)
63 | # for batch in dev_loader:
64 | # print(batch)
65 | # exit(0)
--------------------------------------------------------------------------------
/promptbert/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/19
3 | # @Author : Maciel
4 |
5 | from loading import MatchingDataSet
6 | import torch
7 | import torch.nn.functional as F
8 | import scipy.stats
9 | from torch.utils.data import DataLoader
10 | from PromptBert import PromptBERT
11 | from transformers import AutoTokenizer
12 | import os
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14 |
15 |
16 | def eval(model, tokenizer, dev_loader, device, max_length, pool_type):
17 | model.eval()
18 | model.to(device)
19 |
20 | all_sims, all_scores = [], []
21 | with torch.no_grad():
22 | for data in dev_loader:
23 | sent1 = data['sent1']
24 | sent2 = data['sent2']
25 | score = data['score']
26 | sent1_encoding = tokenizer(sent1,
27 | padding=True,
28 | truncation=True,
29 | max_length=max_length,
30 | return_tensors='pt')
31 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()}
32 | sent2_encoding = tokenizer(sent2,
33 | padding=True,
34 | truncation=True,
35 | max_length=max_length,
36 | return_tensors='pt')
37 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()}
38 |
39 | sent1_output = model.forward_sentence(sent1_encoding['input_ids'],
40 | sent1_encoding['attention_mask'],
41 | sent1_encoding['token_type_ids'],
42 | pool_type)
43 | sent2_output = model.forward_sentence(sent2_encoding['input_ids'],
44 | sent2_encoding['attention_mask'],
45 | sent2_encoding['token_type_ids'],
46 | pool_type)
47 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
48 | all_sims += sim_score
49 | all_scores += score.tolist()
50 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
51 | return corr
52 |
53 |
54 | def eval2(model, tokenizer, dev_loader, device, max_length, pool_type):
55 | model.eval()
56 | model.to(device)
57 |
58 | all_sims, all_scores = [], []
59 | with torch.no_grad():
60 | for data in dev_loader:
61 | sent1 = data['sent1']
62 | sent2 = data['sent2']
63 | score = data['score']
64 |
65 | # print("sent1: {}, sent2: {}".format(sent1, sent2))
66 | prompt_template_sent1 = [transform_sentence(s1, tokenizer, max_length) for s1 in sent1]
67 | prompt_sent1 = [pair[0] for pair in prompt_template_sent1]
68 | template_sent1 = [pair[1] for pair in prompt_template_sent1]
69 | # print("prompt sent1: {}".format(prompt_sent1))
70 | # print("template sent1: {}".format(template_sent1))
71 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1, max_length)
72 | template_encoding1 = encode_sentences(tokenizer, template_sent1, max_length)
73 | prompt_encoding1 = {key: value.to(device) for key, value in prompt_encoding1.items()}
74 | template_encoding1 = {key: value.to(device) for key, value in template_encoding1.items()}
75 | # print("prompt_encoding1 input_ids {}".format(prompt_encoding1['input_ids']))
76 | # print("template_encoding1 input_ids: {}".format(template_encoding1['input_ids']))
77 |
78 | prompt_template_sent2 = [transform_sentence(s2, tokenizer, max_length) for s2 in sent2]
79 | prompt_sent2 = [pair[0] for pair in prompt_template_sent2]
80 | template_sent2 = [pair[1] for pair in prompt_template_sent2]
81 | # print("prompt sent2: {}".format(prompt_sent2))
82 | # print("template sent2: {}".format(template_sent2))
83 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2, max_length)
84 | template_encoding2 = encode_sentences(tokenizer, template_sent2, max_length)
85 | prompt_encoding2 = {key: value.to(device) for key, value in prompt_encoding2.items()}
86 | template_encoding2 = {key: value.to(device) for key, value in template_encoding2.items()}
87 | # print("prompt_encoding2 input_ids: {}".format(prompt_encoding2['input_ids']))
88 | # print("template_encoding2 input_ids: {}".format(template_encoding2['input_ids']))
89 |
90 | sent1_output = model(prompt_encoding1['input_ids'],
91 | prompt_encoding1['attention_mask'],
92 | prompt_encoding1['token_type_ids'],
93 | template_encoding1['input_ids'],
94 | template_encoding1['attention_mask'],
95 | template_encoding1['token_type_ids'])
96 | sent2_output = model(prompt_encoding2['input_ids'],
97 | prompt_encoding2['attention_mask'],
98 | prompt_encoding2['token_type_ids'],
99 | template_encoding2['input_ids'],
100 | template_encoding2['attention_mask'],
101 | template_encoding2['token_type_ids'])
102 | # print("sen1 output shape: {}, sen2 output shape: {}".format(sent1_output.shape, sent2_output.shape))
103 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
104 | all_sims += sim_score
105 | all_scores += score.tolist()
106 | # print("sim_score: {}, score: {}".format(sim_score, score))
107 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
108 | return corr
109 |
110 |
111 | def transform_sentence(sentence, tokenizer, max_length):
112 | prompt_templates = ['[X],它的意思是[MASK]', '[X],这句话的意思是[MASK]']
113 | words_list = tokenizer.tokenize(sentence)
114 | words_num = len(words_list)
115 | sentence_template = []
116 | for template in prompt_templates:
117 | if words_num > max_length - 15:
118 | words_list = words_list[:-15]
119 | sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(words_list)).replace(" ", "")
120 |
121 | words_len = len(tokenizer.tokenize(sentence))
122 | prompt_sentence = template.replace("[X]", sentence)
123 | template_sentence = template.replace("[X]", "[X]"*words_len)
124 | sentence_template += [prompt_sentence, template_sentence]
125 | return sentence_template
126 |
127 |
128 | def encode_sentences(tokenizer, sen_list, max_length):
129 | sen_encoding = tokenizer(sen_list,
130 | padding=True,
131 | truncation=True,
132 | max_length=max_length,
133 | return_tensors='pt')
134 | return sen_encoding
135 |
136 |
137 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length):
138 | match_dataset = MatchingDataSet()
139 | testfile_type = "txt"
140 | test_list = match_dataset.read_eval_file(testfile, testfile_type)
141 | print("test samples num: {}".format(len(test_list)))
142 |
143 | test_loader = DataLoader(test_list,
144 | batch_size=4)
145 | print("test batch num: {}".format(len(test_loader)))
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(pretrained)
148 | special_token_dict = {'additional_special_tokens': ['[X]']}
149 | tokenizer.add_special_tokens(special_token_dict)
150 | mask_id = tokenizer.mask_token_id
151 | model = PromptBERT(pretrained, dropout_rate, mask_id)
152 | model.encoder.resize_token_embeddings(len(tokenizer))
153 | model.load_state_dict(torch.load(model_path, map_location='cpu')['model_state_dict'])
154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155 |
156 | # test_corr = eval(model, tokenizer, test_loader, device, max_length, pool_type)
157 | test_corr2 = eval2(model, tokenizer, test_loader, device, max_length, pool_type)
158 | # print("test corr: {}, test_corr2: {}".format(test_corr, test_corr2))
159 | print(test_corr2)
160 |
161 |
162 | if __name__ == "__main__":
163 | testfile = "../dataset/STS-B/test.txt"
164 | pretrained = "hfl/chinese-roberta-wwm-ext"
165 | pool_type = "avg_first_last"
166 | dropout_rate = 0.1
167 | max_length = 128
168 | model_path = "../models/promptbert_roberta_stsb.pth"
169 |
170 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length)
171 |
--------------------------------------------------------------------------------
/promptbert/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/21
3 | # @Author : Maciel
4 |
5 |
6 | from transformers import AutoTokenizer
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | import scipy.stats
12 | from loading import MatchingDataSet
13 | from PromptBert import PromptBERT
14 | import argparse
15 | from loguru import logger
16 | import os
17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path")
23 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path")
24 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path")
25 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type")
26 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext", help="huggingface pretrained model")
27 | parser.add_argument("--model_out", type=str, default="../models/promptbert_roberta_stsb.pth", help="model output path")
28 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length")
29 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
30 | parser.add_argument("--epochs", type=int, default=10, help="epochs")
31 | parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
32 | parser.add_argument("--tao", type=float, default=0.05, help="temperature")
33 | parser.add_argument("--device", type=str, default="cuda", help="device")
34 | parser.add_argument("--display_interval", type=int, default=100, help="display interval")
35 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type")
36 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate")
37 | parser.add_argument("--task", type=str, default="promptbert", help="task name")
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | def compute_loss(query, key, tao=0.05):
43 | # 对表征进行归一化,便于后面相似度计算
44 | query = F.normalize(query, dim=1)
45 | key = F.normalize(key, dim=1)
46 | # print(query.shape, key.shape)
47 | N, D = query.shape[0], query.shape[1]
48 |
49 | # calculate positive similarity
50 | batch_pos = torch.exp(torch.div(torch.bmm(query.view(N, 1, D), key.view(N, D, 1)).view(N, 1), tao))
51 |
52 | # calculate inner_batch all similarity
53 | batch_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N, D), torch.t(key)), tao)), dim=1)
54 |
55 | loss = torch.mean(-torch.log(torch.div(batch_pos, batch_all)))
56 | return loss
57 |
58 |
59 | def eval(model, tokenizer, dev_loader, args):
60 | model.eval()
61 | model.to(args.device)
62 |
63 | all_sims, all_scores = [], []
64 | with torch.no_grad():
65 | for data in dev_loader:
66 | sent1 = data['sent1']
67 | sent2 = data['sent2']
68 | score = data['score']
69 | sent1_encoding = tokenizer(sent1,
70 | padding=True,
71 | truncation=True,
72 | max_length=args.max_length,
73 | return_tensors='pt')
74 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()}
75 | sent2_encoding = tokenizer(sent2,
76 | padding=True,
77 | truncation=True,
78 | max_length=args.max_length,
79 | return_tensors='pt')
80 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()}
81 |
82 | sent1_output = model.forward_sentence(sent1_encoding['input_ids'],
83 | sent1_encoding['attention_mask'],
84 | sent1_encoding['token_type_ids'],
85 | args.pool_type)
86 | sent2_output = model.forward_sentence(sent2_encoding['input_ids'],
87 | sent2_encoding['attention_mask'],
88 | sent2_encoding['token_type_ids'],
89 | args.pool_type)
90 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
91 | all_sims += sim_score
92 | all_scores += score.tolist()
93 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
94 | return corr
95 |
96 |
97 | def eval2(model, tokenizer, dev_loader, args):
98 | model.eval()
99 | model.to(args.device)
100 |
101 | all_sims, all_scores = [], []
102 | with torch.no_grad():
103 | for data in dev_loader:
104 | sent1 = data['sent1']
105 | sent2 = data['sent2']
106 | score = data['score']
107 |
108 | prompt_template_sent1 = [transform_sentence(s1, tokenizer, args) for s1 in sent1]
109 | prompt_sent1 = [pair[0] for pair in prompt_template_sent1]
110 | template_sent1 = [pair[1] for pair in prompt_template_sent1]
111 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1, args)
112 | template_encoding1 = encode_sentences(tokenizer, template_sent1, args)
113 | prompt_encoding1 = {key: value.to(args.device) for key, value in prompt_encoding1.items()}
114 | template_encoding1 = {key: value.to(args.device) for key, value in template_encoding1.items()}
115 |
116 | prompt_template_sent2 = [transform_sentence(s2, tokenizer, args) for s2 in sent2]
117 | prompt_sent2 = [pair[0] for pair in prompt_template_sent2]
118 | template_sent2 = [pair[1] for pair in prompt_template_sent2]
119 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2, args)
120 | template_encoding2 = encode_sentences(tokenizer, template_sent2, args)
121 | prompt_encoding2 = {key: value.to(args.device) for key, value in prompt_encoding2.items()}
122 | template_encoding2 = {key: value.to(args.device) for key, value in template_encoding2.items()}
123 |
124 | sent1_output = model(prompt_encoding1['input_ids'],
125 | prompt_encoding1['attention_mask'],
126 | prompt_encoding1['token_type_ids'],
127 | template_encoding1['input_ids'],
128 | template_encoding1['attention_mask'],
129 | template_encoding1['token_type_ids'])
130 | sent2_output = model(prompt_encoding2['input_ids'],
131 | prompt_encoding2['attention_mask'],
132 | prompt_encoding2['token_type_ids'],
133 | template_encoding2['input_ids'],
134 | template_encoding2['attention_mask'],
135 | template_encoding2['token_type_ids'])
136 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
137 | all_sims += sim_score
138 | all_scores += score.tolist()
139 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
140 | return corr
141 |
142 |
143 | def transform_sentence(sentence, tokenizer, args):
144 | prompt_templates = ['[X],它的意思是[MASK]', '[X],这句话的意思是[MASK]']
145 | words_list = tokenizer.tokenize(sentence)
146 | words_num = len(words_list)
147 | sentence_template = []
148 | for template in prompt_templates:
149 | if words_num > args.max_length - 15:
150 | words_list = words_list[:-15]
151 | sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(words_list)).replace(" ", "")
152 |
153 | words_len = len(tokenizer.tokenize(sentence))
154 | prompt_sentence = template.replace("[X]", sentence)
155 | template_sentence = template.replace("[X]", "[X]"*words_len)
156 | sentence_template += [prompt_sentence, template_sentence]
157 | return sentence_template
158 |
159 |
160 | def encode_sentences(tokenizer, sen_list, args):
161 | sen_encoding = tokenizer(sen_list,
162 | padding=True,
163 | truncation=True,
164 | max_length=args.max_length,
165 | return_tensors='pt')
166 | return sen_encoding
167 |
168 |
169 | def build_dataset(dataloader, tokenizer, args):
170 | data_encodings = []
171 | for data in dataloader:
172 | prompt_template_sentences = [transform_sentence(sentence, tokenizer, args) for sentence in data]
173 | prompt_sent1_list = [pair[0] for pair in prompt_template_sentences]
174 | template_sent1_list = [pair[1] for pair in prompt_template_sentences]
175 | prompt_sent2_list = [pair[2] for pair in prompt_template_sentences]
176 | template_sent2_list = [pair[3] for pair in prompt_template_sentences]
177 |
178 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1_list, args)
179 | template_encoding1 = encode_sentences(tokenizer, template_sent1_list, args)
180 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2_list, args)
181 | template_encoding2 = encode_sentences(tokenizer, template_sent2_list, args)
182 | data_encodings.append([prompt_encoding1, template_encoding1, prompt_encoding2, template_encoding2])
183 | return data_encodings
184 |
185 |
186 | def train(args):
187 | train_file = args.trainfile
188 | dev_file = args.devfile
189 | test_file = args.testfile
190 | file_type = args.filetype
191 | match_dataset = MatchingDataSet()
192 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type)
193 | dev_list = match_dataset.read_eval_file(dev_file, file_type)
194 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list)))
195 |
196 | train_loader = DataLoader(train_list,
197 | batch_size=args.batch_size,
198 | shuffle=True)
199 | dev_loader = DataLoader(dev_list,
200 | batch_size=args.batch_size)
201 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader)))
202 |
203 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
204 | special_token_dict = {'additional_special_tokens': ['[X]']}
205 | tokenizer.add_special_tokens(special_token_dict)
206 | mask_id = tokenizer.mask_token_id
207 |
208 | model = PromptBERT(args.pretrained, args.dropout_rate, mask_id)
209 | model.encoder.resize_token_embeddings(len(tokenizer))
210 | model.train()
211 | model.to(args.device)
212 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
213 |
214 | train_encodings = build_dataset(train_loader, tokenizer, args)
215 |
216 | batch_idx = 0
217 | best_corr = 0
218 | best_loss = 1000000
219 | for epoch in range(args.epochs):
220 | epoch_losses = []
221 | for data in train_encodings:
222 | optimizer.zero_grad()
223 | prompt_encoding1, template_encoding1, prompt_encoding2, template_encoding2 = data
224 | prompt_encoding1 = {key: value.to(args.device) for key, value in prompt_encoding1.items()}
225 | template_encoding1 = {key: value.to(args.device) for key, value in template_encoding1.items()}
226 | prompt_encoding2 = {key: value.to(args.device) for key, value in prompt_encoding2.items()}
227 | template_encoding2 = {key: value.to(args.device) for key, value in template_encoding2.items()}
228 |
229 | query_embedding = model(prompt_encoding1['input_ids'],
230 | prompt_encoding1['attention_mask'],
231 | prompt_encoding1['token_type_ids'],
232 | template_encoding1['input_ids'],
233 | template_encoding1['attention_mask'],
234 | template_encoding1['token_type_ids'])
235 | key_embedding = model(prompt_encoding2['input_ids'],
236 | prompt_encoding2['attention_mask'],
237 | prompt_encoding2['token_type_ids'],
238 | template_encoding2['input_ids'],
239 | template_encoding2['attention_mask'],
240 | template_encoding2['token_type_ids'])
241 | batch_loss = compute_loss(query_embedding, key_embedding, args.tao)
242 |
243 | batch_loss.backward()
244 | optimizer.step()
245 | epoch_losses.append(batch_loss.item())
246 | if batch_idx % args.display_interval == 0:
247 | logger.info("Epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item()))
248 | batch_idx += 1
249 |
250 | avg_epoch_loss = np.mean(epoch_losses)
251 | dev_corr = eval2(model, tokenizer, dev_loader, args)
252 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr))
253 | if avg_epoch_loss <= best_loss and dev_corr >= best_corr:
254 | best_corr = dev_corr
255 | best_loss = avg_epoch_loss
256 | torch.save({
257 | 'epoch': epoch,
258 | 'batch': batch_idx,
259 | 'model_state_dict': model.state_dict(),
260 | 'loss': best_loss,
261 | 'corr': best_corr
262 | }, args.model_out)
263 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr))
264 |
265 |
266 | if __name__ == "__main__":
267 | args = parse_args()
268 | logger.info("args: {}".format(args))
269 | train(args)
270 |
271 | # sentence = "今天天气真不错啊"
272 | # tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
273 | # special_token_dict = {'additional_special_tokens': ['[X]']}
274 | # tokenizer.add_special_tokens(special_token_dict)
275 | # sentence_template = transform_sentence(sentence, tokenizer, args)
276 | # print(sentence_template)
277 |
278 | # st_encoding = encode_sentences(tokenizer, sentence_template, args)
279 | # print(st_encoding['input_ids'])
280 |
--------------------------------------------------------------------------------
/sbert/config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class Params:
4 | epoches = 100
5 | batch_size = 32
6 | max_length = 32
7 | learning_rate = 2e-5
8 | dropout = 0.2
9 | warmup_steps = 100
10 | display_interval = 500
11 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large"
12 | sbert_model = "models/sbert_0106.pth"
13 | # pretrained_model = "clue/roberta_chinese_pair_large"
14 | pool_type = "mean"
15 | train_file = "data/train_dataset.csv"
16 | test_file = "data/test_dataset.csv"
--------------------------------------------------------------------------------
/sbert/loading.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from torch.utils.data import Dataset
4 | from transformers import AutoTokenizer
5 | from config import Params
6 |
7 |
8 | class LoadDataset():
9 | def __init__(self, data_file):
10 | data_df = pd.read_csv(data_file)
11 | self.data_df = data_df.fillna("")
12 |
13 |
14 | def get_dataset(self):
15 | question1_list = list(self.data_df['question1'])
16 | question2_list = list(self.data_df['question2'])
17 | label_list = list(self.data_df['label'])
18 | return question1_list, question2_list, label_list
19 |
20 |
21 | def get_encodings(self, tokenzier, questions):
22 | question_encodings = tokenzier(questions,
23 | truncation=True,
24 | padding=True,
25 | max_length=Params.max_length,
26 | return_tensors='pt')
27 | return question_encodings
28 |
29 |
30 | class PairDataset(Dataset):
31 | def __init__(self, q1_encodings, q2_encodings, labels):
32 | self.q1_encodings = q1_encodings
33 | self.q2_encodings = q2_encodings
34 | self.labels = labels
35 |
36 |
37 | # 读取单个样本
38 | def __getitem__(self, idx):
39 | item1 = {key: torch.tensor(val[idx]) for key, val in self.q1_encodings.items()}
40 | labels = torch.tensor(int(self.labels[idx]))
41 | item2 = {key: torch.tensor(val[idx]) for key, val in self.q2_encodings.items()}
42 | return item1, item2, labels
43 |
44 |
45 | def __len__(self):
46 | return len(self.labels)
47 |
48 |
49 | if __name__ == "__main__":
50 | print(1)
--------------------------------------------------------------------------------
/sbert/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/1/5
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import BertConfig, BertModel
7 | import torch
8 |
9 |
10 | class SBERT(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = BertConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | if self.pool_type == "cls":
23 | output = self.encoder(input_ids,
24 | attention_mask=attention_mask,
25 | token_type_ids=token_type_ids)
26 | output = output.last_hidden_state[:, 0]
27 | elif self.pool_type == "pooler":
28 | output = self.encoder(input_ids,
29 | attention_mask=attention_mask,
30 | token_type_ids=token_type_ids)
31 | output = output.pooler_output
32 | elif self.pool_type == "mean":
33 | output = self.get_mean_tensor(input_ids, attention_mask)
34 | return output
35 |
36 |
37 | def get_mean_tensor(self, input_ids, attention_mask):
38 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True)
39 | hidden_states = encode_states.hidden_states
40 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
41 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask)
42 | mean_avg_state = (last_avg_state + first_avg_state) / 2
43 | return mean_avg_state
44 |
45 |
46 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
47 | '''
48 | layer_hidden_state: 模型一层表征向量 [B * L * D]
49 | attention_mask: 句子padding位置 [B * L]
50 | return: 非零位置词语的平均向量 [B * D]
51 | '''
52 | layer_hidden_dim = layer_hidden_state.shape[-1]
53 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
54 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask)
55 | layer_sum_state = layer_attention_state.sum(dim=1)
56 | # print(last_attention_state.shape)
57 |
58 | attention_length_mask = attention_mask.sum(dim=-1)
59 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim)
60 | # print(attention_length_repeat_mask.shape)
61 |
62 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask)
63 | return layer_avg_state
64 |
65 |
66 | def get_avg_tensor2(self, layer_hidden_state, attention_mask):
67 | '''
68 | layer_hidden_state: 模型一层表征向量 [B * L * D]
69 | attention_mask: 句子padding位置 [B * L]
70 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
71 | '''
72 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
73 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
74 | sum_mask = input_mask_expanded.sum(1)
75 | sum_mask = torch.clamp(sum_mask, min=1e-9)
76 | avg_embeddings = sum_embeddings / sum_mask
77 | return avg_embeddings
--------------------------------------------------------------------------------
/sbert/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import DataLoader
4 | from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup
5 | import numpy as np
6 | from tqdm import tqdm
7 | from sklearn.metrics import accuracy_score
8 | from loading import LoadDataset, PairDataset
9 | from model import SBERT
10 | from config import Params
11 | import os
12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
13 |
14 |
15 | def compute_loss(similarity, label, loss_fn):
16 | # mse loss
17 | loss = loss_fn(similarity, label)
18 | return loss
19 |
20 |
21 | def compute_acc(similarity, label):
22 | pred = (similarity >= 0.5).long()
23 | acc = accuracy_score(pred.detach().cpu().numpy(), label.cpu().numpy())
24 | return acc
25 |
26 |
27 | def testing(model, test_loader):
28 | model.eval()
29 | test_loss, test_acc = [], []
30 | for test_q1, test_q2, test_label in test_loader:
31 | test_q1_input_ids = test_q1['input_ids'].to(device)
32 | test_q1_attention_mask = test_q1['attention_mask'].to(device)
33 | test_q1_token_type_ids = test_q1['token_type_ids'].to(device)
34 |
35 | test_q2_input_ids = test_q2['input_ids'].to(device)
36 | test_q2_attention_mask = test_q2['attention_mask'].to(device)
37 | test_q2_token_type_ids = test_q2['token_type_ids'].to(device)
38 |
39 | test_label = test_label.float().to(device)
40 |
41 | test_q1_embedding = model(test_q1_input_ids, test_q1_attention_mask, test_q1_token_type_ids)
42 | test_q2_embedding = model(test_q2_input_ids, test_q2_attention_mask, test_q2_token_type_ids)
43 | test_similarity = torch.cosine_similarity(test_q1_embedding, test_q2_embedding, dim=1)
44 | batch_test_loss = compute_loss(test_similarity, test_label, loss_fn)
45 | batch_test_acc = compute_acc(test_similarity, test_label)
46 |
47 | test_loss.append(batch_test_loss.item())
48 | test_acc.append(batch_test_acc)
49 |
50 | test_avg_loss = np.mean(test_loss)
51 | test_avg_acc = np.mean(test_acc)
52 | return test_avg_loss, test_avg_acc
53 |
54 |
55 | train_loading = LoadDataset(Params.train_file)
56 | test_loading = LoadDataset(Params.test_file)
57 |
58 | # load tokenizer
59 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model)
60 |
61 | # load train dataset
62 | train_question1, train_question2, train_labels = train_loading.get_dataset()
63 | train_q1_encodings = train_loading.get_encodings(tokenizer, train_question1)
64 | train_q2_encodings = train_loading.get_encodings(tokenizer, train_question2)
65 | train_dataset = PairDataset(train_q1_encodings, train_q2_encodings, train_labels)
66 | train_loader = DataLoader(train_dataset,
67 | batch_size=Params.batch_size,
68 | shuffle=True)
69 |
70 | # load test dataset
71 | test_question1, test_question2, test_labels = test_loading.get_dataset()
72 | test_q1_encodings = test_loading.get_encodings(tokenizer, test_question1)
73 | test_q2_encodings = test_loading.get_encodings(tokenizer, test_question2)
74 | test_dataset = PairDataset(test_q1_encodings, test_q2_encodings, test_labels)
75 | test_loader = DataLoader(test_dataset,
76 | batch_size=Params.batch_size)
77 |
78 | # load model
79 | model = SBERT(Params.pretrained_model, Params.pool_type, Params.dropout)
80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81 | model.to(device)
82 |
83 | # load loss function
84 | loss_fn = nn.MSELoss()
85 |
86 | # load optimizer
87 | optim = AdamW(model.parameters(), lr=Params.learning_rate)
88 | total_steps = len(train_loader)
89 | scheduler = get_linear_schedule_with_warmup(optim,
90 | num_warmup_steps = Params.warmup_steps, # Default value in run_glue.py
91 | num_training_steps = total_steps)
92 |
93 |
94 | best_loss = 100000
95 | for epoch in range(Params.epoches):
96 | model.train()
97 | batch_num = 0
98 | epoch_losses = []
99 | epoch_acces = []
100 | for q1, q2, label in tqdm(train_loader):
101 |
102 | q1_input_ids = q1['input_ids'].to(device)
103 | q1_attention_mask = q1['attention_mask'].to(device)
104 | q1_token_type_ids = q1['token_type_ids'].to(device)
105 |
106 | q2_input_ids = q2['input_ids'].to(device)
107 | q2_attention_mask = q2['attention_mask'].to(device)
108 | q2_token_type_ids = q2['token_type_ids'].to(device)
109 |
110 | label = label.float().to(device)
111 | # print(q1_input_ids, q2_input_ids, label)
112 |
113 | optim.zero_grad()
114 | q1_embedding = model(q1_input_ids, q1_attention_mask, q1_token_type_ids)
115 | q2_embedding = model(q2_input_ids, q2_attention_mask, q2_token_type_ids)
116 | similarity = torch.cosine_similarity(q1_embedding, q2_embedding, dim=1)
117 | batch_loss = compute_loss(similarity, label, loss_fn)
118 | # print("batch loss: {}, type: {}".format(batch_loss.item(), batch_loss.dtype))
119 | batch_acc = compute_acc(similarity, label)
120 |
121 | # 梯度更新+裁剪
122 | batch_loss.backward()
123 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
124 |
125 | # 参数更新
126 | optim.step()
127 | scheduler.step()
128 |
129 | epoch_losses.append(batch_loss.item())
130 | epoch_acces.append(batch_acc)
131 |
132 | if batch_num % Params.display_interval == 0:
133 | print("Epoch: {}, batch: {}/{}, loss: {}, acc: {}".format(epoch, batch_num, total_steps, batch_loss, batch_acc), flush=True)
134 | batch_num += 1
135 |
136 | epoch_avg_loss = np.mean(epoch_losses)
137 | epoch_avg_acc = np.mean(epoch_acces)
138 | print("Epoch: {}, avg loss: {}, acc: {}".format(epoch, epoch_avg_loss, epoch_avg_acc), flush=True)
139 | if epoch_avg_loss < best_loss:
140 | test_avg_loss, test_avg_acc = testing(model, test_loader)
141 |
142 | print("Epoch: {}, train loss: {}, acc: {}, test loss: {}, acc: {}, save best model".format(epoch, epoch_avg_loss, epoch_avg_acc, test_avg_loss, test_avg_acc), flush=True)
143 | torch.save({
144 | 'epoch': epoch,
145 | 'model_state_dict': model.state_dict(),
146 | 'train_loss': epoch_avg_loss,
147 | 'train_acc': epoch_avg_acc,
148 | 'test_loss': test_avg_loss,
149 | 'test_acc': test_avg_acc
150 | }, Params.sbert_model)
151 | best_loss = epoch_avg_loss
152 |
153 |
154 |
--------------------------------------------------------------------------------
/simcse/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/simcse/.DS_Store
--------------------------------------------------------------------------------
/simcse/SimCSE.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 | import torch.nn as nn
6 | from transformers import BertConfig, BertModel
7 | import torch
8 |
9 |
10 | class SimCSE(nn.Module):
11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3):
12 | super().__init__()
13 | conf = BertConfig.from_pretrained(pretrained)
14 | conf.attention_probs_dropout_prob = dropout_prob
15 | conf.hidden_dropout_prob = dropout_prob
16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf)
17 | assert pool_type in ["cls", "pooler", "avg_first_last", "avg_last_two"], "invalid pool_type: %s" % pool_type
18 | self.pool_type = pool_type
19 |
20 |
21 | def forward(self, input_ids, attention_mask, token_type_ids):
22 | output = self.encoder(input_ids,
23 | attention_mask=attention_mask,
24 | token_type_ids=token_type_ids,
25 | output_hidden_states=True)
26 | hidden_states = output.hidden_states
27 | if self.pool_type == "cls":
28 | output = output.last_hidden_state[:, 0]
29 | elif self.pool_type == "pooler":
30 | output = output.pooler_output
31 | elif self.pool_type == "avg_first_last":
32 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask)
33 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
34 | output = (top_first_state + last_first_state) / 2
35 | else:
36 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask)
37 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask)
38 | output = (last_first_state + last_second_state) / 2
39 | return output
40 |
41 |
42 | def get_avg_tensor(self, layer_hidden_state, attention_mask):
43 | '''
44 | layer_hidden_state: 模型一层表征向量 [B * L * D]
45 | attention_mask: 句子padding位置 [B * L]
46 | return: avg_embeddings, 非零位置词语的平均向量 [B * D]
47 | '''
48 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float()
49 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1)
50 | sum_mask = input_mask_expanded.sum(1)
51 | sum_mask = torch.clamp(sum_mask, min=1e-9)
52 | avg_embeddings = sum_embeddings / sum_mask
53 | return avg_embeddings
--------------------------------------------------------------------------------
/simcse/loading.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class MatchingDataSet:
9 | def read_train_file(self, trainfile, devfile, testfile, filetype):
10 | sents = []
11 | if filetype == "txt":
12 | with open(trainfile, 'r') as f:
13 | for line in f.readlines():
14 | _, s1, s2, _ = line.strip().split(u"||")
15 | sents.append(s1)
16 | sents.append(s2)
17 | with open(devfile, 'r') as f:
18 | for line in f.readlines():
19 | _, s1, s2, _ = line.strip().split(u"||")
20 | sents.append(s1)
21 | sents.append(s2)
22 | with open(testfile, 'r') as f:
23 | for line in f.readlines():
24 | _, s1, s2, _ = line.strip().split(u"||")
25 | sents.append(s1)
26 | sents.append(s2)
27 | return sents
28 |
29 | def read_eval_file(self, file, filetype):
30 | sents = []
31 | if filetype == "txt":
32 | with open(file, 'r') as f:
33 | for line in f.readlines():
34 | _, s1, s2, s = line.strip().split(u"||")
35 | item = {'sent1': s1,
36 | 'sent2': s2,
37 | 'score': float(s)}
38 | sents.append(item)
39 | return sents
40 |
41 |
42 | if __name__ == "__main__":
43 | trainfile = "../dataset/STS-B/train.txt"
44 | devfile = "../dataset/STS-B/dev.txt"
45 | testfile = "../dataset/STS-B/test.txt"
46 | match_dataset = MatchingDataSet()
47 |
48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt")
49 | print(train_list[:5])
50 |
51 | train_lengths = [len(sentence) for sentence in train_list]
52 | max_len = max(train_lengths)
53 |
54 |
55 | dev_list = match_dataset.read_eval_file(devfile, "txt")
56 | dev_sen1_length = [len(d['sent1']) for d in dev_list]
57 | dev_sen2_length = [len(d['sent2']) for d in dev_list]
58 | max_sen1 = max(dev_sen1_length)
59 | max_sen2 = max(dev_sen2_length)
60 | print(max_len, max_sen1, max_sen2)
61 | # dev_loader = DataLoader(dev_list,
62 | # batch_size=8)
63 | # for batch in dev_loader:
64 | # print(batch)
65 | # exit(0)
--------------------------------------------------------------------------------
/simcse/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/19
3 | # @Author : Maciel
4 |
5 | from loading import MatchingDataSet
6 | import torch
7 | import torch.nn.functional as F
8 | import scipy.stats
9 | from torch.utils.data import DataLoader
10 | import numpy as np
11 | from SimCSE import SimCSE
12 | from transformers import AutoTokenizer
13 | import os
14 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
15 |
16 |
17 | def eval(model, tokenizer, test_loader, device, max_length):
18 | model.eval()
19 | model.to(device)
20 |
21 | all_sims, all_scores = [], []
22 | with torch.no_grad():
23 | for data in test_loader:
24 | sent1 = data['sent1']
25 | sent2 = data['sent2']
26 | score = data['score']
27 | sent1_encoding = tokenizer(sent1,
28 | padding=True,
29 | truncation=True,
30 | max_length=max_length,
31 | return_tensors='pt')
32 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()}
33 | sent2_encoding = tokenizer(sent2,
34 | padding=True,
35 | truncation=True,
36 | max_length=max_length,
37 | return_tensors='pt')
38 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()}
39 |
40 | sent1_output = model(**sent1_encoding)
41 | sent2_output = model(**sent2_encoding)
42 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
43 | all_sims += sim_score
44 | all_scores += score.tolist()
45 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
46 | return corr
47 |
48 |
49 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length):
50 | match_dataset = MatchingDataSet()
51 | testfile_type = "txt"
52 | test_list = match_dataset.read_eval_file(testfile, testfile_type)
53 | print("test samples num: {}".format(len(test_list)))
54 |
55 | test_loader = DataLoader(test_list,
56 | batch_size=128)
57 | print("test batch num: {}".format(len(test_loader)))
58 |
59 | tokenizer = AutoTokenizer.from_pretrained(pretrained)
60 | model = SimCSE(pretrained, pool_type, dropout_rate)
61 | model.load_state_dict(torch.load(model_path)['model_state_dict'])
62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63 |
64 | test_corr = eval(model, tokenizer, test_loader, device, max_length)
65 | print("test corr: {}".format(test_corr))
66 |
67 |
68 | if __name__ == "__main__":
69 | testfile = "../dataset/STS-B/test.txt"
70 | pretrained = "hfl/chinese-roberta-wwm-ext-large"
71 | pool_type = "avg_first_last"
72 | dropout_rate = 0.3
73 | max_length = 128
74 | model_path = "../models/simcse_roberta_large_stsb.pth"
75 |
76 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length)
77 |
--------------------------------------------------------------------------------
/simcse/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/9/17
3 | # @Author : Maciel
4 |
5 |
6 | import argparse
7 | import os
8 | from loading import MatchingDataSet
9 | import torch
10 | import torch.nn.functional as F
11 | import scipy.stats
12 | from torch.utils.data import DataLoader
13 | from transformers import AutoTokenizer
14 | import numpy as np
15 | from SimCSE import SimCSE
16 | from loguru import logger
17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
18 | logger.add("../runtime.log")
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path")
24 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path")
25 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path")
26 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type")
27 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext-large", help="huggingface pretrained model")
28 | parser.add_argument("--model_out", type=str, default="../models/simcse_roberta_large_stsb.pth", help="model output path")
29 | # parser.add_argument("--num_proc", type=int, default=1, help="dataset process thread num")
30 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length")
31 | parser.add_argument("--batch_size", type=int, default=16, help="batch size")
32 | parser.add_argument("--epochs", type=int, default=100, help="epochs")
33 | parser.add_argument("--lr", type=float, default=3e-5, help="learning rate")
34 | parser.add_argument("--tao", type=float, default=0.05, help="temperature")
35 | parser.add_argument("--device", type=str, default="cuda", help="device")
36 | parser.add_argument("--display_interval", type=int, default=100, help="display interval")
37 | # parser.add_argument("--save_interval", type=int, default=10, help="save interval")
38 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type")
39 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate")
40 | parser.add_argument("--task", type=str, default="simcse", help="task name")
41 | args = parser.parse_args()
42 | return args
43 |
44 |
45 | def duplicate_batch(batch):
46 | '''
47 | 重复两次数据
48 | '''
49 | new_batch = []
50 | for sentence in batch:
51 | new_batch += [sentence, sentence]
52 | return new_batch
53 |
54 |
55 | def compute_loss(y_pred, tao=0.05, device="cuda"):
56 | idxs = torch.arange(0, y_pred.shape[0], device=device)
57 | y_true = idxs + 1 - idxs % 2 * 2
58 | similarities = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=2)
59 | similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12
60 | similarities = similarities / tao
61 | loss = F.cross_entropy(similarities, y_true)
62 | return torch.mean(loss)
63 |
64 |
65 | def eval(model, tokenizer, dev_loader, args):
66 | model.eval()
67 | model.to(args.device)
68 |
69 | all_sims, all_scores = [], []
70 | with torch.no_grad():
71 | for data in dev_loader:
72 | sent1 = data['sent1']
73 | sent2 = data['sent2']
74 | score = data['score']
75 | sent1_encoding = tokenizer(sent1,
76 | padding=True,
77 | truncation=True,
78 | max_length=args.max_length,
79 | return_tensors='pt')
80 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()}
81 | sent2_encoding = tokenizer(sent2,
82 | padding=True,
83 | truncation=True,
84 | max_length=args.max_length,
85 | return_tensors='pt')
86 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()}
87 |
88 | sent1_output = model(**sent1_encoding)
89 | sent2_output = model(**sent2_encoding)
90 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist()
91 | all_sims += sim_score
92 | all_scores += score.tolist()
93 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation
94 | return corr
95 |
96 |
97 | def train(args):
98 | train_file = args.trainfile
99 | dev_file = args.devfile
100 | test_file = args.testfile
101 | file_type = args.filetype
102 | match_dataset = MatchingDataSet()
103 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type)
104 | dev_list = match_dataset.read_eval_file(dev_file, file_type)
105 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list)))
106 |
107 | train_loader = DataLoader(train_list,
108 | batch_size=args.batch_size,
109 | shuffle=True)
110 | dev_loader = DataLoader(dev_list,
111 | batch_size=args.batch_size)
112 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader)))
113 |
114 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
115 | model = SimCSE(args.pretrained, args.pool_type, args.dropout_rate)
116 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
117 | model.train()
118 | model.to(args.device)
119 |
120 | batch_idx = 0
121 | best_corr = 0
122 | best_loss = 1000000
123 | for epoch in range(args.epochs):
124 | epoch_losses = []
125 | for data in train_loader:
126 | batch_idx += 1
127 | batch_data = duplicate_batch(data)
128 | encodings = tokenizer(batch_data,
129 | padding=True,
130 | truncation=True,
131 | max_length=args.max_length,
132 | return_tensors='pt')
133 | encodings = {key: value.to(args.device) for key, value in encodings.items()}
134 | output = model(**encodings)
135 | batch_loss = compute_loss(output, args.tao, args.device)
136 | optimizer.zero_grad()
137 | batch_loss.backward()
138 | optimizer.step()
139 | epoch_losses.append(batch_loss.item())
140 | if batch_idx % args.display_interval == 0:
141 | logger.info("epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item()))
142 | avg_epoch_loss = np.mean(epoch_losses)
143 | dev_corr = eval(model, tokenizer, dev_loader, args)
144 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr))
145 | if dev_corr >= best_corr and avg_epoch_loss <= best_loss:
146 | best_corr = dev_corr
147 | best_loss = avg_epoch_loss
148 | torch.save({
149 | 'epoch': epoch,
150 | 'batch': batch_idx,
151 | 'model_state_dict': model.state_dict(),
152 | 'loss': best_loss,
153 | 'corr': best_corr
154 | }, args.model_out)
155 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr))
156 |
157 |
158 | if __name__ == "__main__":
159 | args = parse_args()
160 | train(args)
161 |
--------------------------------------------------------------------------------