├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── cosent ├── config.py ├── loading.py ├── model.py ├── test.py └── train.py ├── dataset ├── .DS_Store └── STS-B │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── esimcse ├── ESimCSE.py ├── loading.py ├── test.py └── train.py ├── matching ├── README.md ├── config.py ├── cosent.py ├── esimcse.py ├── log.py ├── promptbert.py ├── retrieval.py ├── sbert.py ├── simcse.py └── view.py ├── pics ├── .DS_Store ├── cosent_loss.svg ├── esimcse.png ├── esimcse_loss.svg ├── promptbert_loss.svg ├── sbert.png ├── simcse.png └── simcse_loss.svg ├── promptbert ├── PromptBert.py ├── loading.py ├── test.py └── train.py ├── sbert ├── config.py ├── loading.py ├── model.py └── train.py └── simcse ├── .DS_Store ├── SimCSE.py ├── loading.py ├── test.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## sentence_representation_matching 2 | 3 | 该项目主要是文本匹配相关模型,包含使用SimCSE、ESimCSE、PromptBert三种无监督文本匹配模型和SBert、CoSent两种有监督文本匹配模型。 4 | 5 | 6 | 7 | ### 无监督文本匹配 8 | 9 | #### 1. SimCSE 10 | 11 | 利用Transformer Dropout机制,使用两次作为正样本对比,以此来拉近正样本,推开负样本。 12 | 13 | ##### 模型结构: 14 | 15 | ![simcse](pics/simcse.png) 16 | 17 | 18 | 19 | ##### 损失函数: 20 | 21 | ![simcse_loss](pics/simcse_loss.svg) 22 | 23 | 24 | 25 | ##### 模型效果: 26 | 27 | | data | Pertained | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr | 28 | | :---- | :-------------------------- | :------------: | :-----: | :--------: | -------- | ------------------------------------ | 29 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.1 | 64 | 0.76076 | 0.70924 | 30 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.2 | 64 | 0.75996 | **0.71474** | 31 | | STS-B | hfl/chinese-bert-wwm-ext | avg_first_last | 0.3 | 64 | 0.76518 | 0.71237 | 32 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.1 | 64 | 0.75933 | 0.69070 | 33 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.2 | 64 | 0.76907 | **0.72410** | 34 | | STS-B | hfl/chinese-roberta-wwm-ext | avg_first_last | 0.3 | 64 | 0.77203 | 0.72155 | 35 | 36 | 参考: 37 | 38 | 1)https://github.com/princeton-nlp/SimCSE 39 | 40 | 2)https://github.com/KwangKa/SIMCSE_unsup 41 | 42 | 3)https://arxiv.org/pdf/2104.08821.pdf 43 | 44 | 45 | 46 | #### 2. ESimCSE 47 | 48 | 在SimCSE的基础上,通过重复句子中部分词组来构造正样本,同时引入动量对比来增加负样本。 49 | 50 | ##### 模型结构: 51 | 52 | ![esimcse](pics/esimcse.png) 53 | 54 | 55 | 56 | ##### 损失函数: 57 | 58 | ![esimcse_loss](pics/esimcse_loss.svg) 59 | 60 | 61 | 62 | ##### 模型效果: 63 | 64 | | data | Pertained | Dup_rate | Queue_num | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr | 65 | | ----- | :-------------------------- | -------- | --------- | :------------: | ------- | :--------: | -------- | ------------------------------------- | 66 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.1 | 64 | 0.77274 | 0.69639 | 67 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.2 | 64 | 0.77047 | 0.70042 | 68 | | STS-B | hfl/chinese-bert-wwm-ext | 0.2 | 32 | avg_first_last | 0.3 | 64 | 0.77963 | **0.72478** | 69 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.1 | 64 | 0.77508 | 0.7206 | 70 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.2 | 64 | 0.77416 | 0.7096 | 71 | | STS-B | hfl/chinese-roberta-wwm-ext | 0.3 | 64 | avg_first_last | 0.3 | 64 | 0.78093 | **0.72495** | 72 | 73 | 参考:https://arxiv.org/pdf/2109.04380.pdf 74 | 75 | 76 | 77 | #### 3. PromptBert 78 | 79 | 使用Prompt方式来表征语义向量,通过不同模板产生的语义向量构造正样本,同一批次中的其他样本作为负样本。 80 | 81 | ##### 损失函数: 82 | 83 | ![promptbert_loss](pics/promptbert_loss.svg) 84 | 85 | ``` 86 | 本实验使用两个句子模板: 87 | 1)[X],它的意思是[MASK]。 88 | 2)[X],这句话的意思是[MASK]。 89 | 90 | 在计算损失函数时为了消除Prompt模板影响,通过替换模板后的句子[MASK]获取的表征减去模板中[MASK]获取的表征来得到句子向量表征。 91 | ``` 92 | 93 | 94 | 95 | ##### 模型效果: 96 | 97 | | data | Pertained | Pool_type | Dropout | Batch_size | Dev_corr | Test_corr | 98 | | ----- | :-------------------------- | :-------: | ------- | :--------: | -------- | ------------------------------------ | 99 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.1 | 32 | 0.78216 | **0.73185** | 100 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.2 | 32 | 0.78362 | 0.73129 | 101 | | STS-B | hfl/chinese-bert-wwm-ext | x_index | 0.3 | 32 | 0.76617 | 0.71597 | 102 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.1 | 32 | 0.79963 | **0.73492** | 103 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.2 | 32 | 0.7764 | 0.72024 | 104 | | STS-B | hfl/chinese-roberta-wwm-ext | x_index | 0.3 | 32 | 0.77875 | 0.73153 | 105 | 106 | 参考:https://arxiv.org/pdf/2201.04337.pdf 107 | 108 | 109 | 110 | #### 4. 模型对比 111 | 112 | 通过各组对比试验,并挑选模型最优测试集结果展示如下。 113 | 114 | | Model | Pertained-Bert | Pretrained-Roberta | 115 | | ---------- | :------------: | :----------------: | 116 | | SimCSE | 0.71474 | 0.72410 | 117 | | ESimCSE | 0.72478 | 0.72495 | 118 | | PromptBERT | 0.73185 | 0.73492 | 119 | 120 | 121 | 122 | ### 有监督文本匹配 123 | 124 | #### 1. SBert 125 | 126 | 使用双塔式来微调Bert,MSE损失函数来拟合文本之间的cosine相似度。 127 | 128 | 模型结构: 129 | 130 | ![SBERT Siamese Network Architecture](pics/sbert.png) 131 | 132 | 参考:https://www.sbert.net/docs/training/overview.html 133 | 134 | 135 | 136 | #### 2. CoSent 137 | 138 | 构造一个排序式损失函数,即所有正样本对的距离都应该小于负样本对的距离,具体小多少由模型和数据决定,没有一个绝对关系。 139 | 140 | 损失函数: 141 | 142 | ![cosent_loss](pics/cosent_loss.svg) 143 | 144 | 参考: 145 | 146 | 1)https://spaces.ac.cn/archives/8847 147 | 148 | 2)https://github.com/shawroad/CoSENT_Pytorch 149 | -------------------------------------------------------------------------------- /cosent/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Params: 4 | epoches = 50 5 | max_length = 32 6 | batch_size = 64 7 | dropout = 0.15 8 | learning_rate = 3e-5 9 | threshold = 0.5 10 | gradient_accumulation_steps = 100 11 | display_steps = 500 12 | pooler_type = "mean" 13 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large" 14 | cosent_model = "models/cosent.pth" 15 | -------------------------------------------------------------------------------- /cosent/loading.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import Dataset 4 | from transformers import AutoTokenizer 5 | from torch.utils.data import DataLoader 6 | from config import Params 7 | 8 | 9 | class Loader(): 10 | def __init__(self, data_file): 11 | data_df = pd.read_csv(data_file) 12 | self.data_df = data_df.fillna("") 13 | 14 | 15 | def get_dataset(self): 16 | sentences = [] 17 | labels = [] 18 | for _, row in self.data_df.iterrows(): 19 | question1 = row['question1'] 20 | question2 = row['question2'] 21 | label = row['label'] 22 | sentences.extend([question1, question2]) 23 | labels.extend([label, label]) 24 | # rows.append([question1, label]) 25 | # rows.append([question2, label]) 26 | return sentences, labels 27 | 28 | 29 | class CustomerDataset(Dataset): 30 | def __init__(self, sentence, label, tokenizer): 31 | self.sentence = sentence 32 | self.label = label 33 | self.tokenizer = tokenizer 34 | 35 | 36 | def __len__(self): 37 | return len(self.sentence) 38 | 39 | 40 | def __getitem__(self, index): 41 | input_encodings = self.tokenizer(self.sentence[index]) 42 | input_ids = input_encodings['input_ids'] 43 | attention_mask = input_encodings['attention_mask'] 44 | token_type_ids = input_encodings['token_type_ids'] 45 | item = {'input_ids': input_ids, 46 | 'attention_mask': attention_mask, 47 | 'token_type_ids': token_type_ids, 48 | 'label': self.label[index]} 49 | return item 50 | 51 | 52 | def pad_to_maxlen(input_ids, max_len, pad_value=0): 53 | if len(input_ids) >= max_len: 54 | input_ids = input_ids[:max_len] 55 | else: 56 | input_ids = input_ids + [pad_value] * (max_len - len(input_ids)) 57 | return input_ids 58 | 59 | 60 | def collate_fn(batch): 61 | # 按batch进行padding获取当前batch中最大长度 62 | max_len = max([len(d['input_ids']) for d in batch]) 63 | 64 | # 如果当前最长长度超过设定的全局最大长度,则取全局最大长度 65 | max_len = max_len if max_len <= Params.max_length else Params.max_length 66 | 67 | input_ids, attention_mask, token_type_ids, labels = [], [], [], [] 68 | 69 | for item in batch: 70 | input_ids.append(pad_to_maxlen(item['input_ids'], max_len=max_len)) 71 | attention_mask.append(pad_to_maxlen(item['attention_mask'], max_len=max_len)) 72 | token_type_ids.append(pad_to_maxlen(item['token_type_ids'], max_len=max_len)) 73 | labels.append(item['label']) 74 | 75 | all_input_ids = torch.tensor(input_ids, dtype=torch.long) 76 | all_input_mask = torch.tensor(attention_mask, dtype=torch.long) 77 | all_segment_ids = torch.tensor(token_type_ids, dtype=torch.long) 78 | all_label_ids = torch.tensor(labels, dtype=torch.float) 79 | return all_input_ids, all_input_mask, all_segment_ids, all_label_ids 80 | 81 | 82 | if __name__ == "__main__": 83 | data_file = "data/train_dataset.csv" 84 | loader = Loader(data_file) 85 | sentence, label = loader.get_dataset() 86 | print("load question file done!") 87 | 88 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model) 89 | dataset = CustomerDataset(sentence, label, tokenizer) 90 | # print(dataset[0]) 91 | train_loader = DataLoader(dataset, 92 | shuffle=False, 93 | batch_size=Params.batch_size, 94 | collate_fn=collate_fn) 95 | 96 | for dl in train_loader: 97 | print(dl) 98 | exit(0) -------------------------------------------------------------------------------- /cosent/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/5 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import AutoConfig, AutoModel 7 | import torch 8 | 9 | 10 | class CoSent(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = AutoConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | if self.pool_type == "cls": 23 | output = self.encoder(input_ids, 24 | attention_mask=attention_mask, 25 | token_type_ids=token_type_ids) 26 | output = output.last_hidden_state[:, 0] 27 | elif self.pool_type == "pooler": 28 | output = self.encoder(input_ids, 29 | attention_mask=attention_mask, 30 | token_type_ids=token_type_ids) 31 | output = output.pooler_output 32 | elif self.pool_type == "mean": 33 | output = self.get_mean_tensor(input_ids, attention_mask) 34 | return output 35 | 36 | 37 | def get_mean_tensor(self, input_ids, attention_mask): 38 | ''' 39 | get first and last layer avg tensor 40 | ''' 41 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True) 42 | hidden_states = encode_states.hidden_states 43 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 44 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask) 45 | mean_avg_state = (last_avg_state + first_avg_state) / 2 46 | return mean_avg_state 47 | 48 | 49 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 50 | ''' 51 | layer_hidden_state: 模型一层表征向量 [B * L * D] 52 | attention_mask: 句子padding位置 [B * L] 53 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 54 | ''' 55 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 56 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 57 | sum_mask = input_mask_expanded.sum(1) 58 | sum_mask = torch.clamp(sum_mask, min=1e-9) 59 | avg_embeddings = sum_embeddings / sum_mask 60 | return avg_embeddings 61 | 62 | 63 | def get_avg_tensor2(self, layer_hidden_state, attention_mask): 64 | ''' 65 | layer_hidden_state: 模型一层表征向量 [B * L * D] 66 | attention_mask: 句子padding位置 [B * L] 67 | return: 非零位置词语的平均向量 [B * D] 68 | ''' 69 | layer_hidden_dim = layer_hidden_state.shape[-1] 70 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 71 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask) 72 | layer_sum_state = layer_attention_state.sum(dim=1) 73 | # print(last_attention_state.shape) 74 | 75 | attention_length_mask = attention_mask.sum(dim=-1) 76 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 77 | # print(attention_length_repeat_mask.shape) 78 | 79 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask) 80 | return layer_avg_state -------------------------------------------------------------------------------- /cosent/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from transformers import AutoTokenizer 4 | from model import CoSent 5 | from config import Params 6 | 7 | 8 | def read_test_data(test_file): 9 | test_df = pd.read_csv(test_file) 10 | test_df = test_df.fillna("") 11 | return test_df 12 | 13 | 14 | def split_similarity(similarity, threshold): 15 | if similarity >= threshold: 16 | return 1 17 | else: 18 | return 0 19 | 20 | 21 | def calculate_accuracy(data_df, threshold): 22 | data_df['pred'] = data_df.apply(lambda x: split_similarity(x['similarity'], threshold), axis=1) 23 | pred_correct = data_df[data_df['pred'] == data_df['label']] 24 | # pred_error = data_df[data_df['pred'] != data_df['label']] 25 | # print(pred_error) 26 | accuracy = len(pred_correct) / len(data_df) 27 | return accuracy, len(pred_correct), len(data_df) 28 | 29 | 30 | class CoSentRetrieval(object): 31 | def __init__(self, pretrained_model_path, cosent_path): 32 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 33 | model = CoSent(Params.pretrained_model, Params.pooler_type, Params.dropout) 34 | self.checkpoint = torch.load(cosent_path, map_location='cpu') 35 | model.load_state_dict(self.checkpoint['model_state_dict']) 36 | model.eval() 37 | self.model = model 38 | 39 | 40 | def print_checkpoint_info(self): 41 | loss = self.checkpoint['loss'] 42 | epoch = self.checkpoint['epoch'] 43 | model_info = {'loss': loss, 'epoch': epoch} 44 | return model_info 45 | 46 | 47 | def calculate_sentence_embedding(self, sentence): 48 | device = "cpu" 49 | input_encodings = self.tokenizer(sentence, 50 | padding=True, 51 | truncation=True, 52 | max_length=Params.max_length, 53 | return_tensors='pt') 54 | sentence_embedding = self.model(input_encodings['input_ids'].to(device), 55 | input_encodings['attention_mask'].to(device), 56 | input_encodings['token_type_ids'].to(device)) 57 | return sentence_embedding 58 | 59 | 60 | def calculate_sentence_similarity(self, sentence1, sentence2): 61 | sentence1 = sentence1.strip() 62 | sentence2 = sentence2.strip() 63 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 64 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 65 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 66 | similarity = float(similarity.item()) 67 | return similarity 68 | 69 | 70 | cosent_retrieval = CoSentRetrieval(Params.pretrained_model, Params.cosent_model) 71 | model_info = cosent_retrieval.print_checkpoint_info() 72 | print("load model done, model_info: {}".format(model_info)) 73 | test_file = "data/test_dataset.csv" 74 | test_df = read_test_data(test_file) 75 | results = [] 76 | for rid, row in test_df.iterrows(): 77 | question1 = row['question1'] 78 | question2 = row['question2'] 79 | label = row['label'] 80 | similarity = cosent_retrieval.calculate_sentence_similarity(question1, question2) 81 | item = {'question1': question1, 82 | 'question2': question2, 83 | 'similarity': similarity, 84 | 'label': label} 85 | if rid % 100 == 0: 86 | print("rid: {}, item: {}".format(rid, item)) 87 | results.append(item) 88 | print("prediction done!") 89 | 90 | 91 | pred_df = pd.DataFrame(results) 92 | pred_file = "results/test_pred.csv" 93 | pred_df.to_csv(pred_file) 94 | max_acc = 0 95 | for t in range(50, 80, 1): 96 | t = t / 100 97 | acc, correct, num = calculate_accuracy(pred_df, t) 98 | if acc > max_acc: 99 | max_acc = acc 100 | print(t, acc, correct, num) 101 | 102 | print(max_acc) -------------------------------------------------------------------------------- /cosent/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoTokenizer 5 | from transformers import AdamW, get_linear_schedule_with_warmup 6 | from sklearn.metrics import accuracy_score 7 | from tqdm import tqdm 8 | from loading import Loader, CustomerDataset, collate_fn 9 | from config import Params 10 | from model import CoSent 11 | import os 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 13 | 14 | 15 | def compute_sim(y_pred): 16 | # 1. 对输出的句子向量进行l2归一化 后面只需要对应为相乘 就可以得到cos值了 17 | norms = (y_pred ** 2).sum(axis=1, keepdims=True) ** 0.5 18 | # y_pred = y_pred / torch.clip(norms, 1e-8, torch.inf) 19 | y_pred = y_pred / norms 20 | 21 | # 2. 奇偶向量相乘 22 | sim = torch.sum(y_pred[::2] * y_pred[1::2], dim=1) 23 | return sim 24 | 25 | 26 | def compute_acc(y_true, y_sim, threshold): 27 | # 1. 取出真实的标签(每两行是一个文本匹配对) 28 | y_true = y_true[::2] # tensor([1, 0, 1]) 真实的标签 29 | 30 | # 2. 根据阈值分割 31 | y_pred_label = (y_sim >= threshold).float() 32 | acc = accuracy_score(y_pred_label.detach().cpu().numpy(), y_true.cpu().numpy()) 33 | return acc 34 | 35 | 36 | def compute_loss(y_true, y_sim): 37 | # 1. 取出真实的标签(每两行是一个文本匹配对) 38 | y_true = y_true[::2] # tensor([1, 0, 1]) 真实的标签 39 | 40 | # 2. 根据句子间相似度进行放缩 41 | y_sim = y_sim * 20 42 | 43 | # 3. 取出负例-正例的差值 44 | y_sim = y_sim[:, None] - y_sim[None, :] # 这里是算出所有位置 两两之间余弦的差值 45 | # 矩阵中的第i行j列 表示的是第i个余弦值-第j个余弦值 46 | y_true = y_true[:, None] < y_true[None, :] # 取出负例-正例的差值 47 | y_true = y_true.float() 48 | y_sim = y_sim - (1 - y_true) * 1e12 49 | y_sim = y_sim.view(-1) 50 | if torch.cuda.is_available(): 51 | y_sim = torch.cat((torch.tensor([0]).float().cuda(), y_sim), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1 52 | else: 53 | y_sim = torch.cat((torch.tensor([0]).float(), y_sim), dim=0) # 这里加0是因为e^0 = 1相当于在log中加了1 54 | 55 | return torch.logsumexp(y_sim, dim=0) 56 | 57 | 58 | data_file = "data/train_dataset.csv" 59 | loader = Loader(data_file) 60 | sentence, label = loader.get_dataset() 61 | print("load question file done!") 62 | 63 | # load tokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model) 65 | train_dataset = CustomerDataset(sentence, label, tokenizer) 66 | train_loader = DataLoader(train_dataset, 67 | shuffle=False, 68 | batch_size=Params.batch_size, 69 | collate_fn=collate_fn) 70 | print("tokenize all batch done!") 71 | 72 | total_steps = len(train_loader) * Params.epoches 73 | train_optimization_steps = int(len(train_dataset) / Params.batch_size) * Params.epoches 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | # load model 77 | model = CoSent(Params.pretrained_model, Params.pooler_type, Params.dropout) 78 | model.to(device) 79 | 80 | param_optimizer = list(model.named_parameters()) 81 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 82 | optimizer_grouped_parameters = [ 83 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 84 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 85 | ] 86 | 87 | # load optimizer and scheduler 88 | optimizer = AdamW(optimizer_grouped_parameters, lr=Params.learning_rate) 89 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.05 * total_steps, 90 | num_training_steps=total_steps) 91 | 92 | best_loss = 1000000 93 | print("start training...") 94 | for epoch in range(Params.epoches): 95 | model.train() 96 | epoch_losses = [] 97 | epoch_acces = [] 98 | step = 0 99 | 100 | for batch in tqdm(train_loader): 101 | input_ids, input_mask, segment_ids, label_ids = batch 102 | input_ids = input_ids.to(device) 103 | input_mask = input_mask.to(device) 104 | segment_ids = segment_ids.to(device) 105 | label_ids = label_ids.to(device) 106 | 107 | output = model(input_ids, input_mask, segment_ids) 108 | y_sim = compute_sim(output) 109 | loss = compute_loss(label_ids, y_sim) 110 | acc = compute_acc(label_ids, y_sim, Params.threshold) 111 | 112 | # update gradient and scheduler 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | scheduler.step() 117 | 118 | if step % Params.display_steps == 0: 119 | print("Epoch: {}, Step: {}, Batch loss: {}, acc: {}".format(epoch, step, loss.item(), acc), flush=True) 120 | epoch_losses.append(loss.item()) 121 | epoch_acces.append(acc) 122 | step += 1 123 | 124 | avg_epoch_loss = np.mean(epoch_losses) 125 | avg_epoch_acc = np.mean(epoch_acces) 126 | print("Epoch: {}, avg loss: {}, acc: {}".format(epoch, avg_epoch_loss, avg_epoch_acc), flush=True) 127 | if avg_epoch_loss < best_loss: 128 | best_loss = avg_epoch_loss 129 | print("Epoch: {}, best loss: {}, acc: {} save model".format(epoch, best_loss, avg_epoch_acc), flush=True) 130 | torch.save({ 131 | 'epoch': epoch, 132 | 'model_state_dict': model.state_dict(), 133 | 'loss': avg_epoch_loss, 134 | 'acc': avg_epoch_acc, 135 | }, Params.cosent_model) -------------------------------------------------------------------------------- /dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/dataset/.DS_Store -------------------------------------------------------------------------------- /esimcse/ESimCSE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import AutoConfig, AutoModel 7 | import torch 8 | 9 | 10 | class ESimCSE(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = AutoConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "avg_first_last", "avg_last_two"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | output = self.encoder(input_ids, 23 | attention_mask=attention_mask, 24 | token_type_ids=token_type_ids, 25 | output_hidden_states=True) 26 | hidden_states = output.hidden_states 27 | if self.pool_type == "cls": 28 | output = output.last_hidden_state[:, 0] 29 | elif self.pool_type == "pooler": 30 | output = output.pooler_output 31 | elif self.pool_type == "avg_first_last": 32 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask) 33 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 34 | output = (top_first_state + last_first_state) / 2 35 | else: 36 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 37 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask) 38 | output = (last_first_state + last_second_state) / 2 39 | return output 40 | 41 | 42 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 43 | ''' 44 | layer_hidden_state: 模型一层表征向量 [B * L * D] 45 | attention_mask: 句子padding位置 [B * L] 46 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 47 | ''' 48 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 49 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 50 | sum_mask = input_mask_expanded.sum(1) 51 | sum_mask = torch.clamp(sum_mask, min=1e-9) 52 | avg_embeddings = sum_embeddings / sum_mask 53 | return avg_embeddings -------------------------------------------------------------------------------- /esimcse/loading.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class MatchingDataSet: 9 | def read_train_file(self, trainfile, devfile, testfile, filetype): 10 | sents = [] 11 | if filetype == "txt": 12 | with open(trainfile, 'r') as f: 13 | for line in f.readlines(): 14 | _, s1, s2, _ = line.strip().split(u"||") 15 | sents.append(s1) 16 | sents.append(s2) 17 | with open(devfile, 'r') as f: 18 | for line in f.readlines(): 19 | _, s1, s2, _ = line.strip().split(u"||") 20 | sents.append(s1) 21 | sents.append(s2) 22 | with open(testfile, 'r') as f: 23 | for line in f.readlines(): 24 | _, s1, s2, _ = line.strip().split(u"||") 25 | sents.append(s1) 26 | sents.append(s2) 27 | return sents 28 | 29 | def read_eval_file(self, file, filetype): 30 | sents = [] 31 | if filetype == "txt": 32 | with open(file, 'r') as f: 33 | for line in f.readlines(): 34 | _, s1, s2, s = line.strip().split(u"||") 35 | item = {'sent1': s1, 36 | 'sent2': s2, 37 | 'score': float(s)} 38 | sents.append(item) 39 | return sents 40 | 41 | 42 | if __name__ == "__main__": 43 | trainfile = "../dataset/STS-B/train.txt" 44 | devfile = "../dataset/STS-B/dev.txt" 45 | testfile = "../dataset/STS-B/test.txt" 46 | match_dataset = MatchingDataSet() 47 | 48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt") 49 | print(train_list[:5]) 50 | 51 | train_lengths = [len(sentence) for sentence in train_list] 52 | max_len = max(train_lengths) 53 | 54 | 55 | dev_list = match_dataset.read_eval_file(devfile, "txt") 56 | dev_sen1_length = [len(d['sent1']) for d in dev_list] 57 | dev_sen2_length = [len(d['sent2']) for d in dev_list] 58 | max_sen1 = max(dev_sen1_length) 59 | max_sen2 = max(dev_sen2_length) 60 | print(max_len, max_sen1, max_sen2) 61 | # dev_loader = DataLoader(dev_list, 62 | # batch_size=8) 63 | # for batch in dev_loader: 64 | # print(batch) 65 | # exit(0) -------------------------------------------------------------------------------- /esimcse/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/19 3 | # @Author : Maciel 4 | 5 | from loading import MatchingDataSet 6 | import torch 7 | import torch.nn.functional as F 8 | import scipy.stats 9 | from torch.utils.data import DataLoader 10 | from ESimCSE import ESimCSE 11 | from transformers import AutoTokenizer 12 | import os 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | 16 | def eval(model, tokenizer, test_loader, device, max_length): 17 | model.eval() 18 | model.to(device) 19 | 20 | all_sims, all_scores = [], [] 21 | with torch.no_grad(): 22 | for data in test_loader: 23 | sent1 = data['sent1'] 24 | sent2 = data['sent2'] 25 | score = data['score'] 26 | sent1_encoding = tokenizer(sent1, 27 | padding=True, 28 | truncation=True, 29 | max_length=max_length, 30 | return_tensors='pt') 31 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()} 32 | sent2_encoding = tokenizer(sent2, 33 | padding=True, 34 | truncation=True, 35 | max_length=max_length, 36 | return_tensors='pt') 37 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()} 38 | 39 | sent1_output = model(**sent1_encoding) 40 | sent2_output = model(**sent2_encoding) 41 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 42 | all_sims += sim_score 43 | all_scores += score.tolist() 44 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 45 | return corr 46 | 47 | 48 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length): 49 | match_dataset = MatchingDataSet() 50 | testfile_type = "txt" 51 | test_list = match_dataset.read_eval_file(testfile, testfile_type) 52 | print("test samples num: {}".format(len(test_list))) 53 | 54 | test_loader = DataLoader(test_list, 55 | batch_size=128) 56 | print("test batch num: {}".format(len(test_loader))) 57 | 58 | tokenizer = AutoTokenizer.from_pretrained(pretrained) 59 | model = ESimCSE(pretrained, pool_type, dropout_rate) 60 | model.load_state_dict(torch.load(model_path)['model_state_dict']) 61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | 63 | test_corr = eval(model, tokenizer, test_loader, device, max_length) 64 | print("test corr: {}".format(test_corr)) 65 | 66 | 67 | if __name__ == "__main__": 68 | testfile = "../dataset/STS-B/test.txt" 69 | pretrained = "hfl/chinese-roberta-wwm-ext" 70 | pool_type = "avg_first_last" 71 | dropout_rate = 0.1 72 | max_length = 128 73 | model_path = "../models/esimcse_roberta_stsb.pth" 74 | 75 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length) 76 | -------------------------------------------------------------------------------- /esimcse/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/20 3 | # @Author : Maciel 4 | 5 | 6 | from transformers import AutoTokenizer 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | import scipy.stats 12 | from loading import MatchingDataSet 13 | from ESimCSE import ESimCSE 14 | import random 15 | import argparse 16 | from loguru import logger 17 | import copy 18 | import os 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path") 25 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path") 26 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path") 27 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type") 28 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext", help="huggingface pretrained model") 29 | parser.add_argument("--model_out", type=str, default="../models/esimcse_roberta_stsb.pth", help="model output path") 30 | parser.add_argument("--dup_rate", type=float, default=0.2, help="repeat word probability") 31 | parser.add_argument("--queue_size", type=int, default=0.5, help="negative queue num / batch size") 32 | parser.add_argument("--momentum", type=float, default=0.995, help="momentum parameter") 33 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length") 34 | parser.add_argument("--batch_size", type=int, default=64, help="batch size") 35 | parser.add_argument("--epochs", type=int, default=10, help="epochs") 36 | parser.add_argument("--lr", type=float, default=3e-5, help="learning rate") 37 | parser.add_argument("--tao", type=float, default=0.05, help="temperature") 38 | parser.add_argument("--device", type=str, default="cuda", help="device") 39 | parser.add_argument("--display_interval", type=int, default=100, help="display interval") 40 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type") 41 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate") 42 | parser.add_argument("--task", type=str, default="esimcse", help="task name") 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def compute_loss(query, key, queue, tao=0.05): 48 | ''' 49 | @function: 计算对比损失函数 50 | 51 | @input: 52 | query: tensor,查询原句向量 53 | key: tensor,增强原句向量 54 | queue: tensor,历史队列句向量 55 | tao: float,温度系数,超参数,默认0.05 56 | 57 | @return: loss(tensor),损失函数值 58 | ''' 59 | # N: batch, D: dim 60 | N, D = query.shape[0], query.shape[1] 61 | 62 | # calculate positive similarity 63 | pos = torch.exp(torch.div(torch.bmm(query.view(N,1,D), key.view(N,D,1)).view(N,1),tao)) 64 | 65 | # calculate inner_batch similarity 66 | batch_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N,D),torch.t(key)),tao)),dim=1) 67 | # calculate inner_queue similarity 68 | queue_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N,D),torch.t(queue)),tao)),dim=1) 69 | 70 | denominator = batch_all + queue_all 71 | 72 | loss = torch.mean(-torch.log(torch.div(pos, denominator))) 73 | return loss 74 | 75 | 76 | def construct_queue(args, train_loader, tokenizer, key_encoder): 77 | flag = 0 78 | queue_num = int(args.queue_size * args.batch_size) 79 | queue = None 80 | while True: 81 | with torch.no_grad(): 82 | for pid, data in enumerate(train_loader): 83 | # 和初始数据不同的数据作为反例 84 | if pid < 10: 85 | continue 86 | query_encodings = tokenizer(data, 87 | padding=True, 88 | truncation=True, 89 | max_length=args.max_length, 90 | return_tensors='pt') 91 | query_encodings = {key: value.to(args.device) for key, value in query_encodings.items()} 92 | query_embedding = key_encoder(**query_encodings) 93 | if queue is None: 94 | queue = query_embedding 95 | else: 96 | if queue.shape[0] < queue_num: 97 | queue = torch.cat((queue, query_embedding), 0) 98 | else: 99 | flag = 1 100 | if flag == 1: 101 | break 102 | if flag == 1: 103 | break 104 | queue = queue[-queue_num:] 105 | queue = torch.div(queue, torch.norm(queue, dim=1).reshape(-1, 1)) 106 | return queue 107 | 108 | 109 | def repeat_word(tokenizer, sentence, dup_rate): 110 | ''' 111 | @function: 重复句子中的部分token 112 | 113 | @input: 114 | sentence: string,输入语句 115 | 116 | @return: 117 | dup_sentence: string,重复token后生成的句子 118 | ''' 119 | word_tokens = tokenizer.tokenize(sentence) 120 | 121 | # dup_len ∈ [0, max(2, int(dup_rate ∗ N))] 122 | max_len = max(2, int(dup_rate * len(word_tokens))) 123 | # 防止随机挑选的数值大于token数量 124 | dup_len = min(random.choice(range(max_len+1)), len(word_tokens)) 125 | 126 | random_indices = random.sample(range(len(word_tokens)), dup_len) 127 | # print(max_len, dup_len, random_indices) 128 | 129 | dup_word_tokens = [] 130 | for index, word in enumerate(word_tokens): 131 | dup_word_tokens.append(word) 132 | if index in random_indices and "#" not in word: 133 | dup_word_tokens.append(word) 134 | dup_sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(dup_word_tokens)).replace(" ", "") 135 | # dup_sentence = "".join(dup_word_tokens) 136 | return dup_sentence 137 | 138 | 139 | def eval(model, tokenizer, dev_loader, args): 140 | model.eval() 141 | model.to(args.device) 142 | 143 | all_sims, all_scores = [], [] 144 | with torch.no_grad(): 145 | for data in dev_loader: 146 | sent1 = data['sent1'] 147 | sent2 = data['sent2'] 148 | score = data['score'] 149 | sent1_encoding = tokenizer(sent1, 150 | padding=True, 151 | truncation=True, 152 | max_length=args.max_length, 153 | return_tensors='pt') 154 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()} 155 | sent2_encoding = tokenizer(sent2, 156 | padding=True, 157 | truncation=True, 158 | max_length=args.max_length, 159 | return_tensors='pt') 160 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()} 161 | 162 | sent1_output = model(**sent1_encoding) 163 | sent2_output = model(**sent2_encoding) 164 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 165 | all_sims += sim_score 166 | all_scores += score.tolist() 167 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 168 | return corr 169 | 170 | 171 | def train(args): 172 | train_file = args.trainfile 173 | dev_file = args.devfile 174 | test_file = args.testfile 175 | file_type = args.filetype 176 | queue_num = int(args.queue_size * args.batch_size) 177 | match_dataset = MatchingDataSet() 178 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type) 179 | dev_list = match_dataset.read_eval_file(dev_file, file_type) 180 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list))) 181 | 182 | train_loader = DataLoader(train_list, 183 | batch_size=args.batch_size, 184 | shuffle=True) 185 | dev_loader = DataLoader(dev_list, 186 | batch_size=args.batch_size) 187 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader))) 188 | 189 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained) 190 | query_encoder = ESimCSE(args.pretrained, args.pool_type, args.dropout_rate) 191 | key_encoder = copy.deepcopy(query_encoder) 192 | query_encoder.train() 193 | query_encoder.to(args.device) 194 | key_encoder.eval() 195 | key_encoder.to(args.device) 196 | 197 | optimizer = torch.optim.AdamW(query_encoder.parameters(), lr=args.lr) 198 | # 构造反例样本队列 199 | queue_embeddings = construct_queue(args, train_loader, tokenizer, key_encoder) 200 | 201 | batch_idx = 0 202 | best_corr = 0 203 | best_loss = 1000000 204 | for epoch in range(args.epochs): 205 | epoch_losses = [] 206 | for data in train_loader: 207 | optimizer.zero_grad() 208 | # 构造正例样本 209 | key_data = [repeat_word(tokenizer, sentence, args.dup_rate) for sentence in data] 210 | 211 | query_encodings = tokenizer(data, 212 | padding=True, 213 | truncation=True, 214 | max_length=args.max_length, 215 | return_tensors='pt') 216 | query_encodings = {key: value.to(args.device) for key, value in query_encodings.items()} 217 | key_encodings = tokenizer(key_data, 218 | padding=True, 219 | truncation=True, 220 | max_length=args.max_length, 221 | return_tensors='pt') 222 | key_encodings = {key: value.to(args.device) for key, value in key_encodings.items()} 223 | 224 | query_embeddings = query_encoder(**query_encodings) 225 | key_embeddings = key_encoder(**key_encodings).detach() 226 | 227 | # 对表征进行归一化,便于后面相似度计算 228 | query_embeddings = F.normalize(query_embeddings, dim=1) 229 | key_embeddings = F.normalize(key_embeddings, dim=1) 230 | 231 | batch_loss = compute_loss(query_embeddings, key_embeddings, queue_embeddings, args.tao) 232 | epoch_losses.append(batch_loss.item()) 233 | # print(batch_idx, batch_loss.item()) 234 | 235 | batch_loss.backward() 236 | optimizer.step() 237 | 238 | # 更新队列中负样本表征 239 | # queue_embeddings = torch.cat((queue_embeddings, query_embeddings.detach()), 0) 240 | queue_embeddings = torch.cat((queue_embeddings, key_embeddings), 0) 241 | queue_embeddings = queue_embeddings[-queue_num:, :] 242 | 243 | # 更新key编码器的动量 244 | for query_params, key_params in zip(query_encoder.parameters(), key_encoder.parameters()): 245 | key_params.data.copy_(args.momentum * key_params + (1-args.momentum) * query_params) 246 | key_params.requires_grad = False 247 | 248 | if batch_idx % args.display_interval == 0: 249 | logger.info("Epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item())) 250 | batch_idx += 1 251 | 252 | avg_epoch_loss = np.mean(epoch_losses) 253 | dev_corr = eval(query_encoder, tokenizer, dev_loader, args) 254 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr)) 255 | # if avg_epoch_loss <= best_loss and dev_corr >= best_corr: 256 | if dev_corr >= best_corr: 257 | best_corr = dev_corr 258 | best_loss = avg_epoch_loss 259 | torch.save({ 260 | 'epoch': epoch, 261 | 'batch': batch_idx, 262 | 'model_state_dict': query_encoder.state_dict(), 263 | 'loss': best_loss, 264 | 'corr': best_corr 265 | }, args.model_out) 266 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr)) 267 | 268 | 269 | if __name__ == "__main__": 270 | args = parse_args() 271 | logger.info("args: {}".format(args)) 272 | train(args) 273 | -------------------------------------------------------------------------------- /matching/README.md: -------------------------------------------------------------------------------- 1 | ## Sentence_Matching 2 | 3 | 该项目主要是文本匹配服务,根据channel不同调用不同模型来计算两句话的相似度。 4 | 5 | 1. channel为0:使用SimCSE模型 6 | 2. channel为1:使用ESimCSE模型,并使用同一句话通过dropout作为正样本对,引入动量对比增加负样本对。 7 | 3. channel为2:使用ESimCSE模型,并重复一句话中部分词组构造正样本对,引入动量对比增加负样本对。 8 | 4. channel为3:同时使用ESimCSE和SimCSE模型,加上多任务损失函数。 9 | 5. channel为4:使用PromptBERT模型。 10 | 6. channel为5:使用SBert模型。 11 | 7. channel为6:使用CoSent模型。 12 | -------------------------------------------------------------------------------- /matching/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Params: 4 | esimcse_same_dropout = 0.15 5 | esimcse_repeat_dropout = 0.15 6 | esimcse_multi_dropout = 0.15 7 | promptbert_dropout = 0.1 8 | simcse_dropout = 0.3 9 | sbert_dropout = 0.2 10 | cosent_dropout = 0.15 11 | max_length = 32 12 | pool_type = "pooler" 13 | sbert_pool_type = "mean" 14 | cosent_pool_type = "mean" 15 | mask_token = "[MASK]" 16 | replace_token = "[X]" 17 | # prompt_templates = ['“[UNK]”,它的意思是[MASK]。', '“[UNK]”,这句话的意思是[MASK]。'] 18 | prompt_templates = ['"{}",它的意思是[MASK]。'.format(replace_token), '"{}",这句话的意思是[MASK]。'.format(replace_token)] 19 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large" 20 | esimcse_repeat_model = "models/esimcse_0.32_0.15_160.pth" 21 | esimcse_same_model = "models/esimcse_0.15_64.pth" 22 | esimcse_multi_model = "models/esimcse_multi_0.15_64.pth" 23 | promptbert_model = "models/promptbert_1231.pth" 24 | simcse_model = "models/simcse_1226.pth" 25 | sbert_model = "models/sbert_0106.pth" 26 | cosent_model = "models/cosent_0119.pth" -------------------------------------------------------------------------------- /matching/cosent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/5 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import AutoConfig, AutoModel 7 | import torch 8 | 9 | 10 | class CoSent(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = AutoConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = AutoModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | if self.pool_type == "cls": 23 | output = self.encoder(input_ids, 24 | attention_mask=attention_mask, 25 | token_type_ids=token_type_ids) 26 | output = output.last_hidden_state[:, 0] 27 | elif self.pool_type == "pooler": 28 | output = self.encoder(input_ids, 29 | attention_mask=attention_mask, 30 | token_type_ids=token_type_ids) 31 | output = output.pooler_output 32 | elif self.pool_type == "mean": 33 | output = self.get_mean_tensor(input_ids, attention_mask) 34 | return output 35 | 36 | 37 | def get_mean_tensor(self, input_ids, attention_mask): 38 | ''' 39 | get first and last layer avg tensor 40 | ''' 41 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True) 42 | hidden_states = encode_states.hidden_states 43 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 44 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask) 45 | mean_avg_state = (last_avg_state + first_avg_state) / 2 46 | return mean_avg_state 47 | 48 | 49 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 50 | ''' 51 | layer_hidden_state: 模型一层表征向量 [B * L * D] 52 | attention_mask: 句子padding位置 [B * L] 53 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 54 | ''' 55 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 56 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 57 | sum_mask = input_mask_expanded.sum(1) 58 | sum_mask = torch.clamp(sum_mask, min=1e-9) 59 | avg_embeddings = sum_embeddings / sum_mask 60 | return avg_embeddings 61 | 62 | 63 | def get_avg_tensor2(self, layer_hidden_state, attention_mask): 64 | ''' 65 | layer_hidden_state: 模型一层表征向量 [B * L * D] 66 | attention_mask: 句子padding位置 [B * L] 67 | return: 非零位置词语的平均向量 [B * D] 68 | ''' 69 | layer_hidden_dim = layer_hidden_state.shape[-1] 70 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 71 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask) 72 | layer_sum_state = layer_attention_state.sum(dim=1) 73 | # print(last_attention_state.shape) 74 | 75 | attention_length_mask = attention_mask.sum(dim=-1) 76 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 77 | # print(attention_length_repeat_mask.shape) 78 | 79 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask) 80 | return layer_avg_state -------------------------------------------------------------------------------- /matching/esimcse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/12/16 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel 7 | 8 | 9 | class ESimCSE(nn.Module): 10 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", dropout_prob=0.15): 11 | super().__init__() 12 | conf = BertConfig.from_pretrained(pretrained) 13 | conf.attention_probs_dropout_prob = dropout_prob 14 | conf.hidden_dropout_prob = dropout_prob 15 | self.encoder = BertModel.from_pretrained(pretrained, config=conf) 16 | self.fc = nn.Linear(conf.hidden_size, conf.hidden_size) 17 | 18 | 19 | def forward(self, input_ids, attention_mask, token_type_ids): 20 | output = self.encoder(input_ids, 21 | attention_mask=attention_mask, 22 | token_type_ids=token_type_ids) 23 | output = output.last_hidden_state[:, 0] 24 | output = self.fc(output) 25 | return output -------------------------------------------------------------------------------- /matching/log.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | 4 | logger.add("logs/runtime.log") -------------------------------------------------------------------------------- /matching/promptbert.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoConfig, AutoTokenizer 2 | import torch 3 | import torch.nn as nn 4 | from config import Params 5 | 6 | 7 | class PromptBERT(nn.Module): 8 | def __init__(self, pretrained_model_path, dropout_prob, mask_id): 9 | super().__init__() 10 | conf = AutoConfig.from_pretrained(pretrained_model_path) 11 | conf.attention_probs_dropout_prob = dropout_prob 12 | conf.hidden_dropout_prob = dropout_prob 13 | self.encoder = AutoModel.from_pretrained(pretrained_model_path, config=conf) 14 | self.mask_id = mask_id 15 | 16 | 17 | def forward(self, prompt_input_ids, prompt_attention_mask, prompt_token_type_ids, template_input_ids, template_attention_mask, template_token_type_ids): 18 | ''' 19 | @function: 计算prompt mask标签表征向量hi和模板向量表征h^i之间的差 20 | 21 | @input: 22 | prompt_input_ids: prompt句子输入id 23 | prompt_attention_mask: prompt句子注意力矩阵 24 | prompt_token_type_ids: prompt句子token类型id 25 | template_input_ids: 模板句子输入id 26 | template_attention_mask: 模板句子注意力矩阵 27 | template_token_type_ids: 模板句子token类型id 28 | 29 | @return: sentence_embedding: 句子表征向量 30 | ''' 31 | prompt_mask_embedding = self.calculate_mask_embedding(prompt_input_ids, prompt_attention_mask, prompt_token_type_ids) 32 | template_mask_embedding = self.calculate_mask_embedding(template_input_ids, template_attention_mask, template_token_type_ids) 33 | sentence_embedding = prompt_mask_embedding - template_mask_embedding 34 | return sentence_embedding 35 | 36 | 37 | def calculate_mask_embedding(self, input_ids, attention_mask, token_type_ids): 38 | # print("input_ids: ", input_ids) 39 | output = self.encoder(input_ids, 40 | attention_mask=attention_mask, 41 | token_type_ids=token_type_ids) 42 | token_embeddings = output[0] 43 | mask_index = (input_ids == self.mask_id).long() 44 | # print("mask_index: ", mask_index) 45 | mask_embedding = self.get_mask_embedding(token_embeddings, mask_index) 46 | return mask_embedding 47 | 48 | 49 | def get_mask_embedding(self, token_embeddings, mask_index): 50 | ''' 51 | @function: 获取[mask]标签的embedding输出 52 | 53 | @input: 54 | token_embeddings: Tensor, 编码层最后一层token输出 55 | mask_index: Tensor, mask标签位置 56 | 57 | @return: mask_embedding: Tensor, mask标签embedding 58 | ''' 59 | input_mask_expanded = mask_index.unsqueeze(-1).expand(token_embeddings.size()).float() 60 | # print("input_mask_expanded: ", input_mask_expanded) 61 | # print("input mask expaned shape: ", input_mask_expanded.shape) 62 | mask_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) 63 | return mask_embedding 64 | 65 | 66 | if __name__ == "__main__": 67 | prompt_templates = '[UNK],这句话的意思是[MASK]' 68 | # sentence = "天气很好" 69 | sentence = '"[X][X][X]",这句话的意思是[MASK]。' 70 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model_path) 71 | special_token_dict = {'additional_special_tokens': ['[X]']} 72 | tokenizer.add_special_tokens(special_token_dict) 73 | sen_tokens = tokenizer.tokenize(sentence) 74 | sen_encodings = tokenizer(sentence, 75 | return_tensors='pt') 76 | print(sen_tokens) 77 | print(sen_encodings) 78 | exit(0) 79 | 80 | mask_id = tokenizer.convert_tokens_to_ids(Params.mask_token) 81 | 82 | model = PromptBERT(Params.pretrained_model_path, Params.dropout, mask_id) 83 | model.train() 84 | 85 | while True: 86 | print("input your sentence:") 87 | sentences = input() 88 | sentence_list = sentences.split(";") 89 | 90 | prompt_lines, template_lines = [], [] 91 | for sentence in sentence_list: 92 | words_num = len(tokenizer.tokenize(sentence)) 93 | prompt_line = prompt_templates.replace('[UNK]', sentence) 94 | template_line = prompt_templates.replace('[UNK]', '[UNK]'*words_num) 95 | print("prompt_line: {}, template_line: {}".format(prompt_line, template_line)) 96 | prompt_lines.append(prompt_line) 97 | template_lines.append(template_line) 98 | 99 | prompt_encodings = tokenizer(list(prompt_lines), 100 | padding=True, 101 | truncation=True, 102 | max_length=Params.max_length, 103 | return_tensors='pt') 104 | template_encodings = tokenizer(list(template_lines), 105 | padding=True, 106 | truncation=True, 107 | max_length=Params.max_length, 108 | return_tensors='pt') 109 | 110 | prompt_mask_embedding = model.calculate_mask_embedding(prompt_encodings['input_ids'], 111 | prompt_encodings['attention_mask'], 112 | prompt_encodings['token_type_ids']) 113 | template_mask_embedding = model.calculate_mask_embedding(template_encodings['input_ids'], 114 | template_encodings['attention_mask'], 115 | template_encodings['token_type_ids']) 116 | # print(prompt_mask_embedding.shape) 117 | -------------------------------------------------------------------------------- /matching/retrieval.py: -------------------------------------------------------------------------------- 1 | from simcse import SimCSE 2 | from esimcse import ESimCSE 3 | from promptbert import PromptBERT 4 | from sbert import SBERT 5 | from cosent import CoSent 6 | from config import Params 7 | from log import logger 8 | import torch 9 | from transformers import AutoTokenizer 10 | 11 | 12 | class SimCSERetrieval(object): 13 | def __init__(self, pretrained_model_path, simcse_path, pool_type, dropout): 14 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 15 | model = SimCSE(Params.pretrained_model, pool_type, dropout) 16 | self.checkpoint = torch.load(simcse_path, map_location='cpu') 17 | model.load_state_dict(self.checkpoint['model_state_dict']) 18 | model.eval() 19 | self.model = model 20 | 21 | 22 | def print_checkpoint_info(self): 23 | loss = self.checkpoint['loss'] 24 | epoch = self.checkpoint['epoch'] 25 | model_info = {'loss': loss, 'epoch': epoch} 26 | return model_info 27 | 28 | 29 | def calculate_sentence_embedding(self, sentence): 30 | device = "cpu" 31 | input_encodings = self.tokenizer(sentence, 32 | padding=True, 33 | truncation=True, 34 | max_length=Params.max_length, 35 | return_tensors='pt') 36 | sentence_embedding = self.model(input_encodings['input_ids'].to(device), 37 | input_encodings['attention_mask'].to(device), 38 | input_encodings['token_type_ids'].to(device)) 39 | return sentence_embedding 40 | 41 | 42 | def calculate_sentence_similarity(self, sentence1, sentence2): 43 | sentence1 = sentence1.strip() 44 | sentence2 = sentence2.strip() 45 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 46 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 47 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 48 | similarity = float(similarity.item()) 49 | return similarity 50 | 51 | 52 | class ESimCSERetrieval(object): 53 | def __init__(self, pretrained_model_path, esimcse_path, dropout): 54 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 55 | model = ESimCSE(Params.pretrained_model, dropout) 56 | self.checkpoint = torch.load(esimcse_path, map_location='cpu') 57 | model.load_state_dict(self.checkpoint['model_state_dict']) 58 | model.eval() 59 | self.model = model 60 | 61 | 62 | def print_checkpoint_info(self): 63 | loss = self.checkpoint['loss'] 64 | epoch = self.checkpoint['epoch'] 65 | model_info = {'loss': loss, 'epoch': epoch} 66 | return model_info 67 | 68 | 69 | def calculate_sentence_embedding(self, sentence): 70 | device = "cpu" 71 | input_encodings = self.tokenizer(sentence, 72 | padding=True, 73 | truncation=True, 74 | max_length=Params.max_length, 75 | return_tensors='pt') 76 | sentence_embedding = self.model(input_encodings['input_ids'].to(device), 77 | input_encodings['attention_mask'].to(device), 78 | input_encodings['token_type_ids'].to(device)) 79 | return sentence_embedding 80 | 81 | 82 | def calculate_sentence_similarity(self, sentence1, sentence2): 83 | sentence1 = sentence1.strip() 84 | sentence2 = sentence2.strip() 85 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 86 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 87 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 88 | similarity = float(similarity.item()) 89 | return similarity 90 | 91 | 92 | class PromptBertRetrieval(object): 93 | def __init__(self, pretrained_model_path, promptbert_path, dropout): 94 | super().__init__() 95 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 96 | special_token_dict = {'additional_special_tokens': ['[X]']} 97 | self.tokenizer.add_special_tokens(special_token_dict) 98 | mask_id = self.tokenizer.convert_tokens_to_ids(Params.mask_token) 99 | model = PromptBERT(pretrained_model_path, dropout, mask_id) 100 | model.encoder.resize_token_embeddings(len(self.tokenizer)) 101 | checkpoint = torch.load(promptbert_path, map_location='cpu') 102 | model.load_state_dict(checkpoint['model_state_dict']) 103 | self.checkpoint = checkpoint 104 | self.model = model 105 | 106 | 107 | def print_checkpoint_info(self): 108 | loss = self.checkpoint['loss'] 109 | epoch = self.checkpoint['epoch'] 110 | model_info = {'loss': loss, 'epoch': epoch} 111 | return model_info 112 | 113 | 114 | def calculate_sentence_mask_embedding(self, sentence): 115 | device = "cpu" 116 | prompt_sentence = Params.prompt_templates[0].replace("[X]", sentence) 117 | prompt_encodings = self.tokenizer(prompt_sentence, 118 | padding=True, 119 | truncation=True, 120 | max_length=Params.max_length, 121 | return_tensors='pt') 122 | sentence_mask_embedding = self.model.calculate_mask_embedding(prompt_encodings['input_ids'].to(device), 123 | prompt_encodings['attention_mask'].to(device), 124 | prompt_encodings['token_type_ids'].to(device)) 125 | return sentence_mask_embedding 126 | 127 | 128 | def calculate_sentence_embedding(self, sentence): 129 | device = "cpu" 130 | prompt_sentence = Params.prompt_templates[0].replace("[X]", sentence) 131 | sentence_num = len(self.tokenizer.tokenize(sentence)) 132 | template_sentence = Params.prompt_templates[0].replace("[X]", "[X]"*sentence_num) 133 | prompt_encodings = self.tokenizer(prompt_sentence, 134 | padding=True, 135 | truncation=True, 136 | max_length=Params.max_length, 137 | return_tensors='pt') 138 | template_encodings = self.tokenizer(template_sentence, 139 | padding=True, 140 | truncation=True, 141 | max_length=Params.max_length, 142 | return_tensors='pt') 143 | sentence_embedding = self.model(prompt_input_ids=prompt_encodings['input_ids'].to(device), 144 | prompt_attention_mask=prompt_encodings['attention_mask'].to(device), 145 | prompt_token_type_ids=prompt_encodings['token_type_ids'].to(device), 146 | template_input_ids=template_encodings['input_ids'].to(device), 147 | template_attention_mask=template_encodings['attention_mask'].to(device), 148 | template_token_type_ids=template_encodings['token_type_ids'].to(device)) 149 | return sentence_embedding 150 | 151 | 152 | def calculate_sentence_similarity(self, sentence1, sentence2): 153 | # sentence1_embedding = self.calculate_sentence_mask_embedding(sentence1) 154 | # sentence2_embedding = self.calculate_sentence_mask_embedding(sentence2) 155 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 156 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 157 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 158 | similarity = float(similarity.item()) 159 | return similarity 160 | 161 | 162 | class SBERTRetrieval(object): 163 | def __init__(self, pretrained_model_path, sbert_path, pool_type, dropout): 164 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 165 | model = SBERT(Params.pretrained_model, pool_type, dropout) 166 | self.checkpoint = torch.load(sbert_path, map_location='cpu') 167 | model.load_state_dict(self.checkpoint['model_state_dict']) 168 | model.eval() 169 | self.model = model 170 | 171 | 172 | def print_checkpoint_info(self): 173 | loss = self.checkpoint['train_loss'] 174 | epoch = self.checkpoint['epoch'] 175 | model_info = {'loss': loss, 'epoch': epoch} 176 | return model_info 177 | 178 | 179 | def calculate_sentence_embedding(self, sentence): 180 | device = "cpu" 181 | input_encodings = self.tokenizer(sentence, 182 | padding=True, 183 | truncation=True, 184 | max_length=Params.max_length, 185 | return_tensors='pt') 186 | sentence_embedding = self.model(input_encodings['input_ids'].to(device), 187 | input_encodings['attention_mask'].to(device), 188 | input_encodings['token_type_ids'].to(device)) 189 | return sentence_embedding 190 | 191 | 192 | def calculate_sentence_similarity(self, sentence1, sentence2): 193 | sentence1 = sentence1.strip() 194 | sentence2 = sentence2.strip() 195 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 196 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 197 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 198 | similarity = float(similarity.item()) 199 | return similarity 200 | 201 | 202 | class CoSentRetrieval(object): 203 | def __init__(self, pretrained_model_path, cosent_path): 204 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path) 205 | model = CoSent(Params.pretrained_model, Params.cosent_pool_type, Params.cosent_dropout) 206 | self.checkpoint = torch.load(cosent_path, map_location='cpu') 207 | model.load_state_dict(self.checkpoint['model_state_dict']) 208 | model.eval() 209 | self.model = model 210 | 211 | 212 | def print_checkpoint_info(self): 213 | loss = self.checkpoint['loss'] 214 | epoch = self.checkpoint['epoch'] 215 | model_info = {'loss': loss, 'epoch': epoch} 216 | return model_info 217 | 218 | 219 | def calculate_sentence_embedding(self, sentence): 220 | device = "cpu" 221 | input_encodings = self.tokenizer(sentence, 222 | padding=True, 223 | truncation=True, 224 | max_length=Params.max_length, 225 | return_tensors='pt') 226 | sentence_embedding = self.model(input_encodings['input_ids'].to(device), 227 | input_encodings['attention_mask'].to(device), 228 | input_encodings['token_type_ids'].to(device)) 229 | return sentence_embedding 230 | 231 | 232 | def calculate_sentence_similarity(self, sentence1, sentence2): 233 | sentence1 = sentence1.strip() 234 | sentence2 = sentence2.strip() 235 | sentence1_embedding = self.calculate_sentence_embedding(sentence1) 236 | sentence2_embedding = self.calculate_sentence_embedding(sentence2) 237 | similarity = torch.cosine_similarity(sentence1_embedding, sentence2_embedding, dim=-1) 238 | similarity = float(similarity.item()) 239 | return similarity 240 | 241 | 242 | simcse_retrieval = SimCSERetrieval(Params.pretrained_model, Params.simcse_model, Params.pool_type, Params.simcse_dropout) 243 | logger.info("start simcse model succussfully!") 244 | esimcse_repeat_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_repeat_model, Params.esimcse_repeat_dropout) 245 | logger.info("start esimcse repeat model succussfully!") 246 | esimcse_same_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_same_model, Params.esimcse_same_dropout) 247 | logger.info("start esimcse same model succussfully!") 248 | esimcse_multi_retrieval = ESimCSERetrieval(Params.pretrained_model, Params.esimcse_multi_model, Params.esimcse_multi_dropout) 249 | logger.info("start esimcse multi model succussfully!") 250 | promptbert_retrieval = PromptBertRetrieval(Params.pretrained_model, Params.promptbert_model, Params.promptbert_dropout) 251 | logger.info("start promptbert model succussfully!") 252 | sbert_retrieval = SBERTRetrieval(Params.pretrained_model, Params.sbert_model, Params.sbert_pool_type, Params.sbert_dropout) 253 | logger.info("start sbert model succussfully!") 254 | cosent_retrieval = CoSentRetrieval(Params.pretrained_model, Params.cosent_model) 255 | logger.info("start cosent model succussfully!") 256 | 257 | 258 | if __name__ == "__main__": 259 | # model_path = "models/esimcse_0.32_0.15_160.pth" 260 | # model_path = "models/esimcse_multi_0.15_64.pth" 261 | # model_path = "models/esimcse_0.15_64.pth" 262 | 263 | 264 | 265 | # simcse_retrieval = SimCSERetrieval(Params.pretrained_model, Params.simcse_model, Params.pool_type, Params.simcse_dropout) 266 | # model_info = simcse_retrieval.print_checkpoint_info() 267 | # print(model_info) 268 | 269 | model_info = sbert_retrieval.print_checkpoint_info() 270 | print(model_info) 271 | 272 | while True: 273 | print("input your sentence1:") 274 | sentence1 = input() 275 | print("input your sentence2:") 276 | sentence2 = input() 277 | 278 | sbert_sentence_similarity = sbert_retrieval.calculate_sentence_similarity(sentence1, sentence2) 279 | # promptbert_sentence_similarity = prom.calculate_sentence_similarity(sentence1, sentence2) 280 | # print("simcse sim: {}, promptbert sim: {}".format(simcse_sentence_similarity, promptbert_sentence_similarity)) 281 | print("sbert similarity: {}".format(sbert_sentence_similarity)) -------------------------------------------------------------------------------- /matching/sbert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/5 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel 7 | import torch 8 | 9 | 10 | class SBERT(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = BertConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | if self.pool_type == "cls": 23 | output = self.encoder(input_ids, 24 | attention_mask=attention_mask, 25 | token_type_ids=token_type_ids) 26 | output = output.last_hidden_state[:, 0] 27 | elif self.pool_type == "pooler": 28 | output = self.encoder(input_ids, 29 | attention_mask=attention_mask, 30 | token_type_ids=token_type_ids) 31 | output = output.pooler_output 32 | elif self.pool_type == "mean": 33 | output = self.get_mean_tensor(input_ids, attention_mask) 34 | return output 35 | 36 | 37 | def get_mean_tensor(self, input_ids, attention_mask): 38 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True) 39 | hidden_states = encode_states.hidden_states 40 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 41 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask) 42 | mean_avg_state = (last_avg_state + first_avg_state) / 2 43 | return mean_avg_state 44 | 45 | 46 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 47 | ''' 48 | layer_hidden_state: 模型一层表征向量 [B * L * D] 49 | attention_mask: 句子padding位置 [B * L] 50 | return: 非零位置词语的平均向量 [B * D] 51 | ''' 52 | layer_hidden_dim = layer_hidden_state.shape[-1] 53 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 54 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask) 55 | layer_sum_state = layer_attention_state.sum(dim=1) 56 | # print(last_attention_state.shape) 57 | 58 | attention_length_mask = attention_mask.sum(dim=-1) 59 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 60 | # print(attention_length_repeat_mask.shape) 61 | 62 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask) 63 | return layer_avg_state 64 | 65 | 66 | def get_avg_tensor2(self, layer_hidden_state, attention_mask): 67 | ''' 68 | layer_hidden_state: 模型一层表征向量 [B * L * D] 69 | attention_mask: 句子padding位置 [B * L] 70 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 71 | ''' 72 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 73 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 74 | sum_mask = input_mask_expanded.sum(1) 75 | sum_mask = torch.clamp(sum_mask, min=1e-9) 76 | avg_embeddings = sum_embeddings / sum_mask 77 | return avg_embeddings -------------------------------------------------------------------------------- /matching/simcse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/10 3 | # @Author : kaka 4 | 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel 7 | 8 | 9 | class SimCSE(nn.Module): 10 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 11 | super().__init__() 12 | conf = BertConfig.from_pretrained(pretrained) 13 | conf.attention_probs_dropout_prob = dropout_prob 14 | conf.hidden_dropout_prob = dropout_prob 15 | self.encoder = BertModel.from_pretrained(pretrained, config=conf) 16 | assert pool_type in ["cls", "pooler"], "invalid pool_type: %s" % pool_type 17 | self.pool_type = pool_type 18 | 19 | 20 | def forward(self, input_ids, attention_mask, token_type_ids): 21 | output = self.encoder(input_ids, 22 | attention_mask=attention_mask, 23 | token_type_ids=token_type_ids) 24 | if self.pool_type == "cls": 25 | output = output.last_hidden_state[:, 0] 26 | elif self.pool_type == "pooler": 27 | output = output.pooler_output 28 | return output -------------------------------------------------------------------------------- /matching/view.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | from retrieval import simcse_retrieval, esimcse_same_retrieval, esimcse_repeat_retrieval, esimcse_multi_retrieval, promptbert_retrieval 3 | from retrieval import sbert_retrieval, cosent_retrieval 4 | 5 | 6 | app = Flask(__name__) 7 | @app.route('/compare', methods=['GET', 'POST']) 8 | def compare(): 9 | sentence1 = request.args.get("sentence1", "") 10 | sentence2 = request.args.get("sentence2", "") 11 | channel = request.args.get("channel", "") 12 | sentence1 = strip_sentence(sentence1) 13 | sentence2 = strip_sentence(sentence2) 14 | 15 | if channel == "0": 16 | model_name = "simcse" 17 | model_desc = "use simcse model for data augmentation and compare positive samples" 18 | model_info = simcse_retrieval.print_checkpoint_info() 19 | model_info['model_name'] = model_name 20 | model_info['model_desc'] = model_desc 21 | similarity = simcse_retrieval.calculate_sentence_similarity(sentence1, sentence2) 22 | elif channel == "1": 23 | model_name = "esimcse_same_positive" 24 | model_desc = "use same sentence as positive pairs and construct a negative queue" 25 | model_info = esimcse_same_retrieval.print_checkpoint_info() 26 | model_info['model_name'] = model_name 27 | model_info['model_desc'] = model_desc 28 | similarity = esimcse_same_retrieval.calculate_sentence_similarity(sentence1, sentence2) 29 | elif channel == "2": 30 | model_name = "esimcse_repeat_positive" 31 | model_desc = "use repeat word for data augmentation as positive pairs and construct a negative queue" 32 | model_info = esimcse_repeat_retrieval.print_checkpoint_info() 33 | model_info['model_name'] = model_name 34 | model_info['model_desc'] = model_desc 35 | similarity = esimcse_repeat_retrieval.calculate_sentence_similarity(sentence1, sentence2) 36 | elif channel == "3": 37 | model_name = "esimcse_multi_positive" 38 | model_desc = "Multi task loss: use same sentence and repeat word as positive pairs and construct a negative queue" 39 | model_info = esimcse_multi_retrieval.print_checkpoint_info() 40 | model_info['model_name'] = model_name 41 | model_info['model_desc'] = model_desc 42 | similarity = esimcse_multi_retrieval.calculate_sentence_similarity(sentence1, sentence2) 43 | elif channel == "4": 44 | model_name = "promptbert" 45 | model_desc = "use different templates to generate sentence embedding as positive pairs" 46 | model_info = promptbert_retrieval.print_checkpoint_info() 47 | model_info['model_name'] = model_name 48 | model_info['model_desc'] = model_desc 49 | similarity = promptbert_retrieval.calculate_sentence_similarity(sentence1, sentence2) 50 | elif channel == "5": 51 | model_name = "sbert" 52 | model_desc = "train sentence-bert structure model with cosine similarity and mse loss, using high prediction probability cases as training dataset" 53 | model_info = sbert_retrieval.print_checkpoint_info() 54 | model_info['model_name'] = model_name 55 | model_info['model_desc'] = model_desc 56 | similarity = sbert_retrieval.calculate_sentence_similarity(sentence1, sentence2) 57 | elif channel == "6": 58 | model_name = "cosent" 59 | model_desc = "train cosent structure model with contrastive loss" 60 | model_info = cosent_retrieval.print_checkpoint_info() 61 | model_info['model_name'] = model_name 62 | model_info['model_desc'] = model_desc 63 | similarity = cosent_retrieval.calculate_sentence_similarity(sentence1, sentence2) 64 | else: 65 | model_info = {'model_name': 'your channel is illegal'} 66 | similarity = None 67 | 68 | sent_info = {'sentence1': sentence1, 'sentence2': sentence2, 'similarity': similarity} 69 | result = {'model_info': model_info, 'sentence_info': sent_info} 70 | data = {'code': 200, 'message': 'OK', 'data': result} 71 | resp = jsonify(data) 72 | return resp 73 | 74 | 75 | def strip_sentence(sentence): 76 | sentence = sentence.strip().lower() 77 | sentence = sentence.replace("?", "").replace("?", "") 78 | return sentence 79 | 80 | 81 | if __name__ == "__main__": 82 | app.run(host='0.0.0.0', port=8080) -------------------------------------------------------------------------------- /pics/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/.DS_Store -------------------------------------------------------------------------------- /pics/cosent_loss.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pics/esimcse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/esimcse.png -------------------------------------------------------------------------------- /pics/esimcse_loss.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pics/promptbert_loss.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pics/sbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/sbert.png -------------------------------------------------------------------------------- /pics/simcse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/pics/simcse.png -------------------------------------------------------------------------------- /pics/simcse_loss.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /promptbert/PromptBert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/21 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import AutoConfig, AutoModel 7 | import torch 8 | 9 | 10 | class PromptBERT(nn.Module): 11 | def __init__(self, pretrained_model_path, dropout_prob, mask_id): 12 | super().__init__() 13 | conf = AutoConfig.from_pretrained(pretrained_model_path) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = AutoModel.from_pretrained(pretrained_model_path, config=conf) 17 | self.mask_id = mask_id 18 | 19 | 20 | def forward(self, prompt_input_ids, prompt_attention_mask, prompt_token_type_ids, template_input_ids, template_attention_mask, template_token_type_ids): 21 | ''' 22 | @function: 计算prompt mask标签表征向量hi和模板向量表征h^i之间的差 23 | 24 | @input: 25 | prompt_input_ids: prompt句子输入id 26 | prompt_attention_mask: prompt句子注意力矩阵 27 | prompt_token_type_ids: prompt句子token类型id 28 | template_input_ids: 模板句子输入id 29 | template_attention_mask: 模板句子注意力矩阵 30 | template_token_type_ids: 模板句子token类型id 31 | 32 | @return: sentence_embedding: 句子表征向量 33 | ''' 34 | prompt_mask_embedding = self.calculate_mask_embedding(prompt_input_ids, prompt_attention_mask, prompt_token_type_ids) 35 | template_mask_embedding = self.calculate_mask_embedding(template_input_ids, template_attention_mask, template_token_type_ids) 36 | sentence_embedding = prompt_mask_embedding - template_mask_embedding 37 | return sentence_embedding 38 | 39 | 40 | def calculate_mask_embedding(self, input_ids, attention_mask, token_type_ids): 41 | # print("input_ids: ", input_ids) 42 | output = self.encoder(input_ids, 43 | attention_mask=attention_mask, 44 | token_type_ids=token_type_ids) 45 | token_embeddings = output[0] 46 | mask_index = (input_ids == self.mask_id).long() 47 | # print("mask_index: ", mask_index) 48 | mask_embedding = self.get_mask_embedding(token_embeddings, mask_index) 49 | return mask_embedding 50 | 51 | 52 | def get_mask_embedding(self, token_embeddings, mask_index): 53 | ''' 54 | @function: 获取[mask]标签的embedding输出 55 | 56 | @input: 57 | token_embeddings: Tensor, 编码层最后一层token输出 58 | mask_index: Tensor, mask标签位置 59 | 60 | @return: mask_embedding: Tensor, mask标签embedding 61 | ''' 62 | input_mask_expanded = mask_index.unsqueeze(-1).expand(token_embeddings.size()).float() 63 | # print("input_mask_expanded: ", input_mask_expanded) 64 | # print("input mask expaned shape: ", input_mask_expanded.shape) 65 | mask_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) 66 | return mask_embedding 67 | 68 | 69 | def forward_sentence(self, input_ids, attention_mask, token_type_ids, pool_type): 70 | output = self.encoder(input_ids, 71 | attention_mask=attention_mask, 72 | token_type_ids=token_type_ids, 73 | output_hidden_states=True) 74 | hidden_states = output.hidden_states 75 | if pool_type == "cls": 76 | output = output.last_hidden_state[:, 0] 77 | elif pool_type == "pooler": 78 | output = output.pooler_output 79 | elif pool_type == "avg_first_last": 80 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask) 81 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 82 | output = (top_first_state + last_first_state) / 2 83 | else: 84 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 85 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask) 86 | output = (last_first_state + last_second_state) / 2 87 | return output 88 | 89 | 90 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 91 | ''' 92 | layer_hidden_state: 模型一层表征向量 [B * L * D] 93 | attention_mask: 句子padding位置 [B * L] 94 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 95 | ''' 96 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 97 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 98 | sum_mask = input_mask_expanded.sum(1) 99 | sum_mask = torch.clamp(sum_mask, min=1e-9) 100 | avg_embeddings = sum_embeddings / sum_mask 101 | return avg_embeddings -------------------------------------------------------------------------------- /promptbert/loading.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class MatchingDataSet: 9 | def read_train_file(self, trainfile, devfile, testfile, filetype): 10 | sents = [] 11 | if filetype == "txt": 12 | with open(trainfile, 'r') as f: 13 | for line in f.readlines(): 14 | _, s1, s2, _ = line.strip().split(u"||") 15 | sents.append(s1) 16 | sents.append(s2) 17 | with open(devfile, 'r') as f: 18 | for line in f.readlines(): 19 | _, s1, s2, _ = line.strip().split(u"||") 20 | sents.append(s1) 21 | sents.append(s2) 22 | with open(testfile, 'r') as f: 23 | for line in f.readlines(): 24 | _, s1, s2, _ = line.strip().split(u"||") 25 | sents.append(s1) 26 | sents.append(s2) 27 | return sents 28 | 29 | def read_eval_file(self, file, filetype): 30 | sents = [] 31 | if filetype == "txt": 32 | with open(file, 'r') as f: 33 | for line in f.readlines(): 34 | _, s1, s2, s = line.strip().split(u"||") 35 | item = {'sent1': s1, 36 | 'sent2': s2, 37 | 'score': float(s)} 38 | sents.append(item) 39 | return sents 40 | 41 | 42 | if __name__ == "__main__": 43 | trainfile = "../dataset/STS-B/train.txt" 44 | devfile = "../dataset/STS-B/dev.txt" 45 | testfile = "../dataset/STS-B/test.txt" 46 | match_dataset = MatchingDataSet() 47 | 48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt") 49 | print(train_list[:5]) 50 | 51 | train_lengths = [len(sentence) for sentence in train_list] 52 | max_len = max(train_lengths) 53 | 54 | 55 | dev_list = match_dataset.read_eval_file(devfile, "txt") 56 | dev_sen1_length = [len(d['sent1']) for d in dev_list] 57 | dev_sen2_length = [len(d['sent2']) for d in dev_list] 58 | max_sen1 = max(dev_sen1_length) 59 | max_sen2 = max(dev_sen2_length) 60 | print(max_len, max_sen1, max_sen2) 61 | # dev_loader = DataLoader(dev_list, 62 | # batch_size=8) 63 | # for batch in dev_loader: 64 | # print(batch) 65 | # exit(0) -------------------------------------------------------------------------------- /promptbert/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/19 3 | # @Author : Maciel 4 | 5 | from loading import MatchingDataSet 6 | import torch 7 | import torch.nn.functional as F 8 | import scipy.stats 9 | from torch.utils.data import DataLoader 10 | from PromptBert import PromptBERT 11 | from transformers import AutoTokenizer 12 | import os 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | 16 | def eval(model, tokenizer, dev_loader, device, max_length, pool_type): 17 | model.eval() 18 | model.to(device) 19 | 20 | all_sims, all_scores = [], [] 21 | with torch.no_grad(): 22 | for data in dev_loader: 23 | sent1 = data['sent1'] 24 | sent2 = data['sent2'] 25 | score = data['score'] 26 | sent1_encoding = tokenizer(sent1, 27 | padding=True, 28 | truncation=True, 29 | max_length=max_length, 30 | return_tensors='pt') 31 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()} 32 | sent2_encoding = tokenizer(sent2, 33 | padding=True, 34 | truncation=True, 35 | max_length=max_length, 36 | return_tensors='pt') 37 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()} 38 | 39 | sent1_output = model.forward_sentence(sent1_encoding['input_ids'], 40 | sent1_encoding['attention_mask'], 41 | sent1_encoding['token_type_ids'], 42 | pool_type) 43 | sent2_output = model.forward_sentence(sent2_encoding['input_ids'], 44 | sent2_encoding['attention_mask'], 45 | sent2_encoding['token_type_ids'], 46 | pool_type) 47 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 48 | all_sims += sim_score 49 | all_scores += score.tolist() 50 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 51 | return corr 52 | 53 | 54 | def eval2(model, tokenizer, dev_loader, device, max_length, pool_type): 55 | model.eval() 56 | model.to(device) 57 | 58 | all_sims, all_scores = [], [] 59 | with torch.no_grad(): 60 | for data in dev_loader: 61 | sent1 = data['sent1'] 62 | sent2 = data['sent2'] 63 | score = data['score'] 64 | 65 | # print("sent1: {}, sent2: {}".format(sent1, sent2)) 66 | prompt_template_sent1 = [transform_sentence(s1, tokenizer, max_length) for s1 in sent1] 67 | prompt_sent1 = [pair[0] for pair in prompt_template_sent1] 68 | template_sent1 = [pair[1] for pair in prompt_template_sent1] 69 | # print("prompt sent1: {}".format(prompt_sent1)) 70 | # print("template sent1: {}".format(template_sent1)) 71 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1, max_length) 72 | template_encoding1 = encode_sentences(tokenizer, template_sent1, max_length) 73 | prompt_encoding1 = {key: value.to(device) for key, value in prompt_encoding1.items()} 74 | template_encoding1 = {key: value.to(device) for key, value in template_encoding1.items()} 75 | # print("prompt_encoding1 input_ids {}".format(prompt_encoding1['input_ids'])) 76 | # print("template_encoding1 input_ids: {}".format(template_encoding1['input_ids'])) 77 | 78 | prompt_template_sent2 = [transform_sentence(s2, tokenizer, max_length) for s2 in sent2] 79 | prompt_sent2 = [pair[0] for pair in prompt_template_sent2] 80 | template_sent2 = [pair[1] for pair in prompt_template_sent2] 81 | # print("prompt sent2: {}".format(prompt_sent2)) 82 | # print("template sent2: {}".format(template_sent2)) 83 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2, max_length) 84 | template_encoding2 = encode_sentences(tokenizer, template_sent2, max_length) 85 | prompt_encoding2 = {key: value.to(device) for key, value in prompt_encoding2.items()} 86 | template_encoding2 = {key: value.to(device) for key, value in template_encoding2.items()} 87 | # print("prompt_encoding2 input_ids: {}".format(prompt_encoding2['input_ids'])) 88 | # print("template_encoding2 input_ids: {}".format(template_encoding2['input_ids'])) 89 | 90 | sent1_output = model(prompt_encoding1['input_ids'], 91 | prompt_encoding1['attention_mask'], 92 | prompt_encoding1['token_type_ids'], 93 | template_encoding1['input_ids'], 94 | template_encoding1['attention_mask'], 95 | template_encoding1['token_type_ids']) 96 | sent2_output = model(prompt_encoding2['input_ids'], 97 | prompt_encoding2['attention_mask'], 98 | prompt_encoding2['token_type_ids'], 99 | template_encoding2['input_ids'], 100 | template_encoding2['attention_mask'], 101 | template_encoding2['token_type_ids']) 102 | # print("sen1 output shape: {}, sen2 output shape: {}".format(sent1_output.shape, sent2_output.shape)) 103 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 104 | all_sims += sim_score 105 | all_scores += score.tolist() 106 | # print("sim_score: {}, score: {}".format(sim_score, score)) 107 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 108 | return corr 109 | 110 | 111 | def transform_sentence(sentence, tokenizer, max_length): 112 | prompt_templates = ['[X],它的意思是[MASK]', '[X],这句话的意思是[MASK]'] 113 | words_list = tokenizer.tokenize(sentence) 114 | words_num = len(words_list) 115 | sentence_template = [] 116 | for template in prompt_templates: 117 | if words_num > max_length - 15: 118 | words_list = words_list[:-15] 119 | sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(words_list)).replace(" ", "") 120 | 121 | words_len = len(tokenizer.tokenize(sentence)) 122 | prompt_sentence = template.replace("[X]", sentence) 123 | template_sentence = template.replace("[X]", "[X]"*words_len) 124 | sentence_template += [prompt_sentence, template_sentence] 125 | return sentence_template 126 | 127 | 128 | def encode_sentences(tokenizer, sen_list, max_length): 129 | sen_encoding = tokenizer(sen_list, 130 | padding=True, 131 | truncation=True, 132 | max_length=max_length, 133 | return_tensors='pt') 134 | return sen_encoding 135 | 136 | 137 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length): 138 | match_dataset = MatchingDataSet() 139 | testfile_type = "txt" 140 | test_list = match_dataset.read_eval_file(testfile, testfile_type) 141 | print("test samples num: {}".format(len(test_list))) 142 | 143 | test_loader = DataLoader(test_list, 144 | batch_size=4) 145 | print("test batch num: {}".format(len(test_loader))) 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(pretrained) 148 | special_token_dict = {'additional_special_tokens': ['[X]']} 149 | tokenizer.add_special_tokens(special_token_dict) 150 | mask_id = tokenizer.mask_token_id 151 | model = PromptBERT(pretrained, dropout_rate, mask_id) 152 | model.encoder.resize_token_embeddings(len(tokenizer)) 153 | model.load_state_dict(torch.load(model_path, map_location='cpu')['model_state_dict']) 154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 155 | 156 | # test_corr = eval(model, tokenizer, test_loader, device, max_length, pool_type) 157 | test_corr2 = eval2(model, tokenizer, test_loader, device, max_length, pool_type) 158 | # print("test corr: {}, test_corr2: {}".format(test_corr, test_corr2)) 159 | print(test_corr2) 160 | 161 | 162 | if __name__ == "__main__": 163 | testfile = "../dataset/STS-B/test.txt" 164 | pretrained = "hfl/chinese-roberta-wwm-ext" 165 | pool_type = "avg_first_last" 166 | dropout_rate = 0.1 167 | max_length = 128 168 | model_path = "../models/promptbert_roberta_stsb.pth" 169 | 170 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length) 171 | -------------------------------------------------------------------------------- /promptbert/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/21 3 | # @Author : Maciel 4 | 5 | 6 | from transformers import AutoTokenizer 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | import scipy.stats 12 | from loading import MatchingDataSet 13 | from PromptBert import PromptBERT 14 | import argparse 15 | from loguru import logger 16 | import os 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path") 23 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path") 24 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path") 25 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type") 26 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext", help="huggingface pretrained model") 27 | parser.add_argument("--model_out", type=str, default="../models/promptbert_roberta_stsb.pth", help="model output path") 28 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length") 29 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 30 | parser.add_argument("--epochs", type=int, default=10, help="epochs") 31 | parser.add_argument("--lr", type=float, default=5e-5, help="learning rate") 32 | parser.add_argument("--tao", type=float, default=0.05, help="temperature") 33 | parser.add_argument("--device", type=str, default="cuda", help="device") 34 | parser.add_argument("--display_interval", type=int, default=100, help="display interval") 35 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type") 36 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate") 37 | parser.add_argument("--task", type=str, default="promptbert", help="task name") 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def compute_loss(query, key, tao=0.05): 43 | # 对表征进行归一化,便于后面相似度计算 44 | query = F.normalize(query, dim=1) 45 | key = F.normalize(key, dim=1) 46 | # print(query.shape, key.shape) 47 | N, D = query.shape[0], query.shape[1] 48 | 49 | # calculate positive similarity 50 | batch_pos = torch.exp(torch.div(torch.bmm(query.view(N, 1, D), key.view(N, D, 1)).view(N, 1), tao)) 51 | 52 | # calculate inner_batch all similarity 53 | batch_all = torch.sum(torch.exp(torch.div(torch.mm(query.view(N, D), torch.t(key)), tao)), dim=1) 54 | 55 | loss = torch.mean(-torch.log(torch.div(batch_pos, batch_all))) 56 | return loss 57 | 58 | 59 | def eval(model, tokenizer, dev_loader, args): 60 | model.eval() 61 | model.to(args.device) 62 | 63 | all_sims, all_scores = [], [] 64 | with torch.no_grad(): 65 | for data in dev_loader: 66 | sent1 = data['sent1'] 67 | sent2 = data['sent2'] 68 | score = data['score'] 69 | sent1_encoding = tokenizer(sent1, 70 | padding=True, 71 | truncation=True, 72 | max_length=args.max_length, 73 | return_tensors='pt') 74 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()} 75 | sent2_encoding = tokenizer(sent2, 76 | padding=True, 77 | truncation=True, 78 | max_length=args.max_length, 79 | return_tensors='pt') 80 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()} 81 | 82 | sent1_output = model.forward_sentence(sent1_encoding['input_ids'], 83 | sent1_encoding['attention_mask'], 84 | sent1_encoding['token_type_ids'], 85 | args.pool_type) 86 | sent2_output = model.forward_sentence(sent2_encoding['input_ids'], 87 | sent2_encoding['attention_mask'], 88 | sent2_encoding['token_type_ids'], 89 | args.pool_type) 90 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 91 | all_sims += sim_score 92 | all_scores += score.tolist() 93 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 94 | return corr 95 | 96 | 97 | def eval2(model, tokenizer, dev_loader, args): 98 | model.eval() 99 | model.to(args.device) 100 | 101 | all_sims, all_scores = [], [] 102 | with torch.no_grad(): 103 | for data in dev_loader: 104 | sent1 = data['sent1'] 105 | sent2 = data['sent2'] 106 | score = data['score'] 107 | 108 | prompt_template_sent1 = [transform_sentence(s1, tokenizer, args) for s1 in sent1] 109 | prompt_sent1 = [pair[0] for pair in prompt_template_sent1] 110 | template_sent1 = [pair[1] for pair in prompt_template_sent1] 111 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1, args) 112 | template_encoding1 = encode_sentences(tokenizer, template_sent1, args) 113 | prompt_encoding1 = {key: value.to(args.device) for key, value in prompt_encoding1.items()} 114 | template_encoding1 = {key: value.to(args.device) for key, value in template_encoding1.items()} 115 | 116 | prompt_template_sent2 = [transform_sentence(s2, tokenizer, args) for s2 in sent2] 117 | prompt_sent2 = [pair[0] for pair in prompt_template_sent2] 118 | template_sent2 = [pair[1] for pair in prompt_template_sent2] 119 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2, args) 120 | template_encoding2 = encode_sentences(tokenizer, template_sent2, args) 121 | prompt_encoding2 = {key: value.to(args.device) for key, value in prompt_encoding2.items()} 122 | template_encoding2 = {key: value.to(args.device) for key, value in template_encoding2.items()} 123 | 124 | sent1_output = model(prompt_encoding1['input_ids'], 125 | prompt_encoding1['attention_mask'], 126 | prompt_encoding1['token_type_ids'], 127 | template_encoding1['input_ids'], 128 | template_encoding1['attention_mask'], 129 | template_encoding1['token_type_ids']) 130 | sent2_output = model(prompt_encoding2['input_ids'], 131 | prompt_encoding2['attention_mask'], 132 | prompt_encoding2['token_type_ids'], 133 | template_encoding2['input_ids'], 134 | template_encoding2['attention_mask'], 135 | template_encoding2['token_type_ids']) 136 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 137 | all_sims += sim_score 138 | all_scores += score.tolist() 139 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 140 | return corr 141 | 142 | 143 | def transform_sentence(sentence, tokenizer, args): 144 | prompt_templates = ['[X],它的意思是[MASK]', '[X],这句话的意思是[MASK]'] 145 | words_list = tokenizer.tokenize(sentence) 146 | words_num = len(words_list) 147 | sentence_template = [] 148 | for template in prompt_templates: 149 | if words_num > args.max_length - 15: 150 | words_list = words_list[:-15] 151 | sentence = tokenizer.decode(tokenizer.convert_tokens_to_ids(words_list)).replace(" ", "") 152 | 153 | words_len = len(tokenizer.tokenize(sentence)) 154 | prompt_sentence = template.replace("[X]", sentence) 155 | template_sentence = template.replace("[X]", "[X]"*words_len) 156 | sentence_template += [prompt_sentence, template_sentence] 157 | return sentence_template 158 | 159 | 160 | def encode_sentences(tokenizer, sen_list, args): 161 | sen_encoding = tokenizer(sen_list, 162 | padding=True, 163 | truncation=True, 164 | max_length=args.max_length, 165 | return_tensors='pt') 166 | return sen_encoding 167 | 168 | 169 | def build_dataset(dataloader, tokenizer, args): 170 | data_encodings = [] 171 | for data in dataloader: 172 | prompt_template_sentences = [transform_sentence(sentence, tokenizer, args) for sentence in data] 173 | prompt_sent1_list = [pair[0] for pair in prompt_template_sentences] 174 | template_sent1_list = [pair[1] for pair in prompt_template_sentences] 175 | prompt_sent2_list = [pair[2] for pair in prompt_template_sentences] 176 | template_sent2_list = [pair[3] for pair in prompt_template_sentences] 177 | 178 | prompt_encoding1 = encode_sentences(tokenizer, prompt_sent1_list, args) 179 | template_encoding1 = encode_sentences(tokenizer, template_sent1_list, args) 180 | prompt_encoding2 = encode_sentences(tokenizer, prompt_sent2_list, args) 181 | template_encoding2 = encode_sentences(tokenizer, template_sent2_list, args) 182 | data_encodings.append([prompt_encoding1, template_encoding1, prompt_encoding2, template_encoding2]) 183 | return data_encodings 184 | 185 | 186 | def train(args): 187 | train_file = args.trainfile 188 | dev_file = args.devfile 189 | test_file = args.testfile 190 | file_type = args.filetype 191 | match_dataset = MatchingDataSet() 192 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type) 193 | dev_list = match_dataset.read_eval_file(dev_file, file_type) 194 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list))) 195 | 196 | train_loader = DataLoader(train_list, 197 | batch_size=args.batch_size, 198 | shuffle=True) 199 | dev_loader = DataLoader(dev_list, 200 | batch_size=args.batch_size) 201 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader))) 202 | 203 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained) 204 | special_token_dict = {'additional_special_tokens': ['[X]']} 205 | tokenizer.add_special_tokens(special_token_dict) 206 | mask_id = tokenizer.mask_token_id 207 | 208 | model = PromptBERT(args.pretrained, args.dropout_rate, mask_id) 209 | model.encoder.resize_token_embeddings(len(tokenizer)) 210 | model.train() 211 | model.to(args.device) 212 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 213 | 214 | train_encodings = build_dataset(train_loader, tokenizer, args) 215 | 216 | batch_idx = 0 217 | best_corr = 0 218 | best_loss = 1000000 219 | for epoch in range(args.epochs): 220 | epoch_losses = [] 221 | for data in train_encodings: 222 | optimizer.zero_grad() 223 | prompt_encoding1, template_encoding1, prompt_encoding2, template_encoding2 = data 224 | prompt_encoding1 = {key: value.to(args.device) for key, value in prompt_encoding1.items()} 225 | template_encoding1 = {key: value.to(args.device) for key, value in template_encoding1.items()} 226 | prompt_encoding2 = {key: value.to(args.device) for key, value in prompt_encoding2.items()} 227 | template_encoding2 = {key: value.to(args.device) for key, value in template_encoding2.items()} 228 | 229 | query_embedding = model(prompt_encoding1['input_ids'], 230 | prompt_encoding1['attention_mask'], 231 | prompt_encoding1['token_type_ids'], 232 | template_encoding1['input_ids'], 233 | template_encoding1['attention_mask'], 234 | template_encoding1['token_type_ids']) 235 | key_embedding = model(prompt_encoding2['input_ids'], 236 | prompt_encoding2['attention_mask'], 237 | prompt_encoding2['token_type_ids'], 238 | template_encoding2['input_ids'], 239 | template_encoding2['attention_mask'], 240 | template_encoding2['token_type_ids']) 241 | batch_loss = compute_loss(query_embedding, key_embedding, args.tao) 242 | 243 | batch_loss.backward() 244 | optimizer.step() 245 | epoch_losses.append(batch_loss.item()) 246 | if batch_idx % args.display_interval == 0: 247 | logger.info("Epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item())) 248 | batch_idx += 1 249 | 250 | avg_epoch_loss = np.mean(epoch_losses) 251 | dev_corr = eval2(model, tokenizer, dev_loader, args) 252 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr)) 253 | if avg_epoch_loss <= best_loss and dev_corr >= best_corr: 254 | best_corr = dev_corr 255 | best_loss = avg_epoch_loss 256 | torch.save({ 257 | 'epoch': epoch, 258 | 'batch': batch_idx, 259 | 'model_state_dict': model.state_dict(), 260 | 'loss': best_loss, 261 | 'corr': best_corr 262 | }, args.model_out) 263 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr)) 264 | 265 | 266 | if __name__ == "__main__": 267 | args = parse_args() 268 | logger.info("args: {}".format(args)) 269 | train(args) 270 | 271 | # sentence = "今天天气真不错啊" 272 | # tokenizer = AutoTokenizer.from_pretrained(args.pretrained) 273 | # special_token_dict = {'additional_special_tokens': ['[X]']} 274 | # tokenizer.add_special_tokens(special_token_dict) 275 | # sentence_template = transform_sentence(sentence, tokenizer, args) 276 | # print(sentence_template) 277 | 278 | # st_encoding = encode_sentences(tokenizer, sentence_template, args) 279 | # print(st_encoding['input_ids']) 280 | -------------------------------------------------------------------------------- /sbert/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Params: 4 | epoches = 100 5 | batch_size = 32 6 | max_length = 32 7 | learning_rate = 2e-5 8 | dropout = 0.2 9 | warmup_steps = 100 10 | display_interval = 500 11 | pretrained_model = "hfl/chinese-roberta-wwm-ext-large" 12 | sbert_model = "models/sbert_0106.pth" 13 | # pretrained_model = "clue/roberta_chinese_pair_large" 14 | pool_type = "mean" 15 | train_file = "data/train_dataset.csv" 16 | test_file = "data/test_dataset.csv" -------------------------------------------------------------------------------- /sbert/loading.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import Dataset 4 | from transformers import AutoTokenizer 5 | from config import Params 6 | 7 | 8 | class LoadDataset(): 9 | def __init__(self, data_file): 10 | data_df = pd.read_csv(data_file) 11 | self.data_df = data_df.fillna("") 12 | 13 | 14 | def get_dataset(self): 15 | question1_list = list(self.data_df['question1']) 16 | question2_list = list(self.data_df['question2']) 17 | label_list = list(self.data_df['label']) 18 | return question1_list, question2_list, label_list 19 | 20 | 21 | def get_encodings(self, tokenzier, questions): 22 | question_encodings = tokenzier(questions, 23 | truncation=True, 24 | padding=True, 25 | max_length=Params.max_length, 26 | return_tensors='pt') 27 | return question_encodings 28 | 29 | 30 | class PairDataset(Dataset): 31 | def __init__(self, q1_encodings, q2_encodings, labels): 32 | self.q1_encodings = q1_encodings 33 | self.q2_encodings = q2_encodings 34 | self.labels = labels 35 | 36 | 37 | # 读取单个样本 38 | def __getitem__(self, idx): 39 | item1 = {key: torch.tensor(val[idx]) for key, val in self.q1_encodings.items()} 40 | labels = torch.tensor(int(self.labels[idx])) 41 | item2 = {key: torch.tensor(val[idx]) for key, val in self.q2_encodings.items()} 42 | return item1, item2, labels 43 | 44 | 45 | def __len__(self): 46 | return len(self.labels) 47 | 48 | 49 | if __name__ == "__main__": 50 | print(1) -------------------------------------------------------------------------------- /sbert/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/1/5 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel 7 | import torch 8 | 9 | 10 | class SBERT(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = BertConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "mean"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | if self.pool_type == "cls": 23 | output = self.encoder(input_ids, 24 | attention_mask=attention_mask, 25 | token_type_ids=token_type_ids) 26 | output = output.last_hidden_state[:, 0] 27 | elif self.pool_type == "pooler": 28 | output = self.encoder(input_ids, 29 | attention_mask=attention_mask, 30 | token_type_ids=token_type_ids) 31 | output = output.pooler_output 32 | elif self.pool_type == "mean": 33 | output = self.get_mean_tensor(input_ids, attention_mask) 34 | return output 35 | 36 | 37 | def get_mean_tensor(self, input_ids, attention_mask): 38 | encode_states = self.encoder(input_ids, attention_mask=attention_mask, output_hidden_states=True) 39 | hidden_states = encode_states.hidden_states 40 | last_avg_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 41 | first_avg_state = self.get_avg_tensor(hidden_states[1], attention_mask) 42 | mean_avg_state = (last_avg_state + first_avg_state) / 2 43 | return mean_avg_state 44 | 45 | 46 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 47 | ''' 48 | layer_hidden_state: 模型一层表征向量 [B * L * D] 49 | attention_mask: 句子padding位置 [B * L] 50 | return: 非零位置词语的平均向量 [B * D] 51 | ''' 52 | layer_hidden_dim = layer_hidden_state.shape[-1] 53 | attention_repeat_mask = attention_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 54 | layer_attention_state = torch.mul(layer_hidden_state, attention_repeat_mask) 55 | layer_sum_state = layer_attention_state.sum(dim=1) 56 | # print(last_attention_state.shape) 57 | 58 | attention_length_mask = attention_mask.sum(dim=-1) 59 | attention_length_repeat_mask = attention_length_mask.unsqueeze(dim=-1).tile(layer_hidden_dim) 60 | # print(attention_length_repeat_mask.shape) 61 | 62 | layer_avg_state = torch.mul(layer_sum_state, 1/attention_length_repeat_mask) 63 | return layer_avg_state 64 | 65 | 66 | def get_avg_tensor2(self, layer_hidden_state, attention_mask): 67 | ''' 68 | layer_hidden_state: 模型一层表征向量 [B * L * D] 69 | attention_mask: 句子padding位置 [B * L] 70 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 71 | ''' 72 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 73 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 74 | sum_mask = input_mask_expanded.sum(1) 75 | sum_mask = torch.clamp(sum_mask, min=1e-9) 76 | avg_embeddings = sum_embeddings / sum_mask 77 | return avg_embeddings -------------------------------------------------------------------------------- /sbert/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup 5 | import numpy as np 6 | from tqdm import tqdm 7 | from sklearn.metrics import accuracy_score 8 | from loading import LoadDataset, PairDataset 9 | from model import SBERT 10 | from config import Params 11 | import os 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 13 | 14 | 15 | def compute_loss(similarity, label, loss_fn): 16 | # mse loss 17 | loss = loss_fn(similarity, label) 18 | return loss 19 | 20 | 21 | def compute_acc(similarity, label): 22 | pred = (similarity >= 0.5).long() 23 | acc = accuracy_score(pred.detach().cpu().numpy(), label.cpu().numpy()) 24 | return acc 25 | 26 | 27 | def testing(model, test_loader): 28 | model.eval() 29 | test_loss, test_acc = [], [] 30 | for test_q1, test_q2, test_label in test_loader: 31 | test_q1_input_ids = test_q1['input_ids'].to(device) 32 | test_q1_attention_mask = test_q1['attention_mask'].to(device) 33 | test_q1_token_type_ids = test_q1['token_type_ids'].to(device) 34 | 35 | test_q2_input_ids = test_q2['input_ids'].to(device) 36 | test_q2_attention_mask = test_q2['attention_mask'].to(device) 37 | test_q2_token_type_ids = test_q2['token_type_ids'].to(device) 38 | 39 | test_label = test_label.float().to(device) 40 | 41 | test_q1_embedding = model(test_q1_input_ids, test_q1_attention_mask, test_q1_token_type_ids) 42 | test_q2_embedding = model(test_q2_input_ids, test_q2_attention_mask, test_q2_token_type_ids) 43 | test_similarity = torch.cosine_similarity(test_q1_embedding, test_q2_embedding, dim=1) 44 | batch_test_loss = compute_loss(test_similarity, test_label, loss_fn) 45 | batch_test_acc = compute_acc(test_similarity, test_label) 46 | 47 | test_loss.append(batch_test_loss.item()) 48 | test_acc.append(batch_test_acc) 49 | 50 | test_avg_loss = np.mean(test_loss) 51 | test_avg_acc = np.mean(test_acc) 52 | return test_avg_loss, test_avg_acc 53 | 54 | 55 | train_loading = LoadDataset(Params.train_file) 56 | test_loading = LoadDataset(Params.test_file) 57 | 58 | # load tokenizer 59 | tokenizer = AutoTokenizer.from_pretrained(Params.pretrained_model) 60 | 61 | # load train dataset 62 | train_question1, train_question2, train_labels = train_loading.get_dataset() 63 | train_q1_encodings = train_loading.get_encodings(tokenizer, train_question1) 64 | train_q2_encodings = train_loading.get_encodings(tokenizer, train_question2) 65 | train_dataset = PairDataset(train_q1_encodings, train_q2_encodings, train_labels) 66 | train_loader = DataLoader(train_dataset, 67 | batch_size=Params.batch_size, 68 | shuffle=True) 69 | 70 | # load test dataset 71 | test_question1, test_question2, test_labels = test_loading.get_dataset() 72 | test_q1_encodings = test_loading.get_encodings(tokenizer, test_question1) 73 | test_q2_encodings = test_loading.get_encodings(tokenizer, test_question2) 74 | test_dataset = PairDataset(test_q1_encodings, test_q2_encodings, test_labels) 75 | test_loader = DataLoader(test_dataset, 76 | batch_size=Params.batch_size) 77 | 78 | # load model 79 | model = SBERT(Params.pretrained_model, Params.pool_type, Params.dropout) 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | model.to(device) 82 | 83 | # load loss function 84 | loss_fn = nn.MSELoss() 85 | 86 | # load optimizer 87 | optim = AdamW(model.parameters(), lr=Params.learning_rate) 88 | total_steps = len(train_loader) 89 | scheduler = get_linear_schedule_with_warmup(optim, 90 | num_warmup_steps = Params.warmup_steps, # Default value in run_glue.py 91 | num_training_steps = total_steps) 92 | 93 | 94 | best_loss = 100000 95 | for epoch in range(Params.epoches): 96 | model.train() 97 | batch_num = 0 98 | epoch_losses = [] 99 | epoch_acces = [] 100 | for q1, q2, label in tqdm(train_loader): 101 | 102 | q1_input_ids = q1['input_ids'].to(device) 103 | q1_attention_mask = q1['attention_mask'].to(device) 104 | q1_token_type_ids = q1['token_type_ids'].to(device) 105 | 106 | q2_input_ids = q2['input_ids'].to(device) 107 | q2_attention_mask = q2['attention_mask'].to(device) 108 | q2_token_type_ids = q2['token_type_ids'].to(device) 109 | 110 | label = label.float().to(device) 111 | # print(q1_input_ids, q2_input_ids, label) 112 | 113 | optim.zero_grad() 114 | q1_embedding = model(q1_input_ids, q1_attention_mask, q1_token_type_ids) 115 | q2_embedding = model(q2_input_ids, q2_attention_mask, q2_token_type_ids) 116 | similarity = torch.cosine_similarity(q1_embedding, q2_embedding, dim=1) 117 | batch_loss = compute_loss(similarity, label, loss_fn) 118 | # print("batch loss: {}, type: {}".format(batch_loss.item(), batch_loss.dtype)) 119 | batch_acc = compute_acc(similarity, label) 120 | 121 | # 梯度更新+裁剪 122 | batch_loss.backward() 123 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 124 | 125 | # 参数更新 126 | optim.step() 127 | scheduler.step() 128 | 129 | epoch_losses.append(batch_loss.item()) 130 | epoch_acces.append(batch_acc) 131 | 132 | if batch_num % Params.display_interval == 0: 133 | print("Epoch: {}, batch: {}/{}, loss: {}, acc: {}".format(epoch, batch_num, total_steps, batch_loss, batch_acc), flush=True) 134 | batch_num += 1 135 | 136 | epoch_avg_loss = np.mean(epoch_losses) 137 | epoch_avg_acc = np.mean(epoch_acces) 138 | print("Epoch: {}, avg loss: {}, acc: {}".format(epoch, epoch_avg_loss, epoch_avg_acc), flush=True) 139 | if epoch_avg_loss < best_loss: 140 | test_avg_loss, test_avg_acc = testing(model, test_loader) 141 | 142 | print("Epoch: {}, train loss: {}, acc: {}, test loss: {}, acc: {}, save best model".format(epoch, epoch_avg_loss, epoch_avg_acc, test_avg_loss, test_avg_acc), flush=True) 143 | torch.save({ 144 | 'epoch': epoch, 145 | 'model_state_dict': model.state_dict(), 146 | 'train_loss': epoch_avg_loss, 147 | 'train_acc': epoch_avg_acc, 148 | 'test_loss': test_avg_loss, 149 | 'test_acc': test_avg_acc 150 | }, Params.sbert_model) 151 | best_loss = epoch_avg_loss 152 | 153 | 154 | -------------------------------------------------------------------------------- /simcse/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Macielyoung/sentence_representation_matching/810bd68e366810814572876fd9cdb380238e5b19/simcse/.DS_Store -------------------------------------------------------------------------------- /simcse/SimCSE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel 7 | import torch 8 | 9 | 10 | class SimCSE(nn.Module): 11 | def __init__(self, pretrained="hfl/chinese-bert-wwm-ext", pool_type="cls", dropout_prob=0.3): 12 | super().__init__() 13 | conf = BertConfig.from_pretrained(pretrained) 14 | conf.attention_probs_dropout_prob = dropout_prob 15 | conf.hidden_dropout_prob = dropout_prob 16 | self.encoder = BertModel.from_pretrained(pretrained, config=conf) 17 | assert pool_type in ["cls", "pooler", "avg_first_last", "avg_last_two"], "invalid pool_type: %s" % pool_type 18 | self.pool_type = pool_type 19 | 20 | 21 | def forward(self, input_ids, attention_mask, token_type_ids): 22 | output = self.encoder(input_ids, 23 | attention_mask=attention_mask, 24 | token_type_ids=token_type_ids, 25 | output_hidden_states=True) 26 | hidden_states = output.hidden_states 27 | if self.pool_type == "cls": 28 | output = output.last_hidden_state[:, 0] 29 | elif self.pool_type == "pooler": 30 | output = output.pooler_output 31 | elif self.pool_type == "avg_first_last": 32 | top_first_state = self.get_avg_tensor(hidden_states[1], attention_mask) 33 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 34 | output = (top_first_state + last_first_state) / 2 35 | else: 36 | last_first_state = self.get_avg_tensor(hidden_states[-1], attention_mask) 37 | last_second_state = self.get_avg_tensor(hidden_states[-2], attention_mask) 38 | output = (last_first_state + last_second_state) / 2 39 | return output 40 | 41 | 42 | def get_avg_tensor(self, layer_hidden_state, attention_mask): 43 | ''' 44 | layer_hidden_state: 模型一层表征向量 [B * L * D] 45 | attention_mask: 句子padding位置 [B * L] 46 | return: avg_embeddings, 非零位置词语的平均向量 [B * D] 47 | ''' 48 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(layer_hidden_state.size()).float() 49 | sum_embeddings = torch.sum(layer_hidden_state * input_mask_expanded, 1) 50 | sum_mask = input_mask_expanded.sum(1) 51 | sum_mask = torch.clamp(sum_mask, min=1e-9) 52 | avg_embeddings = sum_embeddings / sum_mask 53 | return avg_embeddings -------------------------------------------------------------------------------- /simcse/loading.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class MatchingDataSet: 9 | def read_train_file(self, trainfile, devfile, testfile, filetype): 10 | sents = [] 11 | if filetype == "txt": 12 | with open(trainfile, 'r') as f: 13 | for line in f.readlines(): 14 | _, s1, s2, _ = line.strip().split(u"||") 15 | sents.append(s1) 16 | sents.append(s2) 17 | with open(devfile, 'r') as f: 18 | for line in f.readlines(): 19 | _, s1, s2, _ = line.strip().split(u"||") 20 | sents.append(s1) 21 | sents.append(s2) 22 | with open(testfile, 'r') as f: 23 | for line in f.readlines(): 24 | _, s1, s2, _ = line.strip().split(u"||") 25 | sents.append(s1) 26 | sents.append(s2) 27 | return sents 28 | 29 | def read_eval_file(self, file, filetype): 30 | sents = [] 31 | if filetype == "txt": 32 | with open(file, 'r') as f: 33 | for line in f.readlines(): 34 | _, s1, s2, s = line.strip().split(u"||") 35 | item = {'sent1': s1, 36 | 'sent2': s2, 37 | 'score': float(s)} 38 | sents.append(item) 39 | return sents 40 | 41 | 42 | if __name__ == "__main__": 43 | trainfile = "../dataset/STS-B/train.txt" 44 | devfile = "../dataset/STS-B/dev.txt" 45 | testfile = "../dataset/STS-B/test.txt" 46 | match_dataset = MatchingDataSet() 47 | 48 | train_list = match_dataset.read_train_file(trainfile, devfile, testfile, "txt") 49 | print(train_list[:5]) 50 | 51 | train_lengths = [len(sentence) for sentence in train_list] 52 | max_len = max(train_lengths) 53 | 54 | 55 | dev_list = match_dataset.read_eval_file(devfile, "txt") 56 | dev_sen1_length = [len(d['sent1']) for d in dev_list] 57 | dev_sen2_length = [len(d['sent2']) for d in dev_list] 58 | max_sen1 = max(dev_sen1_length) 59 | max_sen2 = max(dev_sen2_length) 60 | print(max_len, max_sen1, max_sen2) 61 | # dev_loader = DataLoader(dev_list, 62 | # batch_size=8) 63 | # for batch in dev_loader: 64 | # print(batch) 65 | # exit(0) -------------------------------------------------------------------------------- /simcse/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/19 3 | # @Author : Maciel 4 | 5 | from loading import MatchingDataSet 6 | import torch 7 | import torch.nn.functional as F 8 | import scipy.stats 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | from SimCSE import SimCSE 12 | from transformers import AutoTokenizer 13 | import os 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 15 | 16 | 17 | def eval(model, tokenizer, test_loader, device, max_length): 18 | model.eval() 19 | model.to(device) 20 | 21 | all_sims, all_scores = [], [] 22 | with torch.no_grad(): 23 | for data in test_loader: 24 | sent1 = data['sent1'] 25 | sent2 = data['sent2'] 26 | score = data['score'] 27 | sent1_encoding = tokenizer(sent1, 28 | padding=True, 29 | truncation=True, 30 | max_length=max_length, 31 | return_tensors='pt') 32 | sent1_encoding = {key: value.to(device) for key, value in sent1_encoding.items()} 33 | sent2_encoding = tokenizer(sent2, 34 | padding=True, 35 | truncation=True, 36 | max_length=max_length, 37 | return_tensors='pt') 38 | sent2_encoding = {key: value.to(device) for key, value in sent2_encoding.items()} 39 | 40 | sent1_output = model(**sent1_encoding) 41 | sent2_output = model(**sent2_encoding) 42 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 43 | all_sims += sim_score 44 | all_scores += score.tolist() 45 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 46 | return corr 47 | 48 | 49 | def test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length): 50 | match_dataset = MatchingDataSet() 51 | testfile_type = "txt" 52 | test_list = match_dataset.read_eval_file(testfile, testfile_type) 53 | print("test samples num: {}".format(len(test_list))) 54 | 55 | test_loader = DataLoader(test_list, 56 | batch_size=128) 57 | print("test batch num: {}".format(len(test_loader))) 58 | 59 | tokenizer = AutoTokenizer.from_pretrained(pretrained) 60 | model = SimCSE(pretrained, pool_type, dropout_rate) 61 | model.load_state_dict(torch.load(model_path)['model_state_dict']) 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | 64 | test_corr = eval(model, tokenizer, test_loader, device, max_length) 65 | print("test corr: {}".format(test_corr)) 66 | 67 | 68 | if __name__ == "__main__": 69 | testfile = "../dataset/STS-B/test.txt" 70 | pretrained = "hfl/chinese-roberta-wwm-ext-large" 71 | pool_type = "avg_first_last" 72 | dropout_rate = 0.3 73 | max_length = 128 74 | model_path = "../models/simcse_roberta_large_stsb.pth" 75 | 76 | test(testfile, pretrained, pool_type, dropout_rate, model_path, max_length) 77 | -------------------------------------------------------------------------------- /simcse/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/17 3 | # @Author : Maciel 4 | 5 | 6 | import argparse 7 | import os 8 | from loading import MatchingDataSet 9 | import torch 10 | import torch.nn.functional as F 11 | import scipy.stats 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoTokenizer 14 | import numpy as np 15 | from SimCSE import SimCSE 16 | from loguru import logger 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | logger.add("../runtime.log") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument("--trainfile", type=str, default="../dataset/STS-B/train.txt", help="train file path") 24 | parser.add_argument("--devfile", type=str, default="../dataset/STS-B/dev.txt", help="dev file path") 25 | parser.add_argument("--testfile", type=str, default="../dataset/STS-B/test.txt", help="test file path") 26 | parser.add_argument("--filetype", type=str, default="txt", help="train and dev file type") 27 | parser.add_argument("--pretrained", type=str, default="hfl/chinese-roberta-wwm-ext-large", help="huggingface pretrained model") 28 | parser.add_argument("--model_out", type=str, default="../models/simcse_roberta_large_stsb.pth", help="model output path") 29 | # parser.add_argument("--num_proc", type=int, default=1, help="dataset process thread num") 30 | parser.add_argument("--max_length", type=int, default=128, help="sentence max length") 31 | parser.add_argument("--batch_size", type=int, default=16, help="batch size") 32 | parser.add_argument("--epochs", type=int, default=100, help="epochs") 33 | parser.add_argument("--lr", type=float, default=3e-5, help="learning rate") 34 | parser.add_argument("--tao", type=float, default=0.05, help="temperature") 35 | parser.add_argument("--device", type=str, default="cuda", help="device") 36 | parser.add_argument("--display_interval", type=int, default=100, help="display interval") 37 | # parser.add_argument("--save_interval", type=int, default=10, help="save interval") 38 | parser.add_argument("--pool_type", type=str, default="avg_first_last", help="pool_type") 39 | parser.add_argument("--dropout_rate", type=float, default=0.3, help="dropout_rate") 40 | parser.add_argument("--task", type=str, default="simcse", help="task name") 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | def duplicate_batch(batch): 46 | ''' 47 | 重复两次数据 48 | ''' 49 | new_batch = [] 50 | for sentence in batch: 51 | new_batch += [sentence, sentence] 52 | return new_batch 53 | 54 | 55 | def compute_loss(y_pred, tao=0.05, device="cuda"): 56 | idxs = torch.arange(0, y_pred.shape[0], device=device) 57 | y_true = idxs + 1 - idxs % 2 * 2 58 | similarities = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=2) 59 | similarities = similarities - torch.eye(y_pred.shape[0], device=device) * 1e12 60 | similarities = similarities / tao 61 | loss = F.cross_entropy(similarities, y_true) 62 | return torch.mean(loss) 63 | 64 | 65 | def eval(model, tokenizer, dev_loader, args): 66 | model.eval() 67 | model.to(args.device) 68 | 69 | all_sims, all_scores = [], [] 70 | with torch.no_grad(): 71 | for data in dev_loader: 72 | sent1 = data['sent1'] 73 | sent2 = data['sent2'] 74 | score = data['score'] 75 | sent1_encoding = tokenizer(sent1, 76 | padding=True, 77 | truncation=True, 78 | max_length=args.max_length, 79 | return_tensors='pt') 80 | sent1_encoding = {key: value.to(args.device) for key, value in sent1_encoding.items()} 81 | sent2_encoding = tokenizer(sent2, 82 | padding=True, 83 | truncation=True, 84 | max_length=args.max_length, 85 | return_tensors='pt') 86 | sent2_encoding = {key: value.to(args.device) for key, value in sent2_encoding.items()} 87 | 88 | sent1_output = model(**sent1_encoding) 89 | sent2_output = model(**sent2_encoding) 90 | sim_score = F.cosine_similarity(sent1_output, sent2_output).cpu().tolist() 91 | all_sims += sim_score 92 | all_scores += score.tolist() 93 | corr = scipy.stats.spearmanr(all_sims, all_scores).correlation 94 | return corr 95 | 96 | 97 | def train(args): 98 | train_file = args.trainfile 99 | dev_file = args.devfile 100 | test_file = args.testfile 101 | file_type = args.filetype 102 | match_dataset = MatchingDataSet() 103 | train_list = match_dataset.read_train_file(train_file, dev_file, test_file, file_type) 104 | dev_list = match_dataset.read_eval_file(dev_file, file_type) 105 | logger.info("train samples num: {}, dev samples num: {}".format(len(train_list), len(dev_list))) 106 | 107 | train_loader = DataLoader(train_list, 108 | batch_size=args.batch_size, 109 | shuffle=True) 110 | dev_loader = DataLoader(dev_list, 111 | batch_size=args.batch_size) 112 | logger.info("train batch num: {}, dev batch num: {}".format(len(train_loader), len(dev_loader))) 113 | 114 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained) 115 | model = SimCSE(args.pretrained, args.pool_type, args.dropout_rate) 116 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 117 | model.train() 118 | model.to(args.device) 119 | 120 | batch_idx = 0 121 | best_corr = 0 122 | best_loss = 1000000 123 | for epoch in range(args.epochs): 124 | epoch_losses = [] 125 | for data in train_loader: 126 | batch_idx += 1 127 | batch_data = duplicate_batch(data) 128 | encodings = tokenizer(batch_data, 129 | padding=True, 130 | truncation=True, 131 | max_length=args.max_length, 132 | return_tensors='pt') 133 | encodings = {key: value.to(args.device) for key, value in encodings.items()} 134 | output = model(**encodings) 135 | batch_loss = compute_loss(output, args.tao, args.device) 136 | optimizer.zero_grad() 137 | batch_loss.backward() 138 | optimizer.step() 139 | epoch_losses.append(batch_loss.item()) 140 | if batch_idx % args.display_interval == 0: 141 | logger.info("epoch: {}, batch: {}, loss: {}".format(epoch, batch_idx, batch_loss.item())) 142 | avg_epoch_loss = np.mean(epoch_losses) 143 | dev_corr = eval(model, tokenizer, dev_loader, args) 144 | logger.info("epoch: {}, avg loss: {}, dev corr: {}".format(epoch, avg_epoch_loss, dev_corr)) 145 | if dev_corr >= best_corr and avg_epoch_loss <= best_loss: 146 | best_corr = dev_corr 147 | best_loss = avg_epoch_loss 148 | torch.save({ 149 | 'epoch': epoch, 150 | 'batch': batch_idx, 151 | 'model_state_dict': model.state_dict(), 152 | 'loss': best_loss, 153 | 'corr': best_corr 154 | }, args.model_out) 155 | logger.info("epoch: {}, batch: {}, best loss: {}, best corr: {}, save model".format(epoch, batch_idx, avg_epoch_loss, dev_corr)) 156 | 157 | 158 | if __name__ == "__main__": 159 | args = parse_args() 160 | train(args) 161 | --------------------------------------------------------------------------------