├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── examples ├── data │ ├── sup_sample.csv │ ├── sup_with_neg_sample.csv │ └── unsup_sample.csv ├── supervised_neg_train.py ├── supervised_train.py ├── unsupervised_train.py └── wentian_train.py ├── requirements.txt └── simcse_tf2 ├── data.py ├── losses.py └── simcse.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimCSE for TensorFlow v2 2 | [**中文**](https://github.com/jifei/simcse-tf2/blob/master/README_CN.md) | **English** 3 | 4 | ![Python Versions](https://img.shields.io/badge/Python-3.0+-blue.svg) 5 | ![TensorFlow Versions](https://img.shields.io/badge/TensorFlow-2.0+-blue.svg) 6 | 7 | A TensorFlow 2 Keras implementation of SimCSE with unsupervised and supervised. 8 | ## Examples 9 | - supervised. [code](https://github.com/jifei/simcse-tf2/blob/master/examples/supervised_train.py) 10 | - supervised train with negative sampling(include random or hard). [code](https://github.com/jifei/simcse-tf2/blob/master/examples/supervised_neg_train.py) 11 | - unsupervised. [code](https://github.com/jifei/simcse-tf2/blob/master/examples/unsupervised_train.py) 12 | 13 | ## References 14 | - [bojone/Bert4Keras](https://github.com/bojone/bert4keras) & [bojone/SimCSE](https://github.com/bojone/SimCSE) 15 | - [princeton-nlp/SimCSE](https://github.com/princeton-nlp/SimCSE) 16 | - [enze5088/WenTianSearch](https://github.com/enze5088/WenTianSearch) 17 | - [muyuuuu/E-commerce-Search-Recall](https://github.com/muyuuuu/E-commerce-Search-Recall) -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # SimCSE for TensorFlow v2 2 | **中文** | [**English**](https://github.com/jifei/simcse-tf2) 3 | 4 | ![Python Versions](https://img.shields.io/badge/Python-3.0+-blue.svg) 5 | ![TensorFlow Versions](https://img.shields.io/badge/TensorFlow-2.0+-blue.svg) 6 | 7 | 基于TensorFlow 2.x keras 开发的 SimCSE,支持无监督和有监督的训练. 8 | ## 例子 9 | - 有监督. [代码](https://github.com/jifei/simcse-tf2/blob/master/examples/supervised_train.py) 10 | - 增加负采样的有监督(包含随机负采样和难负采样). [代码](https://github.com/jifei/simcse-tf2/blob/master/examples/supervised_neg_train.py) 11 | - 无监督. [代码](https://github.com/jifei/simcse-tf2/blob/master/examples/unsupervised_train.py) 12 | - 阿里问天引擎训练部分. [代码](https://github.com/jifei/simcse-tf2/blob/master/examples/wentian_train.py) 13 | 14 | [“阿里灵杰”问天引擎电商搜索算法赛](https://tianchi.aliyun.com/competition/entrance/531946/introduction?spm=5176.12281957.1004.5.38b02448HKvsCR) 数据处理,打包提交可以参考 [enze5088/WenTianSearch](https://github.com/enze5088/WenTianSearch) 中的代码。本代码中的训练结果可以达到0.24左右的成绩,有各种超参数可以自行选择和调试。 15 | 16 | ## 参考 17 | - [bojone/Bert4Keras](https://github.com/bojone/bert4keras) & [bojone/SimCSE](https://github.com/bojone/SimCSE) 18 | - [princeton-nlp/SimCSE](https://github.com/princeton-nlp/SimCSE) 19 | - [enze5088/WenTianSearch](https://github.com/enze5088/WenTianSearch) 20 | - [muyuuuu/E-commerce-Search-Recall](https://github.com/muyuuuu/E-commerce-Search-Recall) -------------------------------------------------------------------------------- /examples/data/sup_sample.csv: -------------------------------------------------------------------------------- 1 | ximapex 猛鸡Ultra阿杰xim apex pc ps5键鼠转换器 送最强配置一对一教学 2 | catia软件安装 Catia V5R20/21 V5-6R2013/R2014/R2015/2016/2017/2020远程安装 3 | 杨琴琴架盒 扬琴架子盒铝合金包边万象轮401、402杨琴架专用架子盒杨琴盒箱包 4 | 高亮度长交通指挥棒 交通指挥棒充电消防应急疏散照明发光闪光棒LED演唱会手持荧光棒 5 | 专业乒乓球运动袜 乒乓袜男士超厚运动袜中筒全棉加厚男袜毛巾底吸汗防臭羽毛球袜子 6 | 易蒙停药 易蒙停盐酸洛哌丁胺胶囊2mg*12粒/盒增加大便稠硬度油滴状粪便慢性腹泻小儿腹泻排便失禁造瘘急慢性腹泻 7 | 树苗家居生活馆 简约植物绿植玻璃贴纸阳台浴室厨房装饰磨砂贴膜半透窗户静电贴画 8 | 德亚一斤装牛奶 德亚旗舰店折上折购物金 充值800得830/充值1000得1050 9 | 折叠床罩防尘罩 欧润哲 折叠床防尘罩 办公室单人午休床套罩袋子收纳对折床防尘袋 10 | 公鸡母 鸡画 超清晰彩色十字绣重绘图纸/源文件名画 公鸡母鸡和小鸡 11 | 天山雪莲花 雪莲花天山雪莲正品新疆天山干雪莲花茶干花泡酒中药非野生中药材 12 | 二胡弦松香 液体松香二胡大中小提琴贝司板胡无尘拉弦乐器通用演奏琴防滑配件 13 | 综合滤盒6006 3M6001CN过滤盒6002 6003 6004 6005 6006滤毒盒6200面具配件 14 | 秒变瓜子脸 【小红书推荐】大脸克星 脸大不求人 妙变瓜子脸 v脸神器 拍1发5~ 15 | 奥尔良烤翅 Tyson/泰森新奥尔良鸡翅300g*5烧烤烤肉料烤鸡翅腌料鸡肉零食冷冻 16 | 扣背神器 祝浩康硅胶拍痰器老人催痰杯拍背器按摩器扣痰碗成人儿童拍痰器 17 | 小椅包 空钩悟道2021款钓椅后挂包侧挂包多功能防水轻便型配套收纳渔具包 18 | 浙江教育出版社旗舰店论语 当当网 正版书籍 论语新版全文注音+详尽注释+精准翻译国学启蒙经典 19 | 反渗透滤布 刮刀式离心机滤布丙纶耐酸碱单复丝双层滤片100-6000目高精度滤网 20 | 年画挂历 13张名胜墙上挂历大号山水画2022家用挂墙虎年日历月历定制财源滚 -------------------------------------------------------------------------------- /examples/data/sup_with_neg_sample.csv: -------------------------------------------------------------------------------- 1 | ximapex 猛鸡Ultra阿杰xim apex pc ps5键鼠转换器 送最强配置一对一教学 测试负样本1 2 | catia软件安装 Catia V5R20/21 V5-6R2013/R2014/R2015/2016/2017/2020远程安装 测试负样本2 3 | 杨琴琴架盒 扬琴架子盒铝合金包边万象轮401、402杨琴架专用架子盒杨琴盒箱包 测试负样本3 4 | 高亮度长交通指挥棒 交通指挥棒充电消防应急疏散照明发光闪光棒LED演唱会手持荧光棒 测试负样本4 5 | 专业乒乓球运动袜 乒乓袜男士超厚运动袜中筒全棉加厚男袜毛巾底吸汗防臭羽毛球袜子 测试负样本5 6 | 易蒙停药 易蒙停盐酸洛哌丁胺胶囊2mg*12粒/盒增加大便稠硬度油滴状粪便慢性腹泻小儿腹泻排便失禁造瘘急慢性腹泻 测试负样本6 7 | 树苗家居生活馆 简约植物绿植玻璃贴纸阳台浴室厨房装饰磨砂贴膜半透窗户静电贴画 测试负样本7 8 | 德亚一斤装牛奶 德亚旗舰店折上折购物金 充值800得830/充值1000得1050 测试负样本8 9 | 折叠床罩防尘罩 欧润哲 折叠床防尘罩 办公室单人午休床套罩袋子收纳对折床防尘袋 测试负样本9 10 | 公鸡母 鸡画 超清晰彩色十字绣重绘图纸/源文件名画 公鸡母鸡和小鸡 测试负样本10 11 | 天山雪莲花 雪莲花天山雪莲正品新疆天山干雪莲花茶干花泡酒中药非野生中药材 测试负样本11 12 | 二胡弦松香 液体松香二胡大中小提琴贝司板胡无尘拉弦乐器通用演奏琴防滑配件 测试负样本12 13 | 综合滤盒6006 3M6001CN过滤盒6002 6003 6004 6005 6006滤毒盒6200面具配件 测试负样本13 14 | 秒变瓜子脸 【小红书推荐】大脸克星 脸大不求人 妙变瓜子脸 v脸神器 拍1发5~ 测试负样本14 15 | 奥尔良烤翅 Tyson/泰森新奥尔良鸡翅300g*5烧烤烤肉料烤鸡翅腌料鸡肉零食冷冻 测试负样本15 16 | 扣背神器 祝浩康硅胶拍痰器老人催痰杯拍背器按摩器扣痰碗成人儿童拍痰器 测试负样本16 17 | 小椅包 空钩悟道2021款钓椅后挂包侧挂包多功能防水轻便型配套收纳渔具包 测试负样本17 18 | 浙江教育出版社旗舰店论语 当当网 正版书籍 论语新版全文注音+详尽注释+精准翻译国学启蒙经典 测试负样本18 19 | 反渗透滤布 刮刀式离心机滤布丙纶耐酸碱单复丝双层滤片100-6000目高精度滤网 测试负样本19 20 | 年画挂历 13张名胜墙上挂历大号山水画2022家用挂墙虎年日历月历定制财源滚 测试负样本20 -------------------------------------------------------------------------------- /examples/data/unsup_sample.csv: -------------------------------------------------------------------------------- 1 | 猛鸡Ultra阿杰xim apex pc ps5键鼠转换器 送最强配置一对一教学 2 | Catia V5R20/21 V5-6R2013/R2014/R2015/2016/2017/2020远程安装 3 | 扬琴架子盒铝合金包边万象轮401、402杨琴架专用架子盒杨琴盒箱包 4 | 交通指挥棒充电消防应急疏散照明发光闪光棒LED演唱会手持荧光棒 5 | 乒乓袜男士超厚运动袜中筒全棉加厚男袜毛巾底吸汗防臭羽毛球袜子 6 | 易蒙停盐酸洛哌丁胺胶囊2mg*12粒/盒增加大便稠硬度油滴状粪便慢性腹泻小儿腹泻排便失禁造瘘急慢性腹泻 7 | 简约植物绿植玻璃贴纸阳台浴室厨房装饰磨砂贴膜半透窗户静电贴画 8 | 德亚旗舰店折上折购物金 充值800得830/充值1000得1050 9 | 欧润哲 折叠床防尘罩 办公室单人午休床套罩袋子收纳对折床防尘袋 10 | 超清晰彩色十字绣重绘图纸/源文件名画 公鸡母鸡和小鸡 11 | 雪莲花天山雪莲正品新疆天山干雪莲花茶干花泡酒中药非野生中药材 12 | 液体松香二胡大中小提琴贝司板胡无尘拉弦乐器通用演奏琴防滑配件 13 | 3M6001CN过滤盒6002 6003 6004 6005 6006滤毒盒6200面具配件 14 | 【小红书推荐】大脸克星 脸大不求人 妙变瓜子脸 v脸神器 拍1发5~ 15 | Tyson/泰森新奥尔良鸡翅300g*5烧烤烤肉料烤鸡翅腌料鸡肉零食冷冻 16 | 祝浩康硅胶拍痰器老人催痰杯拍背器按摩器扣痰碗成人儿童拍痰器 17 | 空钩悟道2021款钓椅后挂包侧挂包多功能防水轻便型配套收纳渔具包 18 | 当当网 正版书籍 论语新版全文注音+详尽注释+精准翻译国学启蒙经典 19 | 刮刀式离心机滤布丙纶耐酸碱单复丝双层滤片100-6000目高精度滤网 20 | 13张名胜墙上挂历大号山水画2022家用挂墙虎年日历月历定制财源滚 -------------------------------------------------------------------------------- /examples/supervised_neg_train.py: -------------------------------------------------------------------------------- 1 | from simcse_tf2.simcse import simcse 2 | from simcse_tf2.data import load_data, SimCseDataGenerator 3 | from simcse_tf2.losses import simcse_hard_neg_loss 4 | import tensorflow as tf 5 | 6 | if __name__ == '__main__': 7 | # 1. bert config 8 | model_path = '/Users/jifei/models/bert/chinese_L-12_H-768_A-12' 9 | # model_path = '/Users/jifei/models/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12' 10 | checkpoint_path = '%s/bert_model.ckpt' % model_path 11 | config_path = '%s/bert_config.json' % model_path 12 | dict_path = '%s/vocab.txt' % model_path 13 | 14 | # 2. set hyper parameters 15 | max_len = 64 16 | pooling = 'cls' # in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 17 | dropout_rate = 0.1 18 | batch_size = 64 19 | learning_rate = 5e-5 20 | epochs = 2 21 | output_units = 128 22 | activation = 'tanh' 23 | 24 | # 3. data generator 25 | # random negative sampling 26 | train_data = load_data('./data/sup_sample.csv', delimiter='\t', skip_header=False, random_negative_sampling=True) 27 | # with hard negative sampling 28 | # train_data = load_data('./data/sup_with_neg_sample.csv', delimiter='\t') 29 | 30 | train_generator = SimCseDataGenerator(train_data, dict_path, batch_size, max_len, text_tuple_size=3) 31 | # print(next(train_generator.forfit())) 32 | 33 | # 4. build model 34 | model = simcse(config_path, checkpoint_path, dropout_rate=dropout_rate, output_units=output_units, 35 | output_activation=activation) 36 | 37 | # 5. model compile 38 | optimizer = tf.keras.optimizers.Adam(learning_rate) 39 | model.compile(loss=simcse_hard_neg_loss, optimizer=optimizer) 40 | 41 | # 6. model fit 42 | model.fit(train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=epochs) 43 | -------------------------------------------------------------------------------- /examples/supervised_train.py: -------------------------------------------------------------------------------- 1 | from simcse_tf2.simcse import simcse 2 | from simcse_tf2.data import texts_to_ids, load_data, SimCseDataGenerator 3 | from simcse_tf2.losses import simcse_loss 4 | import tensorflow as tf 5 | import numpy as np 6 | import csv 7 | 8 | if __name__ == '__main__': 9 | # 1. bert config 10 | model_path = '/Users/jifei/models/bert/chinese_L-12_H-768_A-12' 11 | # model_path = '/Users/jifei/models/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12' 12 | checkpoint_path = '%s/bert_model.ckpt' % model_path 13 | config_path = '%s/bert_config.json' % model_path 14 | dict_path = '%s/vocab.txt' % model_path 15 | 16 | # 2. set hyper parameters 17 | max_len = 64 18 | pooling = 'cls' # in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 19 | dropout_rate = 0.1 20 | batch_size = 20 21 | learning_rate = 5e-5 22 | epochs = 2 23 | output_units = 128 24 | activation = 'tanh' 25 | 26 | # 3. data generator 27 | train_data = load_data('./data/sup_sample.csv', delimiter='\t') 28 | # print(train_data) # text tuple list 29 | train_generator = SimCseDataGenerator(train_data, dict_path, batch_size, max_len) 30 | # print(next(train_generator.forfit())) 31 | # 4. build model 32 | model = simcse(config_path, checkpoint_path, dropout_rate=dropout_rate, output_units=output_units, 33 | output_activation=activation) 34 | 35 | # 5. model compile 36 | optimizer = tf.keras.optimizers.Adam(learning_rate) 37 | model.compile(loss=simcse_loss, optimizer=optimizer) 38 | 39 | # 6. model fit 40 | model.fit(train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=epochs) 41 | 42 | # 7. predict 43 | corpus = [line[0] for line in csv.reader(open("./data/unsup_sample.csv"), delimiter='\t')] 44 | inputs = texts_to_ids(corpus[:10], dict_path, max_len) 45 | print(model.predict([inputs, np.zeros_like(inputs)])) 46 | -------------------------------------------------------------------------------- /examples/unsupervised_train.py: -------------------------------------------------------------------------------- 1 | from simcse_tf2.simcse import simcse 2 | from simcse_tf2.data import load_data, SimCseDataGenerator 3 | from simcse_tf2.losses import simcse_loss 4 | import tensorflow as tf 5 | 6 | if __name__ == '__main__': 7 | # 1. bert config 8 | model_path = '/Users/jifei/models/bert/chinese_L-12_H-768_A-12' 9 | # model_path = '/Users/jifei/models/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12' 10 | checkpoint_path = '%s/bert_model.ckpt' % model_path 11 | config_path = '%s/bert_config.json' % model_path 12 | dict_path = '%s/vocab.txt' % model_path 13 | 14 | # 2. set hyper parameters 15 | max_len = 64 16 | pooling = 'cls' # in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 17 | dropout_rate = 0.1 18 | batch_size = 64 19 | learning_rate = 5e-5 20 | epochs = 2 21 | output_units = 128 22 | activation = 'tanh' 23 | 24 | # 3. data generator 25 | train_data = load_data('./data/unsup_sample.csv', delimiter='\t') 26 | 27 | train_generator = SimCseDataGenerator(train_data, dict_path, batch_size, max_len, text_tuple_size=1) 28 | # print(next(train_generator.forfit())) 29 | 30 | # 4. build model 31 | model = simcse(config_path, checkpoint_path, dropout_rate=dropout_rate, output_units=output_units, 32 | output_activation=activation) 33 | 34 | # 5. model compile 35 | optimizer = tf.keras.optimizers.Adam(learning_rate) 36 | model.compile(loss=simcse_loss, optimizer=optimizer) 37 | 38 | # 6. model fit 39 | model.fit(train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=epochs) 40 | -------------------------------------------------------------------------------- /examples/wentian_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TF_KERAS"] = '1' 4 | from bert4keras.snippets import sequence_padding 5 | from simcse_tf2.simcse import simcse 6 | from simcse_tf2.data import get_tokenizer, load_data, SimCseDataGenerator 7 | from simcse_tf2.losses import simcse_loss 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | def texts_to_ids(data, tokenizer, max_len=64): 13 | """转换文本数据为id形式 14 | """ 15 | token_ids = [] 16 | for d in data: 17 | token_ids.append(tokenizer.encode(d, maxlen=max_len)[0]) 18 | return sequence_padding(token_ids) 19 | 20 | 21 | def encode_fun(texts, model, tokenizer, maxlen): 22 | inputs = texts_to_ids(texts, tokenizer, maxlen) 23 | 24 | embeddings = model.predict([inputs, np.zeros_like(inputs)]) 25 | return embeddings 26 | 27 | 28 | if __name__ == '__main__': 29 | # 1. bert config 30 | model_path = '/Users/jifei/models/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12' 31 | checkpoint_path = '%s/bert_model.ckpt' % model_path 32 | config_path = '%s/bert_config.json' % model_path 33 | dict_path = '%s/vocab.txt' % model_path 34 | 35 | # 2. set hyper parameters 36 | max_len = 64 37 | pooling = 'cls' # in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 38 | dropout_rate = 0.1 39 | batch_size = 64 40 | learning_rate = 5e-5 41 | epochs = 3 42 | output_units = 128 43 | activation = 'tanh' 44 | 45 | # 3. data generator 46 | train_data = load_data('./query_doc.csv', delimiter=",") 47 | # train_data = load_data('./examples/data/sup_sample.csv', delimiter = "\t") 48 | train_generator = SimCseDataGenerator(train_data, dict_path, batch_size, max_len) 49 | print(next(train_generator.forfit())) 50 | 51 | # 4. build model 52 | model = simcse(config_path, checkpoint_path, dropout_rate=dropout_rate, output_units=output_units, 53 | output_activation=activation) 54 | print(model.summary()) 55 | # 5. model compile 56 | optimizer = tf.keras.optimizers.Adam(learning_rate) 57 | model.compile(loss=simcse_loss, optimizer=optimizer) 58 | 59 | # 6. model fit 60 | model.fit(train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=epochs) 61 | 62 | import csv 63 | from tqdm import tqdm 64 | 65 | pre_batch_size = 5000 66 | corpus = [line[1] for line in csv.reader(open("./data/corpus.tsv"), delimiter='\t')] 67 | query = [line[1] for line in csv.reader(open("./data/dev.query.txt"), delimiter='\t')] 68 | tokenizer = get_tokenizer(dict_path) 69 | query_embedding_file = csv.writer(open('./query_embedding', 'w'), delimiter='\t') 70 | 71 | for i in tqdm(range(0, len(query), pre_batch_size)): 72 | batch_text = query[i:i + pre_batch_size] 73 | print("query size:", len(batch_text)) 74 | temp_embedding = encode_fun(batch_text, model, tokenizer, max_len) 75 | for j in range(len(temp_embedding)): 76 | writer_str = temp_embedding[j].tolist() 77 | writer_str = [format(s, '.8f') for s in writer_str] 78 | writer_str = ','.join(writer_str) 79 | query_embedding_file.writerow([i + j + 200001, writer_str]) 80 | print("query end!") 81 | doc_embedding_file = csv.writer(open('./doc_embedding', 'w'), delimiter='\t') 82 | for i in tqdm(range(0, len(corpus), pre_batch_size)): 83 | batch_text = corpus[i:i + pre_batch_size] 84 | temp_embedding = encode_fun(batch_text, model, tokenizer, max_len) 85 | for j in range(len(temp_embedding)): 86 | writer_str = temp_embedding[j].tolist() 87 | writer_str = [format(s, '.8f') for s in writer_str] 88 | writer_str = ','.join(writer_str) 89 | doc_embedding_file.writerow([i + j + 1, writer_str]) 90 | print("doc end!") 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bert4keras==0.10.8 2 | tensorflow >= 2.0.0 3 | numpy 4 | -------------------------------------------------------------------------------- /simcse_tf2/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | Author: 4 | jiei, jifei@outlook.com 5 | """ 6 | import os 7 | 8 | os.environ["TF_KERAS"] = '1' 9 | from bert4keras.tokenizers import Tokenizer 10 | from bert4keras.snippets import DataGenerator, sequence_padding 11 | import random 12 | import numpy as np 13 | 14 | 15 | def get_tokenizer(dict_path): 16 | """build tokenizer 17 | """ 18 | return Tokenizer(dict_path, do_lower_case=True) 19 | 20 | 21 | def texts_to_ids(texts, dict_path, max_len=64): 22 | """texts to ids 23 | """ 24 | tokenizer = get_tokenizer(dict_path) 25 | token_ids = [] 26 | for t in texts: 27 | token_ids.append(tokenizer.encode(t, maxlen=max_len)[0]) 28 | return sequence_padding(token_ids) 29 | 30 | 31 | class SimCseDataGenerator(DataGenerator): 32 | """Data Generator 33 | 34 | """ 35 | 36 | def __init__(self, data, dict_path, batch_size=32, max_len=64, text_tuple_size=2, buffer_size=None): 37 | 38 | super().__init__(data, batch_size, buffer_size) 39 | assert text_tuple_size in [1, 2, 3] 40 | self.tokenizer = get_tokenizer(dict_path) 41 | self.max_len = max_len 42 | self.text_tuple_size = text_tuple_size 43 | 44 | def __iter__(self, random=False): 45 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 46 | for is_end, texts in self.sample(random): 47 | if self.text_tuple_size == 1: # unsupervised one text,repeat self 48 | batch_token_ids.append(self.tokenizer.encode(texts[0], maxlen=self.max_len)[0]) 49 | batch_token_ids.append(self.tokenizer.encode(texts[0], maxlen=self.max_len)[0]) 50 | elif self.text_tuple_size == 2: # texts pair 51 | batch_token_ids.append(self.tokenizer.encode(texts[0], maxlen=self.max_len)[0]) 52 | batch_token_ids.append(self.tokenizer.encode(texts[1], maxlen=self.max_len)[0]) 53 | else: # negative sampling 54 | batch_token_ids.append(self.tokenizer.encode(texts[0], maxlen=self.max_len)[0]) 55 | batch_token_ids.append(self.tokenizer.encode(texts[1], maxlen=self.max_len)[0]) 56 | batch_token_ids.append(self.tokenizer.encode(texts[2], maxlen=self.max_len)[0]) 57 | 58 | if len(batch_token_ids) == self.batch_size * self.text_tuple_size or is_end: 59 | batch_token_ids = sequence_padding(batch_token_ids) 60 | batch_segment_ids = np.zeros_like(batch_token_ids) 61 | batch_labels = np.zeros_like(batch_token_ids[:, :1]) 62 | yield [batch_token_ids, batch_segment_ids], batch_labels 63 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 64 | 65 | 66 | def load_data(file_name, delimiter='\t', skip_header=False, shuffle=True, random_negative_sampling=False): 67 | """ Load data from file 68 | 69 | :param file_name:string, file path 70 | :param skip_header:bool, need skip first line 71 | :param delimiter:string 72 | :param shuffle:bool, shuffle data 73 | :param random_negative_sampling: bool, Random Negative Sampling. 74 | :return:list,[(text1),...] or [(text1, text2),...] or [(text1, text2, neg text),...] 75 | """ 76 | lines = [] 77 | negs = [] 78 | with open(file_name, encoding='utf-8') as f: 79 | for line in f: 80 | if skip_header: 81 | skip_header = False 82 | else: 83 | columns = line.strip().split(delimiter) 84 | lines.append(tuple(columns)) 85 | if random_negative_sampling and len(columns) == 2: 86 | negs.append(columns[1]) 87 | 88 | if shuffle: 89 | random.shuffle(lines) 90 | if random_negative_sampling: 91 | random.shuffle(negs) 92 | return [(i[0], i[1], negs.pop()) for i in lines] 93 | return lines 94 | -------------------------------------------------------------------------------- /simcse_tf2/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | Author: 4 | jifei, jifei@outlook.com 5 | """ 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def simcse_loss(y_true, y_pred): 11 | """ 12 | simcse loss 13 | """ 14 | idxs = tf.range(0, tf.shape(y_pred)[0]) 15 | idxs_1 = idxs[None, :] 16 | idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None] 17 | y_true = tf.equal(idxs_1, idxs_2) 18 | y_true = tf.cast(y_true, tf.keras.backend.floatx()) 19 | y_pred = tf.math.l2_normalize(y_pred, axis=1) 20 | similarities = tf.matmul(y_pred, y_pred, transpose_b=True) 21 | similarities = similarities - tf.eye(tf.shape(y_pred)[0]) * 1e12 22 | similarities = similarities / 0.05 23 | loss = tf.keras.losses.categorical_crossentropy(y_true, similarities, from_logits=True) 24 | return tf.reduce_mean(loss) 25 | 26 | 27 | def simcse_hard_neg_loss(y_true, y_pred): 28 | """ 29 | simcse loss for hard neg or random neg 30 | """ 31 | row = tf.range(0, tf.shape(y_pred)[0], 3) 32 | col = tf.range(tf.shape(y_pred)[0]) 33 | col = tf.squeeze(tf.where(col % 3 != 0), axis=1) 34 | y_true = tf.range(0, len(col), 2) 35 | y_pred = tf.math.l2_normalize(y_pred, axis=1) 36 | similarities = tf.matmul(y_pred, y_pred, transpose_b=True) 37 | similarities = tf.gather(similarities, row, axis=0) 38 | similarities = tf.gather(similarities, col, axis=1) 39 | similarities = similarities / 0.05 40 | loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, similarities, from_logits=True) 41 | return tf.reduce_mean(loss) 42 | -------------------------------------------------------------------------------- /simcse_tf2/simcse.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | Author: 4 | jifei, jifei@outlook.com 5 | """ 6 | import tensorflow as tf 7 | import os 8 | import json 9 | 10 | os.environ["TF_KERAS"] = '1' 11 | from bert4keras.models import build_transformer_model 12 | 13 | 14 | def simcse(config_path, 15 | checkpoint_path, 16 | model='bert', 17 | pooling='first-last-avg', 18 | dropout_rate=0.1, 19 | output_units=None, 20 | output_activation=None, 21 | ): 22 | """Build SimCSE model 23 | 24 | :param config_path:string 25 | :param checkpoint_path:string 26 | :param model:string, model name 27 | :param pooling:string, in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 28 | :param dropout_rate:float 29 | :param output_units:int 30 | :param output_activation:string 31 | :return: A Keras model instance. 32 | """ 33 | assert pooling in ['first-last-avg', 'last-avg', 'cls', 'pooler'] 34 | with open(config_path, 'r') as load_f: 35 | num_hidden_layers = json.load(load_f)['num_hidden_layers'] 36 | 37 | if pooling == 'pooler': 38 | bert = build_transformer_model( 39 | config_path, 40 | checkpoint_path, 41 | model=model, 42 | with_pool='linear', 43 | dropout_rate=dropout_rate 44 | ) 45 | else: 46 | bert = build_transformer_model(config_path, checkpoint_path, model=model, dropout_rate=dropout_rate) 47 | 48 | last_layer_output = bert.get_layer('Transformer-%d-FeedForward-Norm' % (num_hidden_layers - 1)).output 49 | if pooling == 'first-last-avg': 50 | outputs = [ 51 | tf.keras.layers.GlobalAveragePooling1D()(bert.get_layer('Transformer-%d-FeedForward-Norm' % 0).output), 52 | tf.keras.layers.GlobalAveragePooling1D()(last_layer_output) 53 | ] 54 | output = tf.keras.layers.Average()(outputs) 55 | elif pooling == 'last-avg': 56 | output = tf.keras.layers.GlobalAveragePooling1D()(last_layer_output) 57 | elif pooling == 'cls': 58 | output = tf.keras.layers.Lambda(lambda x: x[:, 0])(last_layer_output) 59 | else: 60 | output = bert.output 61 | 62 | if output_units and output_activation: 63 | output = tf.keras.layers.Dense(output_units, activation=output_activation)(output) 64 | model = tf.keras.Model(bert.inputs, output) 65 | return model 66 | --------------------------------------------------------------------------------