├── CoSENT_ PAWSX.ipynb ├── CoSENT_ATEC.ipynb ├── CoSENT_BQ.ipynb ├── CoSENT_CHIP-STS.ipynb ├── CoSENT_LCQMC.ipynb ├── README.md └── text_match_CHIP-STS.ipynb /CoSENT_ PAWSX.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9ced17a4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "\n", 13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n", 15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "bfcea417", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# 目录地址\n", 26 | "train_data_path = '../data/source_datasets/PAWSX/PAWSX.train.data'\n", 27 | "dev_data_path = '../data/source_datasets/PAWSX/PAWSX.test.data'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "ccea726a", 33 | "metadata": { 34 | "tags": [] 35 | }, 36 | "source": [ 37 | "### 一、数据读入与处理" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "8d5c3337", 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "source": [ 47 | "#### 1. 数据读入" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n", 58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "8cbc9c3e-35e0-4bbb-82fe-c139be87f7f4", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "train_data_df = train_data_df[~train_data_df['text_a'].isnull()]\n", 69 | "dev_data_df = dev_data_df[~dev_data_df['text_a'].isnull()]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "90876062", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "cosent_train_dataset = Dataset(train_data_df)\n", 80 | "cosent_dev_dataset = Dataset(dev_data_df)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "e061890a", 86 | "metadata": {}, 87 | "source": [ 88 | "#### 2. 词典创建和生成分词器" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "be116454", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# 加载分词器\n", 99 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=128)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "0d6c3b3d", 105 | "metadata": {}, 106 | "source": [ 107 | "#### 3. ID化" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "566dd6b6", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "cosent_train_dataset.convert_to_ids(tokenizer)\n", 118 | "cosent_dev_dataset.convert_to_ids(tokenizer)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "981b4160", 124 | "metadata": {}, 125 | "source": [ 126 | "
\n", 127 | "\n", 128 | "### 二、模型构建" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "72753ee8", 134 | "metadata": {}, 135 | "source": [ 136 | "#### 1. 模型参数设置" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "527535d3", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "from transformers import BertConfig\n", 147 | "\n", 148 | "bert_config = BertConfig.from_pretrained(\n", 149 | " 'bert-base-chinese',\n", 150 | " num_labels=len(cosent_train_dataset.cat2id)\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "77d8b580", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "torch.cuda.empty_cache()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "700d7752", 167 | "metadata": {}, 168 | "source": [ 169 | "#### 2. 模型创建" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "58f380f4", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "import torch\n", 180 | "import torch.nn.functional as F\n", 181 | "\n", 182 | "from torch import nn\n", 183 | "from transformers import BertModel\n", 184 | "from ark_nlp.nn import Bert\n", 185 | "\n", 186 | "\n", 187 | "class CoSENT(Bert):\n", 188 | " \"\"\"\n", 189 | " CoSENT模型\n", 190 | "\n", 191 | " Args:\n", 192 | " config:\n", 193 | " 模型的配置对象\n", 194 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n", 195 | " bert参数是否可训练,默认可训练\n", 196 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n", 197 | " bert输出的池化方式,默认为\"last_avg\",\n", 198 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n", 199 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n", 200 | " dropout比例,默认为None,实际设置时会设置成0\n", 201 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n", 202 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n", 203 | "\n", 204 | " Reference:\n", 205 | " [1] https://kexue.fm/archives/8847\n", 206 | " [2] https://github.com/bojone/CoSENT \n", 207 | " \"\"\" # noqa: ignore flake8\"\n", 208 | "\n", 209 | " def __init__(\n", 210 | " self,\n", 211 | " config,\n", 212 | " encoder_trained=True,\n", 213 | " pooling='last_avg',\n", 214 | " dropout=None,\n", 215 | " output_emb_size=0\n", 216 | " ):\n", 217 | "\n", 218 | " super(CoSENT, self).__init__(config)\n", 219 | "\n", 220 | " self.bert = BertModel(config)\n", 221 | " self.pooling = pooling\n", 222 | "\n", 223 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n", 224 | "\n", 225 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n", 226 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n", 227 | " # recall performance and efficiency\n", 228 | " self.output_emb_size = output_emb_size\n", 229 | " if self.output_emb_size > 0:\n", 230 | " self.emb_reduce_linear = nn.Linear(\n", 231 | " config.hidden_size,\n", 232 | " self.output_emb_size\n", 233 | " )\n", 234 | " torch.nn.init.trunc_normal_(\n", 235 | " self.emb_reduce_linear.weight,\n", 236 | " std=0.02\n", 237 | " )\n", 238 | "\n", 239 | " for param in self.bert.parameters():\n", 240 | " param.requires_grad = encoder_trained\n", 241 | "\n", 242 | " self.init_weights()\n", 243 | "\n", 244 | " def get_pooled_embedding(\n", 245 | " self,\n", 246 | " input_ids,\n", 247 | " token_type_ids=None,\n", 248 | " position_ids=None,\n", 249 | " attention_mask=None\n", 250 | " ):\n", 251 | " outputs = self.bert(\n", 252 | " input_ids,\n", 253 | " attention_mask=attention_mask,\n", 254 | " token_type_ids=token_type_ids,\n", 255 | " position_ids=position_ids,\n", 256 | " return_dict=True,\n", 257 | " output_hidden_states=True\n", 258 | " )\n", 259 | "\n", 260 | " encoder_feature = self.get_encoder_feature(\n", 261 | " outputs,\n", 262 | " attention_mask\n", 263 | " )\n", 264 | "\n", 265 | " if self.output_emb_size > 0:\n", 266 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n", 267 | "\n", 268 | " encoder_feature = self.dropout(encoder_feature)\n", 269 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n", 270 | "\n", 271 | " return out\n", 272 | "\n", 273 | " def cosine_sim(\n", 274 | " self,\n", 275 | " input_ids_a,\n", 276 | " input_ids_b,\n", 277 | " token_type_ids_a=None,\n", 278 | " position_ids_ids_a=None,\n", 279 | " attention_mask_a=None,\n", 280 | " token_type_ids_b=None,\n", 281 | " position_ids_b=None,\n", 282 | " attention_mask_b=None,\n", 283 | " **kwargs\n", 284 | " ):\n", 285 | "\n", 286 | " query_cls_embedding = self.get_pooled_embedding(\n", 287 | " input_ids_a,\n", 288 | " token_type_ids_a,\n", 289 | " position_ids_ids_a,\n", 290 | " attention_mask_a\n", 291 | " )\n", 292 | "\n", 293 | " title_cls_embedding = self.get_pooled_embedding(\n", 294 | " input_ids_b,\n", 295 | " token_type_ids_b,\n", 296 | " position_ids_b,\n", 297 | " attention_mask_b\n", 298 | " )\n", 299 | "\n", 300 | " cosine_sim = torch.sum(\n", 301 | " query_cls_embedding * title_cls_embedding,\n", 302 | " axis=-1\n", 303 | " )\n", 304 | "\n", 305 | " return cosine_sim\n", 306 | "\n", 307 | " def forward(\n", 308 | " self,\n", 309 | " input_ids_a,\n", 310 | " input_ids_b,\n", 311 | " token_type_ids_a=None,\n", 312 | " position_ids_ids_a=None,\n", 313 | " attention_mask_a=None,\n", 314 | " token_type_ids_b=None,\n", 315 | " position_ids_b=None,\n", 316 | " attention_mask_b=None,\n", 317 | " label_ids=None,\n", 318 | " **kwargs\n", 319 | " ):\n", 320 | "\n", 321 | " cls_embedding_a = self.get_pooled_embedding(\n", 322 | " input_ids_a,\n", 323 | " token_type_ids_a,\n", 324 | " position_ids_ids_a,\n", 325 | " attention_mask_a\n", 326 | " )\n", 327 | "\n", 328 | " cls_embedding_b = self.get_pooled_embedding(\n", 329 | " input_ids_b,\n", 330 | " token_type_ids_b,\n", 331 | " position_ids_b,\n", 332 | " attention_mask_b\n", 333 | " )\n", 334 | "\n", 335 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n", 336 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n", 337 | " \n", 338 | " labels = label_ids[:, None] < label_ids[None, :]\n", 339 | " labels = labels.long()\n", 340 | " \n", 341 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n", 342 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n", 343 | " loss = torch.logsumexp(cosine_sim, dim=0)\n", 344 | "\n", 345 | " return cosine_sim, loss\n" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "id": "e630530b", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "dl_module = CoSENT.from_pretrained(\n", 356 | " 'bert-base-chinese', \n", 357 | " config=bert_config\n", 358 | ")" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "id": "13e3c8ac", 364 | "metadata": {}, 365 | "source": [ 366 | "
\n", 367 | "\n", 368 | "### 三、任务构建" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "id": "31d1f76c", 374 | "metadata": {}, 375 | "source": [ 376 | "#### 1. 任务参数和必要部件设定" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "id": "943bf64c", 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "# 设置运行次数\n", 387 | "num_epoches = 5\n", 388 | "batch_size = 32" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "id": "74641ede", 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "param_optimizer = list(dl_module.named_parameters())\n", 399 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n", 400 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 401 | "optimizer_grouped_parameters = [\n", 402 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n", 403 | " 'weight_decay': 0.01},\n", 404 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 405 | "] " 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "id": "bd5a9361", 411 | "metadata": {}, 412 | "source": [ 413 | "#### 2. 任务创建" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "5bc04465", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "import torch\n", 424 | "import numpy as np\n", 425 | "from scipy import stats\n", 426 | "\n", 427 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n", 428 | "\n", 429 | "\n", 430 | "class CoSENTTask(SequenceClassificationTask):\n", 431 | " \"\"\"\n", 432 | " 用于CoSENT模型文本匹配任务的Task\n", 433 | " \n", 434 | " Args:\n", 435 | " module: 深度学习模型\n", 436 | " optimizer: 训练模型使用的优化器名或者优化器对象\n", 437 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n", 438 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n", 439 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n", 440 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n", 441 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n", 442 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n", 443 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n", 444 | " **kwargs (optional): 其他可选参数\n", 445 | " \"\"\" # noqa: ignore flake8\"\n", 446 | "\n", 447 | " def _on_evaluate_begin_record(self, **kwargs):\n", 448 | "\n", 449 | " self.evaluate_logs['eval_loss'] = 0\n", 450 | " self.evaluate_logs['eval_step'] = 0\n", 451 | " self.evaluate_logs['eval_example'] = 0\n", 452 | "\n", 453 | " self.evaluate_logs['labels'] = []\n", 454 | " self.evaluate_logs['eval_sim'] = []\n", 455 | "\n", 456 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n", 457 | "\n", 458 | " with torch.no_grad():\n", 459 | " # compute loss\n", 460 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n", 461 | " self.evaluate_logs['eval_loss'] += loss.item()\n", 462 | "\n", 463 | " if 'label_ids' in inputs:\n", 464 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n", 465 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n", 466 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n", 467 | "\n", 468 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n", 469 | " self.evaluate_logs['eval_step'] += 1\n", 470 | "\n", 471 | " def _on_evaluate_epoch_end(\n", 472 | " self,\n", 473 | " validation_data,\n", 474 | " epoch=1,\n", 475 | " is_evaluate_print=True,\n", 476 | " **kwargs\n", 477 | " ):\n", 478 | "\n", 479 | " if is_evaluate_print:\n", 480 | " if 'labels' in self.evaluate_logs:\n", 481 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n", 482 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n", 483 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n", 484 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n", 485 | " spearman_corr,\n", 486 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n", 487 | " )\n", 488 | " )\n", 489 | " else:\n", 490 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "id": "3dfc61d9", 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "id": "35c96cf8", 506 | "metadata": { 507 | "tags": [] 508 | }, 509 | "source": [ 510 | "#### 3. 训练" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "id": "62e3e9ff", 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "model.fit(\n", 521 | " cosent_train_dataset,\n", 522 | " cosent_dev_dataset,\n", 523 | " lr=2e-5,\n", 524 | " epochs=num_epoches,\n", 525 | " batch_size=batch_size,\n", 526 | " params=optimizer_grouped_parameters\n", 527 | ")" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "id": "9b27e57b", 533 | "metadata": {}, 534 | "source": [ 535 | "
\n", 536 | "\n", 537 | "### 四、模型验证" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "id": "adb8effd-c437-4dd5-a10b-db6537780f91", 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "import torch\n", 548 | "\n", 549 | "from torch.utils.data import DataLoader\n", 550 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n", 551 | "\n", 552 | "\n", 553 | "class CoSENTPredictor(SequenceClassificationPredictor):\n", 554 | " \"\"\"\n", 555 | " CoSENT的预测器\n", 556 | " \n", 557 | " Args:\n", 558 | " module: 深度学习模型\n", 559 | " tokernizer: 分词器\n", 560 | " cat2id (:obj:`dict`): 标签映射\n", 561 | " \"\"\" # noqa: ignore flake8\"\n", 562 | "\n", 563 | " def _get_input_ids(\n", 564 | " self,\n", 565 | " text_a,\n", 566 | " text_b\n", 567 | " ):\n", 568 | " if self.tokenizer.tokenizer_type == 'vanilla':\n", 569 | " return self._convert_to_vanilla_ids(text_a, text_b)\n", 570 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n", 571 | " return self._convert_to_transfomer_ids(text_a, text_b)\n", 572 | " elif self.tokenizer.tokenizer_type == 'customized':\n", 573 | " return self._convert_to_customized_ids(text_a, text_b)\n", 574 | " else:\n", 575 | " raise ValueError(\"The tokenizer type does not exist\")\n", 576 | "\n", 577 | " def _convert_to_transfomer_ids(\n", 578 | " self,\n", 579 | " text_a,\n", 580 | " text_b\n", 581 | " ):\n", 582 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n", 583 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n", 584 | "\n", 585 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n", 586 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n", 587 | "\n", 588 | " features = {\n", 589 | " 'input_ids_a': input_ids_a,\n", 590 | " 'attention_mask_a': input_mask_a,\n", 591 | " 'token_type_ids_a': segment_ids_a,\n", 592 | " 'input_ids_b': input_ids_b,\n", 593 | " 'attention_mask_b': input_mask_b,\n", 594 | " 'token_type_ids_b': segment_ids_b\n", 595 | " }\n", 596 | "\n", 597 | " return features\n", 598 | "\n", 599 | " def predict_one_sample(\n", 600 | " self,\n", 601 | " text,\n", 602 | " topk=None,\n", 603 | " threshold=0.5,\n", 604 | " return_label_name=True,\n", 605 | " return_proba=False\n", 606 | " ):\n", 607 | " if topk is None:\n", 608 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n", 609 | " text_a, text_b = text\n", 610 | " features = self._get_input_ids(text_a, text_b)\n", 611 | " self.module.eval()\n", 612 | "\n", 613 | " with torch.no_grad():\n", 614 | " inputs = self._get_module_one_sample_inputs(features)\n", 615 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 616 | "\n", 617 | " _proba = logits[0]\n", 618 | " \n", 619 | " if threshold is not None:\n", 620 | " _pred = self._threshold(_proba, threshold)\n", 621 | "\n", 622 | " if return_label_name and threshold is not None:\n", 623 | " _pred = self.id2cat[_pred]\n", 624 | "\n", 625 | " if threshold is not None:\n", 626 | " if return_proba:\n", 627 | " return [_pred, _proba]\n", 628 | " else:\n", 629 | " return _pred\n", 630 | "\n", 631 | " return _proba\n", 632 | "\n", 633 | " def predict_batch(\n", 634 | " self,\n", 635 | " test_data,\n", 636 | " batch_size=16,\n", 637 | " shuffle=False\n", 638 | " ):\n", 639 | " self.inputs_cols = test_data.dataset_cols\n", 640 | "\n", 641 | " preds = []\n", 642 | "\n", 643 | " self.module.eval()\n", 644 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n", 645 | "\n", 646 | " with torch.no_grad():\n", 647 | " for step, inputs in enumerate(generator):\n", 648 | " inputs = self._get_module_batch_inputs(inputs)\n", 649 | "\n", 650 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 651 | "\n", 652 | " preds.extend(logits)\n", 653 | "\n", 654 | " return preds" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "id": "fce50880-d821-4fc5-8e34-c4ff7c356287", 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": null, 670 | "id": "26a67a56-1f8e-4579-81a2-1c578aaf92ab", 671 | "metadata": {}, 672 | "outputs": [], 673 | "source": [ 674 | "cosent_predictor_instance.predict_one_sample(\n", 675 | " ['1975年的nba赛季 - 76赛季是全美篮球协会的第30个赛季。', \n", 676 | " '1975-76赛季的全国篮球协会是nba的第30个赛季。'],\n", 677 | " threshold=None\n", 678 | ")" 679 | ] 680 | } 681 | ], 682 | "metadata": { 683 | "kernelspec": { 684 | "display_name": "Python 3 (ipykernel)", 685 | "language": "python", 686 | "name": "python3" 687 | }, 688 | "language_info": { 689 | "codemirror_mode": { 690 | "name": "ipython", 691 | "version": 3 692 | }, 693 | "file_extension": ".py", 694 | "mimetype": "text/x-python", 695 | "name": "python", 696 | "nbconvert_exporter": "python", 697 | "pygments_lexer": "ipython3", 698 | "version": "3.8.10" 699 | } 700 | }, 701 | "nbformat": 4, 702 | "nbformat_minor": 5 703 | } 704 | -------------------------------------------------------------------------------- /CoSENT_ATEC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9ced17a4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "\n", 13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n", 15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "bfcea417", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# 目录地址\n", 26 | "train_data_path = '../data/source_datasets/ATEC/ATEC.train.data'\n", 27 | "dev_data_path = '../data/source_datasets/ATEC/ATEC.test.data'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "ccea726a", 33 | "metadata": { 34 | "tags": [] 35 | }, 36 | "source": [ 37 | "### 一、数据读入与处理" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "8d5c3337", 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "source": [ 47 | "#### 1. 数据读入" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n", 58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "90876062", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "cosent_train_dataset = Dataset(train_data_df)\n", 69 | "cosent_dev_dataset = Dataset(dev_data_df)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "e061890a", 75 | "metadata": {}, 76 | "source": [ 77 | "#### 2. 词典创建和生成分词器" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "be116454", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# 加载分词器\n", 88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "0d6c3b3d", 94 | "metadata": {}, 95 | "source": [ 96 | "#### 3. ID化" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "566dd6b6", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n", 107 | "cosent_dev_dataset.convert_to_ids(tokenizer)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "981b4160", 113 | "metadata": {}, 114 | "source": [ 115 | "
\n", 116 | "\n", 117 | "### 二、模型构建" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "72753ee8", 123 | "metadata": { 124 | "tags": [] 125 | }, 126 | "source": [ 127 | "#### 1. 模型参数设置" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "527535d3", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from transformers import BertConfig\n", 138 | "\n", 139 | "bert_config = BertConfig.from_pretrained(\n", 140 | " 'bert-base-chinese',\n", 141 | " num_labels=len(cosent_train_dataset.cat2id)\n", 142 | ")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "77d8b580", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "torch.cuda.empty_cache()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "700d7752", 158 | "metadata": {}, 159 | "source": [ 160 | "#### 2. 模型创建" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "58f380f4", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "import torch\n", 171 | "import torch.nn.functional as F\n", 172 | "\n", 173 | "from torch import nn\n", 174 | "from transformers import BertModel\n", 175 | "from ark_nlp.nn import Bert\n", 176 | "\n", 177 | "\n", 178 | "class CoSENT(Bert):\n", 179 | " \"\"\"\n", 180 | " CoSENT模型\n", 181 | "\n", 182 | " Args:\n", 183 | " config:\n", 184 | " 模型的配置对象\n", 185 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n", 186 | " bert参数是否可训练,默认可训练\n", 187 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n", 188 | " bert输出的池化方式,默认为\"last_avg\",\n", 189 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n", 190 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n", 191 | " dropout比例,默认为None,实际设置时会设置成0\n", 192 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n", 193 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n", 194 | "\n", 195 | " Reference:\n", 196 | " [1] https://kexue.fm/archives/8847\n", 197 | " [2] https://github.com/bojone/CoSENT \n", 198 | " \"\"\" # noqa: ignore flake8\"\n", 199 | "\n", 200 | " def __init__(\n", 201 | " self,\n", 202 | " config,\n", 203 | " encoder_trained=True,\n", 204 | " pooling='last_avg',\n", 205 | " dropout=None,\n", 206 | " output_emb_size=0\n", 207 | " ):\n", 208 | "\n", 209 | " super(CoSENT, self).__init__(config)\n", 210 | "\n", 211 | " self.bert = BertModel(config)\n", 212 | " self.pooling = pooling\n", 213 | "\n", 214 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n", 215 | "\n", 216 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n", 217 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n", 218 | " # recall performance and efficiency\n", 219 | " self.output_emb_size = output_emb_size\n", 220 | " if self.output_emb_size > 0:\n", 221 | " self.emb_reduce_linear = nn.Linear(\n", 222 | " config.hidden_size,\n", 223 | " self.output_emb_size\n", 224 | " )\n", 225 | " torch.nn.init.trunc_normal_(\n", 226 | " self.emb_reduce_linear.weight,\n", 227 | " std=0.02\n", 228 | " )\n", 229 | "\n", 230 | " for param in self.bert.parameters():\n", 231 | " param.requires_grad = encoder_trained\n", 232 | "\n", 233 | " self.init_weights()\n", 234 | "\n", 235 | " def get_pooled_embedding(\n", 236 | " self,\n", 237 | " input_ids,\n", 238 | " token_type_ids=None,\n", 239 | " position_ids=None,\n", 240 | " attention_mask=None\n", 241 | " ):\n", 242 | " outputs = self.bert(\n", 243 | " input_ids,\n", 244 | " attention_mask=attention_mask,\n", 245 | " token_type_ids=token_type_ids,\n", 246 | " position_ids=position_ids,\n", 247 | " return_dict=True,\n", 248 | " output_hidden_states=True\n", 249 | " )\n", 250 | "\n", 251 | " encoder_feature = self.get_encoder_feature(\n", 252 | " outputs,\n", 253 | " attention_mask\n", 254 | " )\n", 255 | "\n", 256 | " if self.output_emb_size > 0:\n", 257 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n", 258 | "\n", 259 | " encoder_feature = self.dropout(encoder_feature)\n", 260 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n", 261 | "\n", 262 | " return out\n", 263 | "\n", 264 | " def cosine_sim(\n", 265 | " self,\n", 266 | " input_ids_a,\n", 267 | " input_ids_b,\n", 268 | " token_type_ids_a=None,\n", 269 | " position_ids_ids_a=None,\n", 270 | " attention_mask_a=None,\n", 271 | " token_type_ids_b=None,\n", 272 | " position_ids_b=None,\n", 273 | " attention_mask_b=None,\n", 274 | " **kwargs\n", 275 | " ):\n", 276 | "\n", 277 | " query_cls_embedding = self.get_pooled_embedding(\n", 278 | " input_ids_a,\n", 279 | " token_type_ids_a,\n", 280 | " position_ids_ids_a,\n", 281 | " attention_mask_a\n", 282 | " )\n", 283 | "\n", 284 | " title_cls_embedding = self.get_pooled_embedding(\n", 285 | " input_ids_b,\n", 286 | " token_type_ids_b,\n", 287 | " position_ids_b,\n", 288 | " attention_mask_b\n", 289 | " )\n", 290 | "\n", 291 | " cosine_sim = torch.sum(\n", 292 | " query_cls_embedding * title_cls_embedding,\n", 293 | " axis=-1\n", 294 | " )\n", 295 | "\n", 296 | " return cosine_sim\n", 297 | "\n", 298 | " def forward(\n", 299 | " self,\n", 300 | " input_ids_a,\n", 301 | " input_ids_b,\n", 302 | " token_type_ids_a=None,\n", 303 | " position_ids_ids_a=None,\n", 304 | " attention_mask_a=None,\n", 305 | " token_type_ids_b=None,\n", 306 | " position_ids_b=None,\n", 307 | " attention_mask_b=None,\n", 308 | " label_ids=None,\n", 309 | " **kwargs\n", 310 | " ):\n", 311 | "\n", 312 | " cls_embedding_a = self.get_pooled_embedding(\n", 313 | " input_ids_a,\n", 314 | " token_type_ids_a,\n", 315 | " position_ids_ids_a,\n", 316 | " attention_mask_a\n", 317 | " )\n", 318 | "\n", 319 | " cls_embedding_b = self.get_pooled_embedding(\n", 320 | " input_ids_b,\n", 321 | " token_type_ids_b,\n", 322 | " position_ids_b,\n", 323 | " attention_mask_b\n", 324 | " )\n", 325 | "\n", 326 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n", 327 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n", 328 | " \n", 329 | " labels = label_ids[:, None] < label_ids[None, :]\n", 330 | " labels = labels.long()\n", 331 | " \n", 332 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n", 333 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n", 334 | " loss = torch.logsumexp(cosine_sim, dim=0)\n", 335 | "\n", 336 | " return cosine_sim, loss\n" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "id": "e630530b", 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "dl_module = CoSENT.from_pretrained(\n", 347 | " 'bert-base-chinese', \n", 348 | " config=bert_config\n", 349 | ")" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "id": "13e3c8ac", 355 | "metadata": {}, 356 | "source": [ 357 | "
\n", 358 | "\n", 359 | "### 三、任务构建" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "id": "31d1f76c", 365 | "metadata": {}, 366 | "source": [ 367 | "#### 1. 任务参数和必要部件设定" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "943bf64c", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "# 设置运行次数\n", 378 | "num_epoches = 5\n", 379 | "batch_size = 32" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "id": "74641ede", 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "param_optimizer = list(dl_module.named_parameters())\n", 390 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n", 391 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 392 | "optimizer_grouped_parameters = [\n", 393 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n", 394 | " 'weight_decay': 0.01},\n", 395 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 396 | "] " 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "id": "bd5a9361", 402 | "metadata": {}, 403 | "source": [ 404 | "#### 2. 任务创建" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "5bc04465", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "import torch\n", 415 | "import numpy as np\n", 416 | "from scipy import stats\n", 417 | "\n", 418 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n", 419 | "\n", 420 | "\n", 421 | "class CoSENTTask(SequenceClassificationTask):\n", 422 | " \"\"\"\n", 423 | " 用于CoSENT模型文本匹配任务的Task\n", 424 | " \n", 425 | " Args:\n", 426 | " module: 深度学习模型\n", 427 | " optimizer: 训练模型使用的优化器名或者优化器对象\n", 428 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n", 429 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n", 430 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n", 431 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n", 432 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n", 433 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n", 434 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n", 435 | " **kwargs (optional): 其他可选参数\n", 436 | " \"\"\" # noqa: ignore flake8\"\n", 437 | "\n", 438 | " def _on_evaluate_begin_record(self, **kwargs):\n", 439 | "\n", 440 | " self.evaluate_logs['eval_loss'] = 0\n", 441 | " self.evaluate_logs['eval_step'] = 0\n", 442 | " self.evaluate_logs['eval_example'] = 0\n", 443 | "\n", 444 | " self.evaluate_logs['labels'] = []\n", 445 | " self.evaluate_logs['eval_sim'] = []\n", 446 | "\n", 447 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n", 448 | "\n", 449 | " with torch.no_grad():\n", 450 | " # compute loss\n", 451 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n", 452 | " self.evaluate_logs['eval_loss'] += loss.item()\n", 453 | "\n", 454 | " if 'label_ids' in inputs:\n", 455 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n", 456 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n", 457 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n", 458 | "\n", 459 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n", 460 | " self.evaluate_logs['eval_step'] += 1\n", 461 | "\n", 462 | " def _on_evaluate_epoch_end(\n", 463 | " self,\n", 464 | " validation_data,\n", 465 | " epoch=1,\n", 466 | " is_evaluate_print=True,\n", 467 | " **kwargs\n", 468 | " ):\n", 469 | "\n", 470 | " if is_evaluate_print:\n", 471 | " if 'labels' in self.evaluate_logs:\n", 472 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n", 473 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n", 474 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n", 475 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n", 476 | " spearman_corr,\n", 477 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n", 478 | " )\n", 479 | " )\n", 480 | " else:\n", 481 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "id": "3dfc61d9", 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "id": "35c96cf8", 497 | "metadata": { 498 | "tags": [] 499 | }, 500 | "source": [ 501 | "#### 3. 训练" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "id": "62e3e9ff", 508 | "metadata": {}, 509 | "outputs": [], 510 | "source": [ 511 | "model.fit(\n", 512 | " cosent_train_dataset,\n", 513 | " cosent_dev_dataset,\n", 514 | " lr=2e-5,\n", 515 | " epochs=num_epoches,\n", 516 | " batch_size=batch_size,\n", 517 | " params=optimizer_grouped_parameters\n", 518 | ")" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "id": "9b27e57b", 524 | "metadata": {}, 525 | "source": [ 526 | "
\n", 527 | "\n", 528 | "### 四、模型验证" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "id": "0800c2a3-2b57-435a-87d7-cfafbbc69fd1", 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "import torch\n", 539 | "\n", 540 | "from torch.utils.data import DataLoader\n", 541 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n", 542 | "\n", 543 | "\n", 544 | "class CoSENTPredictor(SequenceClassificationPredictor):\n", 545 | " \"\"\"\n", 546 | " CoSENT的预测器\n", 547 | " \n", 548 | " Args:\n", 549 | " module: 深度学习模型\n", 550 | " tokernizer: 分词器\n", 551 | " cat2id (:obj:`dict`): 标签映射\n", 552 | " \"\"\" # noqa: ignore flake8\"\n", 553 | "\n", 554 | " def _get_input_ids(\n", 555 | " self,\n", 556 | " text_a,\n", 557 | " text_b\n", 558 | " ):\n", 559 | " if self.tokenizer.tokenizer_type == 'vanilla':\n", 560 | " return self._convert_to_vanilla_ids(text_a, text_b)\n", 561 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n", 562 | " return self._convert_to_transfomer_ids(text_a, text_b)\n", 563 | " elif self.tokenizer.tokenizer_type == 'customized':\n", 564 | " return self._convert_to_customized_ids(text_a, text_b)\n", 565 | " else:\n", 566 | " raise ValueError(\"The tokenizer type does not exist\")\n", 567 | "\n", 568 | " def _convert_to_transfomer_ids(\n", 569 | " self,\n", 570 | " text_a,\n", 571 | " text_b\n", 572 | " ):\n", 573 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n", 574 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n", 575 | "\n", 576 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n", 577 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n", 578 | "\n", 579 | " features = {\n", 580 | " 'input_ids_a': input_ids_a,\n", 581 | " 'attention_mask_a': input_mask_a,\n", 582 | " 'token_type_ids_a': segment_ids_a,\n", 583 | " 'input_ids_b': input_ids_b,\n", 584 | " 'attention_mask_b': input_mask_b,\n", 585 | " 'token_type_ids_b': segment_ids_b\n", 586 | " }\n", 587 | "\n", 588 | " return features\n", 589 | "\n", 590 | " def predict_one_sample(\n", 591 | " self,\n", 592 | " text,\n", 593 | " topk=None,\n", 594 | " threshold=0.5,\n", 595 | " return_label_name=True,\n", 596 | " return_proba=False\n", 597 | " ):\n", 598 | " if topk is None:\n", 599 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n", 600 | " text_a, text_b = text\n", 601 | " features = self._get_input_ids(text_a, text_b)\n", 602 | " self.module.eval()\n", 603 | "\n", 604 | " with torch.no_grad():\n", 605 | " inputs = self._get_module_one_sample_inputs(features)\n", 606 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 607 | "\n", 608 | " _proba = logits[0]\n", 609 | " \n", 610 | " if threshold is not None:\n", 611 | " _pred = self._threshold(_proba, threshold)\n", 612 | "\n", 613 | " if return_label_name and threshold is not None:\n", 614 | " _pred = self.id2cat[_pred]\n", 615 | "\n", 616 | " if threshold is not None:\n", 617 | " if return_proba:\n", 618 | " return [_pred, _proba]\n", 619 | " else:\n", 620 | " return _pred\n", 621 | "\n", 622 | " return _proba\n", 623 | "\n", 624 | " def predict_batch(\n", 625 | " self,\n", 626 | " test_data,\n", 627 | " batch_size=16,\n", 628 | " shuffle=False\n", 629 | " ):\n", 630 | " self.inputs_cols = test_data.dataset_cols\n", 631 | "\n", 632 | " preds = []\n", 633 | "\n", 634 | " self.module.eval()\n", 635 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n", 636 | "\n", 637 | " with torch.no_grad():\n", 638 | " for step, inputs in enumerate(generator):\n", 639 | " inputs = self._get_module_batch_inputs(inputs)\n", 640 | "\n", 641 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 642 | "\n", 643 | " preds.extend(logits)\n", 644 | "\n", 645 | " return preds" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "id": "6f416a92-4e5d-4f78-91f1-b42f9db71b25", 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "id": "55fe09db-9e27-4300-b57f-df6603643e50", 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "cosent_predictor_instance.predict_one_sample(\n", 666 | " ['怎么花呗不可以用来生活缴费了呀', \n", 667 | " '怎么我的花呗不能付电费了'], \n", 668 | " threshold=None\n", 669 | ")" 670 | ] 671 | } 672 | ], 673 | "metadata": { 674 | "kernelspec": { 675 | "display_name": "Python 3 (ipykernel)", 676 | "language": "python", 677 | "name": "python3" 678 | }, 679 | "language_info": { 680 | "codemirror_mode": { 681 | "name": "ipython", 682 | "version": 3 683 | }, 684 | "file_extension": ".py", 685 | "mimetype": "text/x-python", 686 | "name": "python", 687 | "nbconvert_exporter": "python", 688 | "pygments_lexer": "ipython3", 689 | "version": "3.8.10" 690 | } 691 | }, 692 | "nbformat": 4, 693 | "nbformat_minor": 5 694 | } 695 | -------------------------------------------------------------------------------- /CoSENT_BQ.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9ced17a4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "\n", 13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n", 15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "bfcea417", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# 目录地址\n", 26 | "train_data_path = '../data/source_datasets/BQ/BQ.train.data'\n", 27 | "dev_data_path = '../data/source_datasets/BQ/BQ.test.data'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "ccea726a", 33 | "metadata": { 34 | "tags": [] 35 | }, 36 | "source": [ 37 | "### 一、数据读入与处理" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "8d5c3337", 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "source": [ 47 | "#### 1. 数据读入" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n", 58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "90876062", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "cosent_train_dataset = Dataset(train_data_df)\n", 69 | "cosent_dev_dataset = Dataset(dev_data_df)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "e061890a", 75 | "metadata": {}, 76 | "source": [ 77 | "#### 2. 词典创建和生成分词器" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "be116454", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# 加载分词器\n", 88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "0d6c3b3d", 94 | "metadata": {}, 95 | "source": [ 96 | "#### 3. ID化" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "566dd6b6", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n", 107 | "cosent_dev_dataset.convert_to_ids(tokenizer)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "981b4160", 113 | "metadata": {}, 114 | "source": [ 115 | "
\n", 116 | "\n", 117 | "### 二、模型构建" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "72753ee8", 123 | "metadata": {}, 124 | "source": [ 125 | "#### 1. 模型参数设置" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "527535d3", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from transformers import BertConfig\n", 136 | "\n", 137 | "bert_config = BertConfig.from_pretrained(\n", 138 | " 'bert-base-chinese',\n", 139 | " num_labels=2\n", 140 | ")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "77d8b580", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "torch.cuda.empty_cache()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "700d7752", 156 | "metadata": {}, 157 | "source": [ 158 | "#### 2. 模型创建" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "58f380f4", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "import torch\n", 169 | "import torch.nn.functional as F\n", 170 | "\n", 171 | "from torch import nn\n", 172 | "from transformers import BertModel\n", 173 | "from ark_nlp.nn import Bert\n", 174 | "\n", 175 | "\n", 176 | "class CoSENT(Bert):\n", 177 | " \"\"\"\n", 178 | " CoSENT模型\n", 179 | "\n", 180 | " Args:\n", 181 | " config:\n", 182 | " 模型的配置对象\n", 183 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n", 184 | " bert参数是否可训练,默认可训练\n", 185 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n", 186 | " bert输出的池化方式,默认为\"last_avg\",\n", 187 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n", 188 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n", 189 | " dropout比例,默认为None,实际设置时会设置成0\n", 190 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n", 191 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n", 192 | "\n", 193 | " Reference:\n", 194 | " [1] https://kexue.fm/archives/8847\n", 195 | " [2] https://github.com/bojone/CoSENT \n", 196 | " \"\"\" # noqa: ignore flake8\"\n", 197 | "\n", 198 | " def __init__(\n", 199 | " self,\n", 200 | " config,\n", 201 | " encoder_trained=True,\n", 202 | " pooling='last_avg',\n", 203 | " dropout=None,\n", 204 | " output_emb_size=0\n", 205 | " ):\n", 206 | "\n", 207 | " super(CoSENT, self).__init__(config)\n", 208 | "\n", 209 | " self.bert = BertModel(config)\n", 210 | " self.pooling = pooling\n", 211 | "\n", 212 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n", 213 | "\n", 214 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n", 215 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n", 216 | " # recall performance and efficiency\n", 217 | " self.output_emb_size = output_emb_size\n", 218 | " if self.output_emb_size > 0:\n", 219 | " self.emb_reduce_linear = nn.Linear(\n", 220 | " config.hidden_size,\n", 221 | " self.output_emb_size\n", 222 | " )\n", 223 | " torch.nn.init.trunc_normal_(\n", 224 | " self.emb_reduce_linear.weight,\n", 225 | " std=0.02\n", 226 | " )\n", 227 | "\n", 228 | " for param in self.bert.parameters():\n", 229 | " param.requires_grad = encoder_trained\n", 230 | "\n", 231 | " self.init_weights()\n", 232 | "\n", 233 | " def get_pooled_embedding(\n", 234 | " self,\n", 235 | " input_ids,\n", 236 | " token_type_ids=None,\n", 237 | " position_ids=None,\n", 238 | " attention_mask=None\n", 239 | " ):\n", 240 | " outputs = self.bert(\n", 241 | " input_ids,\n", 242 | " attention_mask=attention_mask,\n", 243 | " token_type_ids=token_type_ids,\n", 244 | " position_ids=position_ids,\n", 245 | " return_dict=True,\n", 246 | " output_hidden_states=True\n", 247 | " )\n", 248 | "\n", 249 | " encoder_feature = self.get_encoder_feature(\n", 250 | " outputs,\n", 251 | " attention_mask\n", 252 | " )\n", 253 | "\n", 254 | " if self.output_emb_size > 0:\n", 255 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n", 256 | "\n", 257 | " encoder_feature = self.dropout(encoder_feature)\n", 258 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n", 259 | "\n", 260 | " return out\n", 261 | "\n", 262 | " def cosine_sim(\n", 263 | " self,\n", 264 | " input_ids_a,\n", 265 | " input_ids_b,\n", 266 | " token_type_ids_a=None,\n", 267 | " position_ids_ids_a=None,\n", 268 | " attention_mask_a=None,\n", 269 | " token_type_ids_b=None,\n", 270 | " position_ids_b=None,\n", 271 | " attention_mask_b=None,\n", 272 | " **kwargs\n", 273 | " ):\n", 274 | "\n", 275 | " query_cls_embedding = self.get_pooled_embedding(\n", 276 | " input_ids_a,\n", 277 | " token_type_ids_a,\n", 278 | " position_ids_ids_a,\n", 279 | " attention_mask_a\n", 280 | " )\n", 281 | "\n", 282 | " title_cls_embedding = self.get_pooled_embedding(\n", 283 | " input_ids_b,\n", 284 | " token_type_ids_b,\n", 285 | " position_ids_b,\n", 286 | " attention_mask_b\n", 287 | " )\n", 288 | "\n", 289 | " cosine_sim = torch.sum(\n", 290 | " query_cls_embedding * title_cls_embedding,\n", 291 | " axis=-1\n", 292 | " )\n", 293 | "\n", 294 | " return cosine_sim\n", 295 | "\n", 296 | " def forward(\n", 297 | " self,\n", 298 | " input_ids_a,\n", 299 | " input_ids_b,\n", 300 | " token_type_ids_a=None,\n", 301 | " position_ids_ids_a=None,\n", 302 | " attention_mask_a=None,\n", 303 | " token_type_ids_b=None,\n", 304 | " position_ids_b=None,\n", 305 | " attention_mask_b=None,\n", 306 | " label_ids=None,\n", 307 | " **kwargs\n", 308 | " ):\n", 309 | "\n", 310 | " cls_embedding_a = self.get_pooled_embedding(\n", 311 | " input_ids_a,\n", 312 | " token_type_ids_a,\n", 313 | " position_ids_ids_a,\n", 314 | " attention_mask_a\n", 315 | " )\n", 316 | "\n", 317 | " cls_embedding_b = self.get_pooled_embedding(\n", 318 | " input_ids_b,\n", 319 | " token_type_ids_b,\n", 320 | " position_ids_b,\n", 321 | " attention_mask_b\n", 322 | " )\n", 323 | "\n", 324 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n", 325 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n", 326 | " \n", 327 | " labels = label_ids[:, None] < label_ids[None, :]\n", 328 | " labels = labels.long()\n", 329 | " \n", 330 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n", 331 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n", 332 | " loss = torch.logsumexp(cosine_sim, dim=0)\n", 333 | "\n", 334 | " return cosine_sim, loss\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "e630530b", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "dl_module = CoSENT.from_pretrained(\n", 345 | " 'bert-base-chinese', \n", 346 | " config=bert_config\n", 347 | ")" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "13e3c8ac", 353 | "metadata": {}, 354 | "source": [ 355 | "
\n", 356 | "\n", 357 | "### 三、任务构建" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "31d1f76c", 363 | "metadata": {}, 364 | "source": [ 365 | "#### 1. 任务参数和必要部件设定" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "id": "943bf64c", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "# 设置运行次数\n", 376 | "num_epoches = 5\n", 377 | "batch_size = 32" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "id": "74641ede", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "param_optimizer = list(dl_module.named_parameters())\n", 388 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n", 389 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 390 | "optimizer_grouped_parameters = [\n", 391 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n", 392 | " 'weight_decay': 0.01},\n", 393 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 394 | "] " 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "id": "bd5a9361", 400 | "metadata": {}, 401 | "source": [ 402 | "#### 2. 任务创建" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "5bc04465", 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "import torch\n", 413 | "import numpy as np\n", 414 | "from scipy import stats\n", 415 | "\n", 416 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n", 417 | "\n", 418 | "\n", 419 | "class CoSENTTask(SequenceClassificationTask):\n", 420 | " \"\"\"\n", 421 | " 用于CoSENT模型文本匹配任务的Task\n", 422 | " \n", 423 | " Args:\n", 424 | " module: 深度学习模型\n", 425 | " optimizer: 训练模型使用的优化器名或者优化器对象\n", 426 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n", 427 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n", 428 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n", 429 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n", 430 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n", 431 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n", 432 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n", 433 | " **kwargs (optional): 其他可选参数\n", 434 | " \"\"\" # noqa: ignore flake8\"\n", 435 | "\n", 436 | " def _on_evaluate_begin_record(self, **kwargs):\n", 437 | "\n", 438 | " self.evaluate_logs['eval_loss'] = 0\n", 439 | " self.evaluate_logs['eval_step'] = 0\n", 440 | " self.evaluate_logs['eval_example'] = 0\n", 441 | "\n", 442 | " self.evaluate_logs['labels'] = []\n", 443 | " self.evaluate_logs['eval_sim'] = []\n", 444 | "\n", 445 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n", 446 | "\n", 447 | " with torch.no_grad():\n", 448 | " # compute loss\n", 449 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n", 450 | " self.evaluate_logs['eval_loss'] += loss.item()\n", 451 | "\n", 452 | " if 'label_ids' in inputs:\n", 453 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n", 454 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n", 455 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n", 456 | "\n", 457 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n", 458 | " self.evaluate_logs['eval_step'] += 1\n", 459 | "\n", 460 | " def _on_evaluate_epoch_end(\n", 461 | " self,\n", 462 | " validation_data,\n", 463 | " epoch=1,\n", 464 | " is_evaluate_print=True,\n", 465 | " **kwargs\n", 466 | " ):\n", 467 | "\n", 468 | " if is_evaluate_print:\n", 469 | " if 'labels' in self.evaluate_logs:\n", 470 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n", 471 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n", 472 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n", 473 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n", 474 | " spearman_corr,\n", 475 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n", 476 | " )\n", 477 | " )\n", 478 | " else:\n", 479 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "id": "3dfc61d9", 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "id": "35c96cf8", 495 | "metadata": { 496 | "tags": [] 497 | }, 498 | "source": [ 499 | "#### 3. 训练" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "62e3e9ff", 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "model.fit(\n", 510 | " cosent_train_dataset,\n", 511 | " cosent_dev_dataset,\n", 512 | " lr=2e-5,\n", 513 | " epochs=num_epoches,\n", 514 | " batch_size=batch_size,\n", 515 | " params=optimizer_grouped_parameters\n", 516 | ")" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "id": "9b27e57b", 522 | "metadata": {}, 523 | "source": [ 524 | "
\n", 525 | "\n", 526 | "### 四、模型验证" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "id": "e6fe3272-777a-45b5-b736-6546af69ec34", 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "import torch\n", 537 | "\n", 538 | "from torch.utils.data import DataLoader\n", 539 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n", 540 | "\n", 541 | "\n", 542 | "class CoSENTPredictor(SequenceClassificationPredictor):\n", 543 | " \"\"\"\n", 544 | " CoSENT的预测器\n", 545 | " \n", 546 | " Args:\n", 547 | " module: 深度学习模型\n", 548 | " tokernizer: 分词器\n", 549 | " cat2id (:obj:`dict`): 标签映射\n", 550 | " \"\"\" # noqa: ignore flake8\"\n", 551 | "\n", 552 | " def _get_input_ids(\n", 553 | " self,\n", 554 | " text_a,\n", 555 | " text_b\n", 556 | " ):\n", 557 | " if self.tokenizer.tokenizer_type == 'vanilla':\n", 558 | " return self._convert_to_vanilla_ids(text_a, text_b)\n", 559 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n", 560 | " return self._convert_to_transfomer_ids(text_a, text_b)\n", 561 | " elif self.tokenizer.tokenizer_type == 'customized':\n", 562 | " return self._convert_to_customized_ids(text_a, text_b)\n", 563 | " else:\n", 564 | " raise ValueError(\"The tokenizer type does not exist\")\n", 565 | "\n", 566 | " def _convert_to_transfomer_ids(\n", 567 | " self,\n", 568 | " text_a,\n", 569 | " text_b\n", 570 | " ):\n", 571 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n", 572 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n", 573 | "\n", 574 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n", 575 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n", 576 | "\n", 577 | " features = {\n", 578 | " 'input_ids_a': input_ids_a,\n", 579 | " 'attention_mask_a': input_mask_a,\n", 580 | " 'token_type_ids_a': segment_ids_a,\n", 581 | " 'input_ids_b': input_ids_b,\n", 582 | " 'attention_mask_b': input_mask_b,\n", 583 | " 'token_type_ids_b': segment_ids_b\n", 584 | " }\n", 585 | "\n", 586 | " return features\n", 587 | "\n", 588 | " def predict_one_sample(\n", 589 | " self,\n", 590 | " text,\n", 591 | " topk=None,\n", 592 | " threshold=0.5,\n", 593 | " return_label_name=True,\n", 594 | " return_proba=False\n", 595 | " ):\n", 596 | " if topk is None:\n", 597 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n", 598 | " text_a, text_b = text\n", 599 | " features = self._get_input_ids(text_a, text_b)\n", 600 | " self.module.eval()\n", 601 | "\n", 602 | " with torch.no_grad():\n", 603 | " inputs = self._get_module_one_sample_inputs(features)\n", 604 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 605 | "\n", 606 | " _proba = logits[0]\n", 607 | " \n", 608 | " if threshold is not None:\n", 609 | " _pred = self._threshold(_proba, threshold)\n", 610 | "\n", 611 | " if return_label_name and threshold is not None:\n", 612 | " _pred = self.id2cat[_pred]\n", 613 | "\n", 614 | " if threshold is not None:\n", 615 | " if return_proba:\n", 616 | " return [_pred, _proba]\n", 617 | " else:\n", 618 | " return _pred\n", 619 | "\n", 620 | " return _proba\n", 621 | "\n", 622 | " def predict_batch(\n", 623 | " self,\n", 624 | " test_data,\n", 625 | " batch_size=16,\n", 626 | " shuffle=False\n", 627 | " ):\n", 628 | " self.inputs_cols = test_data.dataset_cols\n", 629 | "\n", 630 | " preds = []\n", 631 | "\n", 632 | " self.module.eval()\n", 633 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n", 634 | "\n", 635 | " with torch.no_grad():\n", 636 | " for step, inputs in enumerate(generator):\n", 637 | " inputs = self._get_module_batch_inputs(inputs)\n", 638 | "\n", 639 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 640 | "\n", 641 | " preds.extend(logits)\n", 642 | "\n", 643 | " return preds" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": null, 649 | "id": "2d14331b-bf8f-4431-9f8f-e6947cbe7bfd", 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "id": "ab835fab-8d52-4fad-9610-a9125c56b825", 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "cosent_predictor_instance.predict_one_sample(\n", 664 | " ['用微信都6年,微信没有微粒贷功能', \n", 665 | " '4。 号码来微粒贷'],\n", 666 | " threshold=None\n", 667 | ")" 668 | ] 669 | } 670 | ], 671 | "metadata": { 672 | "kernelspec": { 673 | "display_name": "Python 3 (ipykernel)", 674 | "language": "python", 675 | "name": "python3" 676 | }, 677 | "language_info": { 678 | "codemirror_mode": { 679 | "name": "ipython", 680 | "version": 3 681 | }, 682 | "file_extension": ".py", 683 | "mimetype": "text/x-python", 684 | "name": "python", 685 | "nbconvert_exporter": "python", 686 | "pygments_lexer": "ipython3", 687 | "version": "3.8.10" 688 | } 689 | }, 690 | "nbformat": 4, 691 | "nbformat_minor": 5 692 | } 693 | -------------------------------------------------------------------------------- /CoSENT_CHIP-STS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9ced17a4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "\n", 13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n", 15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "ccea726a", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "source": [ 25 | "### 一、数据读入与处理" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "8d5c3337", 31 | "metadata": { 32 | "tags": [] 33 | }, 34 | "source": [ 35 | "#### 1. 数据读入" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "train_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_train.json')\n", 46 | "train_data_df = (train_data_df\n", 47 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n", 48 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])\n", 49 | "\n", 50 | "dev_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_dev.json')\n", 51 | "dev_data_df = dev_data_df[dev_data_df['label'] != \"NA\"]\n", 52 | "dev_data_df = (dev_data_df\n", 53 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n", 54 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "90876062", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "cosent_train_dataset = Dataset(train_data_df)\n", 65 | "cosent_dev_dataset = Dataset(dev_data_df)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "e061890a", 71 | "metadata": {}, 72 | "source": [ 73 | "#### 2. 词典创建和生成分词器" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "be116454", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# 加载分词器\n", 84 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "0d6c3b3d", 90 | "metadata": {}, 91 | "source": [ 92 | "#### 3. ID化" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "566dd6b6", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "cosent_train_dataset.convert_to_ids(tokenizer)\n", 103 | "cosent_dev_dataset.convert_to_ids(tokenizer)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "981b4160", 109 | "metadata": {}, 110 | "source": [ 111 | "
\n", 112 | "\n", 113 | "### 二、模型构建" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "72753ee8", 119 | "metadata": {}, 120 | "source": [ 121 | "#### 1. 模型参数设置" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "527535d3", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "from transformers import BertConfig\n", 132 | "\n", 133 | "bert_config = BertConfig.from_pretrained(\n", 134 | " 'bert-base-chinese',\n", 135 | " num_labels=len(cosent_train_dataset.cat2id)\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "77d8b580", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "torch.cuda.empty_cache()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "700d7752", 152 | "metadata": {}, 153 | "source": [ 154 | "#### 2. 模型创建" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "58f380f4", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "import torch\n", 165 | "import torch.nn.functional as F\n", 166 | "\n", 167 | "from torch import nn\n", 168 | "from transformers import BertModel\n", 169 | "from ark_nlp.nn import Bert\n", 170 | "\n", 171 | "\n", 172 | "class CoSENT(Bert):\n", 173 | " \"\"\"\n", 174 | " CoSENT模型\n", 175 | "\n", 176 | " Args:\n", 177 | " config:\n", 178 | " 模型的配置对象\n", 179 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n", 180 | " bert参数是否可训练,默认可训练\n", 181 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n", 182 | " bert输出的池化方式,默认为\"last_avg\",\n", 183 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n", 184 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n", 185 | " dropout比例,默认为None,实际设置时会设置成0\n", 186 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n", 187 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n", 188 | "\n", 189 | " Reference:\n", 190 | " [1] https://kexue.fm/archives/8847\n", 191 | " [2] https://github.com/bojone/CoSENT \n", 192 | " \"\"\" # noqa: ignore flake8\"\n", 193 | "\n", 194 | " def __init__(\n", 195 | " self,\n", 196 | " config,\n", 197 | " encoder_trained=True,\n", 198 | " pooling='last_avg',\n", 199 | " dropout=None,\n", 200 | " output_emb_size=0\n", 201 | " ):\n", 202 | "\n", 203 | " super(CoSENT, self).__init__(config)\n", 204 | "\n", 205 | " self.bert = BertModel(config)\n", 206 | " self.pooling = pooling\n", 207 | "\n", 208 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0)\n", 209 | "\n", 210 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n", 211 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n", 212 | " # recall performance and efficiency\n", 213 | " self.output_emb_size = output_emb_size\n", 214 | " if self.output_emb_size > 0:\n", 215 | " self.emb_reduce_linear = nn.Linear(\n", 216 | " config.hidden_size,\n", 217 | " self.output_emb_size\n", 218 | " )\n", 219 | " torch.nn.init.trunc_normal_(\n", 220 | " self.emb_reduce_linear.weight,\n", 221 | " std=0.02\n", 222 | " )\n", 223 | "\n", 224 | " for param in self.bert.parameters():\n", 225 | " param.requires_grad = encoder_trained\n", 226 | "\n", 227 | " self.init_weights()\n", 228 | "\n", 229 | " def get_pooled_embedding(\n", 230 | " self,\n", 231 | " input_ids,\n", 232 | " token_type_ids=None,\n", 233 | " position_ids=None,\n", 234 | " attention_mask=None\n", 235 | " ):\n", 236 | " outputs = self.bert(\n", 237 | " input_ids,\n", 238 | " attention_mask=attention_mask,\n", 239 | " token_type_ids=token_type_ids,\n", 240 | " position_ids=position_ids,\n", 241 | " return_dict=True,\n", 242 | " output_hidden_states=True\n", 243 | " )\n", 244 | "\n", 245 | " encoder_feature = self.get_encoder_feature(\n", 246 | " outputs,\n", 247 | " attention_mask\n", 248 | " )\n", 249 | "\n", 250 | " if self.output_emb_size > 0:\n", 251 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n", 252 | "\n", 253 | " encoder_feature = self.dropout(encoder_feature)\n", 254 | " out = F.normalize(encoder_feature, p=2, dim=-1)\n", 255 | "\n", 256 | " return out\n", 257 | "\n", 258 | " def cosine_sim(\n", 259 | " self,\n", 260 | " input_ids_a,\n", 261 | " input_ids_b,\n", 262 | " token_type_ids_a=None,\n", 263 | " position_ids_ids_a=None,\n", 264 | " attention_mask_a=None,\n", 265 | " token_type_ids_b=None,\n", 266 | " position_ids_b=None,\n", 267 | " attention_mask_b=None,\n", 268 | " **kwargs\n", 269 | " ):\n", 270 | "\n", 271 | " query_cls_embedding = self.get_pooled_embedding(\n", 272 | " input_ids_a,\n", 273 | " token_type_ids_a,\n", 274 | " position_ids_ids_a,\n", 275 | " attention_mask_a\n", 276 | " )\n", 277 | "\n", 278 | " title_cls_embedding = self.get_pooled_embedding(\n", 279 | " input_ids_b,\n", 280 | " token_type_ids_b,\n", 281 | " position_ids_b,\n", 282 | " attention_mask_b\n", 283 | " )\n", 284 | "\n", 285 | " cosine_sim = torch.sum(\n", 286 | " query_cls_embedding * title_cls_embedding,\n", 287 | " axis=-1\n", 288 | " )\n", 289 | "\n", 290 | " return cosine_sim\n", 291 | "\n", 292 | " def forward(\n", 293 | " self,\n", 294 | " input_ids_a,\n", 295 | " input_ids_b,\n", 296 | " token_type_ids_a=None,\n", 297 | " position_ids_ids_a=None,\n", 298 | " attention_mask_a=None,\n", 299 | " token_type_ids_b=None,\n", 300 | " position_ids_b=None,\n", 301 | " attention_mask_b=None,\n", 302 | " label_ids=None,\n", 303 | " **kwargs\n", 304 | " ):\n", 305 | "\n", 306 | " cls_embedding_a = self.get_pooled_embedding(\n", 307 | " input_ids_a,\n", 308 | " token_type_ids_a,\n", 309 | " position_ids_ids_a,\n", 310 | " attention_mask_a\n", 311 | " )\n", 312 | "\n", 313 | " cls_embedding_b = self.get_pooled_embedding(\n", 314 | " input_ids_b,\n", 315 | " token_type_ids_b,\n", 316 | " position_ids_b,\n", 317 | " attention_mask_b\n", 318 | " )\n", 319 | "\n", 320 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n", 321 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n", 322 | " \n", 323 | " labels = label_ids[:, None] < label_ids[None, :]\n", 324 | " labels = labels.long()\n", 325 | " \n", 326 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n", 327 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n", 328 | " loss = torch.logsumexp(cosine_sim.view(-1), dim=0)\n", 329 | "\n", 330 | " return cosine_sim, loss\n" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "id": "e630530b", 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "dl_module = CoSENT.from_pretrained(\n", 341 | " 'bert-base-chinese', \n", 342 | " config=bert_config\n", 343 | ")" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "id": "13e3c8ac", 349 | "metadata": {}, 350 | "source": [ 351 | "
\n", 352 | "\n", 353 | "### 三、任务构建" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "id": "31d1f76c", 359 | "metadata": {}, 360 | "source": [ 361 | "#### 1. 任务参数和必要部件设定" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "id": "943bf64c", 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "# 设置运行次数\n", 372 | "num_epoches = 5\n", 373 | "batch_size = 32" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "74641ede", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "param_optimizer = list(dl_module.named_parameters())\n", 384 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n", 385 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 386 | "optimizer_grouped_parameters = [\n", 387 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n", 388 | " 'weight_decay': 0.01},\n", 389 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 390 | "] " 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "id": "bd5a9361", 396 | "metadata": {}, 397 | "source": [ 398 | "#### 2. 任务创建" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "5bc04465", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "import torch\n", 409 | "import numpy as np\n", 410 | "from scipy import stats\n", 411 | "\n", 412 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n", 413 | "\n", 414 | "\n", 415 | "class CoSENTTask(SequenceClassificationTask):\n", 416 | " \"\"\"\n", 417 | " 用于CoSENT模型文本匹配任务的Task\n", 418 | " \n", 419 | " Args:\n", 420 | " module: 深度学习模型\n", 421 | " optimizer: 训练模型使用的优化器名或者优化器对象\n", 422 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n", 423 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n", 424 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n", 425 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n", 426 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n", 427 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n", 428 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n", 429 | " **kwargs (optional): 其他可选参数\n", 430 | " \"\"\" # noqa: ignore flake8\"\n", 431 | "\n", 432 | " def _on_evaluate_begin_record(self, **kwargs):\n", 433 | "\n", 434 | " self.evaluate_logs['eval_loss'] = 0\n", 435 | " self.evaluate_logs['eval_step'] = 0\n", 436 | " self.evaluate_logs['eval_example'] = 0\n", 437 | "\n", 438 | " self.evaluate_logs['labels'] = []\n", 439 | " self.evaluate_logs['eval_sim'] = []\n", 440 | "\n", 441 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n", 442 | "\n", 443 | " with torch.no_grad():\n", 444 | " # compute loss\n", 445 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n", 446 | " self.evaluate_logs['eval_loss'] += loss.item()\n", 447 | "\n", 448 | " if 'label_ids' in inputs:\n", 449 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n", 450 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n", 451 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n", 452 | "\n", 453 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n", 454 | " self.evaluate_logs['eval_step'] += 1\n", 455 | "\n", 456 | " def _on_evaluate_epoch_end(\n", 457 | " self,\n", 458 | " validation_data,\n", 459 | " epoch=1,\n", 460 | " is_evaluate_print=True,\n", 461 | " **kwargs\n", 462 | " ):\n", 463 | "\n", 464 | " if is_evaluate_print:\n", 465 | " if 'labels' in self.evaluate_logs:\n", 466 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n", 467 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n", 468 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n", 469 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n", 470 | " spearman_corr,\n", 471 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n", 472 | " )\n", 473 | " )\n", 474 | " else:\n", 475 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "id": "3dfc61d9", 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "id": "35c96cf8", 491 | "metadata": { 492 | "tags": [] 493 | }, 494 | "source": [ 495 | "#### 3. 训练" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "id": "62e3e9ff", 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "model.fit(\n", 506 | " cosent_train_dataset,\n", 507 | " cosent_dev_dataset,\n", 508 | " lr=2e-5,\n", 509 | " epochs=num_epoches,\n", 510 | " batch_size=batch_size,\n", 511 | " params=optimizer_grouped_parameters\n", 512 | ")" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "id": "9b27e57b", 518 | "metadata": {}, 519 | "source": [ 520 | "
\n", 521 | "\n", 522 | "### 四、模型验证" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "id": "971c58b0-23fe-45c2-a9e8-0c8d48ab9326", 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "import torch\n", 533 | "\n", 534 | "from torch.utils.data import DataLoader\n", 535 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n", 536 | "\n", 537 | "\n", 538 | "class CoSENTPredictor(SequenceClassificationPredictor):\n", 539 | " \"\"\"\n", 540 | " CoSENT的预测器\n", 541 | " \n", 542 | " Args:\n", 543 | " module: 深度学习模型\n", 544 | " tokernizer: 分词器\n", 545 | " cat2id (:obj:`dict`): 标签映射\n", 546 | " \"\"\" # noqa: ignore flake8\"\n", 547 | "\n", 548 | " def _get_input_ids(\n", 549 | " self,\n", 550 | " text_a,\n", 551 | " text_b\n", 552 | " ):\n", 553 | " if self.tokenizer.tokenizer_type == 'vanilla':\n", 554 | " return self._convert_to_vanilla_ids(text_a, text_b)\n", 555 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n", 556 | " return self._convert_to_transfomer_ids(text_a, text_b)\n", 557 | " elif self.tokenizer.tokenizer_type == 'customized':\n", 558 | " return self._convert_to_customized_ids(text_a, text_b)\n", 559 | " else:\n", 560 | " raise ValueError(\"The tokenizer type does not exist\")\n", 561 | "\n", 562 | " def _convert_to_transfomer_ids(\n", 563 | " self,\n", 564 | " text_a,\n", 565 | " text_b\n", 566 | " ):\n", 567 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n", 568 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n", 569 | "\n", 570 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n", 571 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n", 572 | "\n", 573 | " features = {\n", 574 | " 'input_ids_a': input_ids_a,\n", 575 | " 'attention_mask_a': input_mask_a,\n", 576 | " 'token_type_ids_a': segment_ids_a,\n", 577 | " 'input_ids_b': input_ids_b,\n", 578 | " 'attention_mask_b': input_mask_b,\n", 579 | " 'token_type_ids_b': segment_ids_b\n", 580 | " }\n", 581 | "\n", 582 | " return features\n", 583 | "\n", 584 | " def predict_one_sample(\n", 585 | " self,\n", 586 | " text,\n", 587 | " topk=None,\n", 588 | " threshold=0.5,\n", 589 | " return_label_name=True,\n", 590 | " return_proba=False\n", 591 | " ):\n", 592 | " if topk is None:\n", 593 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n", 594 | " text_a, text_b = text\n", 595 | " features = self._get_input_ids(text_a, text_b)\n", 596 | " self.module.eval()\n", 597 | "\n", 598 | " with torch.no_grad():\n", 599 | " inputs = self._get_module_one_sample_inputs(features)\n", 600 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 601 | "\n", 602 | " _proba = logits[0]\n", 603 | " \n", 604 | " if threshold is not None:\n", 605 | " _pred = self._threshold(_proba, threshold)\n", 606 | "\n", 607 | " if return_label_name and threshold is not None:\n", 608 | " _pred = self.id2cat[_pred]\n", 609 | "\n", 610 | " if threshold is not None:\n", 611 | " if return_proba:\n", 612 | " return [_pred, _proba]\n", 613 | " else:\n", 614 | " return _pred\n", 615 | "\n", 616 | " return _proba\n", 617 | "\n", 618 | " def predict_batch(\n", 619 | " self,\n", 620 | " test_data,\n", 621 | " batch_size=16,\n", 622 | " shuffle=False\n", 623 | " ):\n", 624 | " self.inputs_cols = test_data.dataset_cols\n", 625 | "\n", 626 | " preds = []\n", 627 | "\n", 628 | " self.module.eval()\n", 629 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n", 630 | "\n", 631 | " with torch.no_grad():\n", 632 | " for step, inputs in enumerate(generator):\n", 633 | " inputs = self._get_module_batch_inputs(inputs)\n", 634 | "\n", 635 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 636 | "\n", 637 | " preds.extend(logits)\n", 638 | "\n", 639 | " return preds" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "id": "4b2f71d7-8547-4d17-95f2-77952673b871", 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": null, 655 | "id": "ec9094a9-cdee-4709-80c7-35285290a992", 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "cosent_predictor_instance.predict_one_sample(\n", 660 | " ['糖尿病能吃减肥药吗?能治愈吗?', \n", 661 | " '糖尿病为什么不能吃减肥药'],\n", 662 | " threshold=None\n", 663 | ")" 664 | ] 665 | }, 666 | { 667 | "cell_type": "markdown", 668 | "id": "fd6fdff6-1a80-4f6e-bebd-82511121365c", 669 | "metadata": {}, 670 | "source": [ 671 | "### 五、CBLUE打榜提交" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "id": "4e66ade4-b0f6-4e77-863c-6c62762ef793", 678 | "metadata": {}, 679 | "outputs": [], 680 | "source": [ 681 | "import pandas as pd\n", 682 | "\n", 683 | "from tqdm import tqdm\n", 684 | "\n", 685 | "test_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_test.json')\n", 686 | "\n", 687 | "submit = []\n", 688 | "for _id, _text_a, _text_b, _condition in tqdm(zip(\n", 689 | " test_df['id'],\n", 690 | " test_df['text1'],\n", 691 | " test_df['text2'],\n", 692 | " test_df['category']\n", 693 | ")):\n", 694 | " if _condition == 'daibetes':\n", 695 | " _condition = 'diabetes'\n", 696 | "\n", 697 | " predict_ = cosent_predictor_instance.predict_one_sample([_text_a, _text_b], threshold=0.6)\n", 698 | " \n", 699 | " submit.append({\n", 700 | " 'id': str(_id),\n", 701 | " 'text1': _text_a,\n", 702 | " 'text2': _text_b,\n", 703 | " 'label': predict_,\n", 704 | " 'category': _condition\n", 705 | " })" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "id": "1fa52f20-7f8f-4279-ae2b-fc1d54c5f0f2", 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "import json\n", 716 | "\n", 717 | "output_path = '../data/output_datasets/CHIP-STS_test.json'\n", 718 | "\n", 719 | "with open(output_path, 'w', encoding='utf-8') as f:\n", 720 | " f.write(json.dumps(submit, ensure_ascii=False))" 721 | ] 722 | } 723 | ], 724 | "metadata": { 725 | "kernelspec": { 726 | "display_name": "Python 3 (ipykernel)", 727 | "language": "python", 728 | "name": "python3" 729 | }, 730 | "language_info": { 731 | "codemirror_mode": { 732 | "name": "ipython", 733 | "version": 3 734 | }, 735 | "file_extension": ".py", 736 | "mimetype": "text/x-python", 737 | "name": "python", 738 | "nbconvert_exporter": "python", 739 | "pygments_lexer": "ipython3", 740 | "version": "3.8.10" 741 | } 742 | }, 743 | "nbformat": 4, 744 | "nbformat_minor": 5 745 | } 746 | -------------------------------------------------------------------------------- /CoSENT_LCQMC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9ced17a4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import pandas as pd\n", 12 | "\n", 13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n", 15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "bfcea417", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# 目录地址\n", 26 | "train_data_path = '../data/source_datasets/LCQMC/LCQMC.train.data'\n", 27 | "dev_data_path = '../data/source_datasets/LCQMC/LCQMC.test.data'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "ccea726a", 33 | "metadata": { 34 | "tags": [] 35 | }, 36 | "source": [ 37 | "### 一、数据读入与处理" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "8d5c3337", 43 | "metadata": { 44 | "tags": [] 45 | }, 46 | "source": [ 47 | "#### 1. 数据读入" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n", 58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "90876062", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "cosent_train_dataset = Dataset(train_data_df)\n", 69 | "cosent_dev_dataset = Dataset(dev_data_df)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "e061890a", 75 | "metadata": {}, 76 | "source": [ 77 | "#### 2. 词典创建和生成分词器" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "be116454", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# 加载分词器\n", 88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "0d6c3b3d", 94 | "metadata": {}, 95 | "source": [ 96 | "#### 3. ID化" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "566dd6b6", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n", 107 | "cosent_dev_dataset.convert_to_ids(tokenizer)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "981b4160", 113 | "metadata": {}, 114 | "source": [ 115 | "
\n", 116 | "\n", 117 | "### 二、模型构建" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "72753ee8", 123 | "metadata": {}, 124 | "source": [ 125 | "#### 1. 模型参数设置" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "527535d3", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from transformers import BertConfig\n", 136 | "\n", 137 | "bert_config = BertConfig.from_pretrained(\n", 138 | " 'bert-base-chinese',\n", 139 | " num_labels=2\n", 140 | ")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "77d8b580", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "torch.cuda.empty_cache()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "700d7752", 156 | "metadata": {}, 157 | "source": [ 158 | "#### 2. 模型创建" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "58f380f4", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "import torch\n", 169 | "import torch.nn.functional as F\n", 170 | "\n", 171 | "from torch import nn\n", 172 | "from transformers import BertModel\n", 173 | "from ark_nlp.nn import Bert\n", 174 | "\n", 175 | "\n", 176 | "class CoSENT(Bert):\n", 177 | " \"\"\"\n", 178 | " CoSENT模型\n", 179 | "\n", 180 | " Args:\n", 181 | " config:\n", 182 | " 模型的配置对象\n", 183 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n", 184 | " bert参数是否可训练,默认可训练\n", 185 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n", 186 | " bert输出的池化方式,默认为\"last_avg\",\n", 187 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n", 188 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n", 189 | " dropout比例,默认为None,实际设置时会设置成0\n", 190 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n", 191 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n", 192 | "\n", 193 | " Reference:\n", 194 | " [1] https://kexue.fm/archives/8847\n", 195 | " [2] https://github.com/bojone/CoSENT \n", 196 | " \"\"\" # noqa: ignore flake8\"\n", 197 | "\n", 198 | " def __init__(\n", 199 | " self,\n", 200 | " config,\n", 201 | " encoder_trained=True,\n", 202 | " pooling='last_avg',\n", 203 | " dropout=None,\n", 204 | " output_emb_size=0\n", 205 | " ):\n", 206 | "\n", 207 | " super(CoSENT, self).__init__(config)\n", 208 | "\n", 209 | " self.bert = BertModel(config)\n", 210 | " self.pooling = pooling\n", 211 | "\n", 212 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0)\n", 213 | "\n", 214 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n", 215 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n", 216 | " # recall performance and efficiency\n", 217 | " self.output_emb_size = output_emb_size\n", 218 | " if self.output_emb_size > 0:\n", 219 | " self.emb_reduce_linear = nn.Linear(\n", 220 | " config.hidden_size,\n", 221 | " self.output_emb_size\n", 222 | " )\n", 223 | " torch.nn.init.trunc_normal_(\n", 224 | " self.emb_reduce_linear.weight,\n", 225 | " std=0.02\n", 226 | " )\n", 227 | "\n", 228 | " for param in self.bert.parameters():\n", 229 | " param.requires_grad = encoder_trained\n", 230 | "\n", 231 | " self.init_weights()\n", 232 | "\n", 233 | " def get_pooled_embedding(\n", 234 | " self,\n", 235 | " input_ids,\n", 236 | " token_type_ids=None,\n", 237 | " position_ids=None,\n", 238 | " attention_mask=None\n", 239 | " ):\n", 240 | " outputs = self.bert(\n", 241 | " input_ids,\n", 242 | " attention_mask=attention_mask,\n", 243 | " token_type_ids=token_type_ids,\n", 244 | " position_ids=position_ids,\n", 245 | " return_dict=True,\n", 246 | " output_hidden_states=True\n", 247 | " )\n", 248 | "\n", 249 | " encoder_feature = self.get_encoder_feature(\n", 250 | " outputs,\n", 251 | " attention_mask\n", 252 | " )\n", 253 | "\n", 254 | " if self.output_emb_size > 0:\n", 255 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n", 256 | "\n", 257 | " encoder_feature = self.dropout(encoder_feature)\n", 258 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n", 259 | "\n", 260 | " return out\n", 261 | "\n", 262 | " def cosine_sim(\n", 263 | " self,\n", 264 | " input_ids_a,\n", 265 | " input_ids_b,\n", 266 | " token_type_ids_a=None,\n", 267 | " position_ids_ids_a=None,\n", 268 | " attention_mask_a=None,\n", 269 | " token_type_ids_b=None,\n", 270 | " position_ids_b=None,\n", 271 | " attention_mask_b=None,\n", 272 | " **kwargs\n", 273 | " ):\n", 274 | "\n", 275 | " query_cls_embedding = self.get_pooled_embedding(\n", 276 | " input_ids_a,\n", 277 | " token_type_ids_a,\n", 278 | " position_ids_ids_a,\n", 279 | " attention_mask_a\n", 280 | " )\n", 281 | "\n", 282 | " title_cls_embedding = self.get_pooled_embedding(\n", 283 | " input_ids_b,\n", 284 | " token_type_ids_b,\n", 285 | " position_ids_b,\n", 286 | " attention_mask_b\n", 287 | " )\n", 288 | "\n", 289 | " cosine_sim = torch.sum(\n", 290 | " query_cls_embedding * title_cls_embedding,\n", 291 | " axis=-1\n", 292 | " )\n", 293 | "\n", 294 | " return cosine_sim\n", 295 | "\n", 296 | " def forward(\n", 297 | " self,\n", 298 | " input_ids_a,\n", 299 | " input_ids_b,\n", 300 | " token_type_ids_a=None,\n", 301 | " position_ids_ids_a=None,\n", 302 | " attention_mask_a=None,\n", 303 | " token_type_ids_b=None,\n", 304 | " position_ids_b=None,\n", 305 | " attention_mask_b=None,\n", 306 | " label_ids=None,\n", 307 | " **kwargs\n", 308 | " ):\n", 309 | "\n", 310 | " cls_embedding_a = self.get_pooled_embedding(\n", 311 | " input_ids_a,\n", 312 | " token_type_ids_a,\n", 313 | " position_ids_ids_a,\n", 314 | " attention_mask_a\n", 315 | " )\n", 316 | "\n", 317 | " cls_embedding_b = self.get_pooled_embedding(\n", 318 | " input_ids_b,\n", 319 | " token_type_ids_b,\n", 320 | " position_ids_b,\n", 321 | " attention_mask_b\n", 322 | " )\n", 323 | "\n", 324 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n", 325 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n", 326 | " \n", 327 | " labels = label_ids[:, None] < label_ids[None, :]\n", 328 | " labels = labels.long()\n", 329 | " \n", 330 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n", 331 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n", 332 | " loss = torch.logsumexp(cosine_sim.view(-1), dim=0)\n", 333 | "\n", 334 | " return cosine_sim, loss\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "e630530b", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "dl_module = CoSENT.from_pretrained(\n", 345 | " 'bert-base-chinese', \n", 346 | " config=bert_config\n", 347 | ")" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "13e3c8ac", 353 | "metadata": {}, 354 | "source": [ 355 | "
\n", 356 | "\n", 357 | "### 三、任务构建" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "31d1f76c", 363 | "metadata": {}, 364 | "source": [ 365 | "#### 1. 任务参数和必要部件设定" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "id": "943bf64c", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "# 设置运行次数\n", 376 | "num_epoches = 5\n", 377 | "batch_size = 32" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "id": "74641ede", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "param_optimizer = list(dl_module.named_parameters())\n", 388 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n", 389 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 390 | "optimizer_grouped_parameters = [\n", 391 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n", 392 | " 'weight_decay': 0.01},\n", 393 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 394 | "] " 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "id": "bd5a9361", 400 | "metadata": {}, 401 | "source": [ 402 | "#### 2. 任务创建" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "5bc04465", 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "import torch\n", 413 | "import numpy as np\n", 414 | "from scipy import stats\n", 415 | "\n", 416 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n", 417 | "\n", 418 | "\n", 419 | "class CoSENTTask(SequenceClassificationTask):\n", 420 | " \"\"\"\n", 421 | " 用于CoSENT模型文本匹配任务的Task\n", 422 | " \n", 423 | " Args:\n", 424 | " module: 深度学习模型\n", 425 | " optimizer: 训练模型使用的优化器名或者优化器对象\n", 426 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n", 427 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n", 428 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n", 429 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n", 430 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n", 431 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n", 432 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n", 433 | " **kwargs (optional): 其他可选参数\n", 434 | " \"\"\" # noqa: ignore flake8\"\n", 435 | "\n", 436 | " def _on_evaluate_begin_record(self, **kwargs):\n", 437 | "\n", 438 | " self.evaluate_logs['eval_loss'] = 0\n", 439 | " self.evaluate_logs['eval_step'] = 0\n", 440 | " self.evaluate_logs['eval_example'] = 0\n", 441 | "\n", 442 | " self.evaluate_logs['labels'] = []\n", 443 | " self.evaluate_logs['eval_sim'] = []\n", 444 | "\n", 445 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n", 446 | "\n", 447 | " with torch.no_grad():\n", 448 | " # compute loss\n", 449 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n", 450 | " self.evaluate_logs['eval_loss'] += loss.item()\n", 451 | "\n", 452 | " if 'label_ids' in inputs:\n", 453 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n", 454 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n", 455 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n", 456 | "\n", 457 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n", 458 | " self.evaluate_logs['eval_step'] += 1\n", 459 | "\n", 460 | " def _on_evaluate_epoch_end(\n", 461 | " self,\n", 462 | " validation_data,\n", 463 | " epoch=1,\n", 464 | " is_evaluate_print=True,\n", 465 | " **kwargs\n", 466 | " ):\n", 467 | "\n", 468 | " if is_evaluate_print:\n", 469 | " if 'labels' in self.evaluate_logs:\n", 470 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n", 471 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n", 472 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n", 473 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n", 474 | " spearman_corr,\n", 475 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n", 476 | " )\n", 477 | " )\n", 478 | " else:\n", 479 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "id": "3dfc61d9", 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "id": "35c96cf8", 495 | "metadata": { 496 | "tags": [] 497 | }, 498 | "source": [ 499 | "#### 3. 训练" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "62e3e9ff", 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "model.fit(\n", 510 | " cosent_train_dataset,\n", 511 | " cosent_dev_dataset,\n", 512 | " lr=2e-5,\n", 513 | " epochs=num_epoches,\n", 514 | " batch_size=batch_size,\n", 515 | " params=optimizer_grouped_parameters\n", 516 | ")" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "id": "9b27e57b", 522 | "metadata": {}, 523 | "source": [ 524 | "
\n", 525 | "\n", 526 | "### 四、模型验证" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "id": "6b296f3f-7adb-41df-9a4e-43847b17a900", 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "import torch\n", 537 | "\n", 538 | "from torch.utils.data import DataLoader\n", 539 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n", 540 | "\n", 541 | "\n", 542 | "class CoSENTPredictor(SequenceClassificationPredictor):\n", 543 | " \"\"\"\n", 544 | " CoSENT的预测器\n", 545 | " \n", 546 | " Args:\n", 547 | " module: 深度学习模型\n", 548 | " tokernizer: 分词器\n", 549 | " cat2id (:obj:`dict`): 标签映射\n", 550 | " \"\"\" # noqa: ignore flake8\"\n", 551 | "\n", 552 | " def _get_input_ids(\n", 553 | " self,\n", 554 | " text_a,\n", 555 | " text_b\n", 556 | " ):\n", 557 | " if self.tokenizer.tokenizer_type == 'vanilla':\n", 558 | " return self._convert_to_vanilla_ids(text_a, text_b)\n", 559 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n", 560 | " return self._convert_to_transfomer_ids(text_a, text_b)\n", 561 | " elif self.tokenizer.tokenizer_type == 'customized':\n", 562 | " return self._convert_to_customized_ids(text_a, text_b)\n", 563 | " else:\n", 564 | " raise ValueError(\"The tokenizer type does not exist\")\n", 565 | "\n", 566 | " def _convert_to_transfomer_ids(\n", 567 | " self,\n", 568 | " text_a,\n", 569 | " text_b\n", 570 | " ):\n", 571 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n", 572 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n", 573 | "\n", 574 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n", 575 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n", 576 | "\n", 577 | " features = {\n", 578 | " 'input_ids_a': input_ids_a,\n", 579 | " 'attention_mask_a': input_mask_a,\n", 580 | " 'token_type_ids_a': segment_ids_a,\n", 581 | " 'input_ids_b': input_ids_b,\n", 582 | " 'attention_mask_b': input_mask_b,\n", 583 | " 'token_type_ids_b': segment_ids_b\n", 584 | " }\n", 585 | "\n", 586 | " return features\n", 587 | "\n", 588 | " def predict_one_sample(\n", 589 | " self,\n", 590 | " text,\n", 591 | " topk=None,\n", 592 | " threshold=0.5,\n", 593 | " return_label_name=True,\n", 594 | " return_proba=False\n", 595 | " ):\n", 596 | " if topk is None:\n", 597 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n", 598 | " text_a, text_b = text\n", 599 | " features = self._get_input_ids(text_a, text_b)\n", 600 | " self.module.eval()\n", 601 | "\n", 602 | " with torch.no_grad():\n", 603 | " inputs = self._get_module_one_sample_inputs(features)\n", 604 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 605 | "\n", 606 | " _proba = logits[0]\n", 607 | " \n", 608 | " if threshold is not None:\n", 609 | " _pred = self._threshold(_proba, threshold)\n", 610 | "\n", 611 | " if return_label_name and threshold is not None:\n", 612 | " _pred = self.id2cat[_pred]\n", 613 | "\n", 614 | " if threshold is not None:\n", 615 | " if return_proba:\n", 616 | " return [_pred, _proba]\n", 617 | " else:\n", 618 | " return _pred\n", 619 | "\n", 620 | " return _proba\n", 621 | "\n", 622 | " def predict_batch(\n", 623 | " self,\n", 624 | " test_data,\n", 625 | " batch_size=16,\n", 626 | " shuffle=False\n", 627 | " ):\n", 628 | " self.inputs_cols = test_data.dataset_cols\n", 629 | "\n", 630 | " preds = []\n", 631 | "\n", 632 | " self.module.eval()\n", 633 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n", 634 | "\n", 635 | " with torch.no_grad():\n", 636 | " for step, inputs in enumerate(generator):\n", 637 | " inputs = self._get_module_batch_inputs(inputs)\n", 638 | "\n", 639 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n", 640 | "\n", 641 | " preds.extend(logits)\n", 642 | "\n", 643 | " return preds" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": null, 649 | "id": "b1df061b", 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "id": "24994fe9", 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "cosent_predictor_instance.predict_one_sample(\n", 664 | " ['喜欢打篮球的男生喜欢什么样的女生', \n", 665 | " '爱打篮球的男生喜欢什么样的女生'],\n", 666 | " threshold=None\n", 667 | ")" 668 | ] 669 | } 670 | ], 671 | "metadata": { 672 | "kernelspec": { 673 | "display_name": "Python 3 (ipykernel)", 674 | "language": "python", 675 | "name": "python3" 676 | }, 677 | "language_info": { 678 | "codemirror_mode": { 679 | "name": "ipython", 680 | "version": 3 681 | }, 682 | "file_extension": ".py", 683 | "mimetype": "text/x-python", 684 | "name": "python", 685 | "nbconvert_exporter": "python", 686 | "pygments_lexer": "ipython3", 687 | "version": "3.8.10" 688 | } 689 | }, 690 | "nbformat": 4, 691 | "nbformat_minor": 5 692 | } 693 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch_CoSENT 2 | 3 | ## 简介 4 | 5 | 比Sentence-BERT更有效的句向量方案,复现苏神提出的[CoSENT](https://github.com/bojone/CoSENT),模型细节可以参考苏神的文章:https://kexue.fm/archives/8847 6 | 7 | **数据下载** 8 | 9 | * ATEC、BQ、LCQMC和PAWSX:https://github.com/bojone/BERT-whitening/tree/main/chn 10 | * CHIP-STS(平安医疗科技疾病问答迁移学习):https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 11 | 12 | 13 | ## 环境 14 | 15 | ``` 16 | pip install ark-nlp 17 | pip install pandas 18 | ``` 19 | 20 | ## 使用说明 21 | 22 | 项目目录按以下格式设置 23 | 24 | ```shell 25 | │ 26 | ├── data # 数据文件夹 27 | │ ├── source_datasets 28 | │ ├── task_datasets 29 | │ └── output_datasets 30 | │ 31 | ├── checkpoint # 存放训练好的模型 32 | │ ├── ... 33 | │ └── ... 34 | │ 35 | └── code # 代码 36 | ``` 37 | 下载数据并解压到`data/source_datasets`中,运行`code`文件夹中的`.ipynb`文件,最终提交文件会生成在`data/output_datasets` 38 | 39 | ## 参数设置 40 | 41 | 代码参数设置如下: 42 | 43 | ``` 44 | 句子截断长度:64(PAWSX数据集截断长度为128) 45 | batch_size:32 46 | epochs:5 47 | ``` 48 | 49 | 50 | ## 效果 51 | 52 | 使用spearman系数作为测评指标,ATEC、BQ、LCQMC和PAWSX使用test集进行测试实验,CHIP-STS则使用验证集 53 | 54 | | | ATEC | BQ | LCQMC | PAWSX | CHIP-STS | 55 | | :-: | :-: | :-: | :-: | :-: | :-: | 56 | | BERT+CoSENT(ark-nlp) | 49.80 | 72.46 | 79.00 | 59.17 | 76.22 | 57 | | BERT+CoSENT(bert4keras) | 49.74 | 72.38 | 78.69 | 60.00 | | 58 | | Sentence-BERT(bert4keras)| 46.36 | 70.36 | 78.72 | 46.86 | | 59 | 60 | PS:上表ark-nlp展示的是5轮里最好的结果,由于没有深入了解bert4keras,所以设置参数可能还是存在差异,因此对比仅供参考 61 | 62 | 针对CHIP-STS测试数据集,选择阈值为0.6生成结果进行提交,结果如下: 63 | 64 | | | Precision | Recall | Macro-F1| 65 | | :-: | :-: | :-: | :-:| 66 | | BERT+CoSENT| 82.89 | 81.528 | 81.688 | 67 | | BERT句子对分类 | 84.331 | 83.799 | 83.924| 68 | 69 | 70 | ## Acknowledge 71 | 72 | 感谢苏神无私的分享 -------------------------------------------------------------------------------- /text_match_CHIP-STS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings\n", 10 | "warnings.filterwarnings(\"ignore\")\n", 11 | "\n", 12 | "import os\n", 13 | "import jieba\n", 14 | "import torch\n", 15 | "import pickle\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import pandas as pd\n", 19 | "\n", 20 | "from ark_nlp.model.tm.bert import Bert\n", 21 | "from ark_nlp.model.tm.bert import BertConfig\n", 22 | "from ark_nlp.model.tm.bert import Dataset\n", 23 | "from ark_nlp.model.tm.bert import Task\n", 24 | "from ark_nlp.model.tm.bert import get_default_model_optimizer\n", 25 | "from ark_nlp.model.tm.bert import Tokenizer" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### 一、数据读入与处理" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "#### 1. 数据读入" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "train_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_train.json')\n", 49 | "train_data_df = (train_data_df\n", 50 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n", 51 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])\n", 52 | "\n", 53 | "dev_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_dev.json')\n", 54 | "dev_data_df = dev_data_df[dev_data_df['label'] != \"NA\"]\n", 55 | "dev_data_df = (dev_data_df\n", 56 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n", 57 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "tm_train_dataset = Dataset(train_data_df)\n", 67 | "tm_dev_dataset = Dataset(dev_data_df)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "#### 2. 词典创建和生成分词器" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "#### 3. ID化" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "tm_train_dataset.convert_to_ids(tokenizer)\n", 100 | "tm_dev_dataset.convert_to_ids(tokenizer)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "
\n", 108 | "\n", 109 | "### 二、模型构建" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "#### 1. 模型参数设置" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "config = BertConfig.from_pretrained('bert-base-chinese',\n", 126 | " num_labels=len(tm_train_dataset.cat2id))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "#### 2. 模型创建" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "torch.cuda.empty_cache()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "dl_module = Bert.from_pretrained('bert-base-chinese', \n", 152 | " config=config)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "
\n", 160 | "\n", 161 | "### 三、任务构建" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "#### 1. 任务参数和必要部件设定" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# 设置运行次数\n", 178 | "num_epoches = 5\n", 179 | "batch_size = 32" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "optimizer = get_default_model_optimizer(dl_module)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "#### 2. 任务创建" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "model = Task(dl_module, optimizer, 'ce', cuda_device=0)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "#### 3. 训练" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "model.fit(tm_train_dataset, \n", 221 | " tm_dev_dataset,\n", 222 | " lr=2e-5,\n", 223 | " epochs=num_epoches, \n", 224 | " batch_size=batch_size\n", 225 | " )" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "
\n", 233 | "\n", 234 | "### 四、CBLUE提交生成" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "from ark_nlp.model.tm.bert import Predictor" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "tm_predictor_instance = Predictor(model.module, tokenizer, tm_train_dataset.cat2id)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "import pandas as pd\n", 262 | "\n", 263 | "from tqdm import tqdm\n", 264 | "\n", 265 | "test_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_test.json')\n", 266 | "\n", 267 | "submit = []\n", 268 | "for _id, _text_a, _text_b, _condition in tqdm(zip(\n", 269 | " test_df['id'],\n", 270 | " test_df['text1'],\n", 271 | " test_df['text2'],\n", 272 | " test_df['category']\n", 273 | ")):\n", 274 | " if _condition == 'daibetes':\n", 275 | " _condition = 'diabetes'\n", 276 | "\n", 277 | " predict_ = tm_predictor_instance.predict_one_sample([_text_a, _text_b])[0]\n", 278 | " \n", 279 | " submit.append({\n", 280 | " 'id': str(_id),\n", 281 | " 'text1': _text_a,\n", 282 | " 'text2': _text_b,\n", 283 | " 'label': predict_,\n", 284 | " 'category': _condition\n", 285 | " })" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "import json\n", 295 | "\n", 296 | "output_path = '../data/output_datasets/CHIP-STS_test.json'\n", 297 | "\n", 298 | "with open(output_path, 'w', encoding='utf-8') as f:\n", 299 | " f.write(json.dumps(submit, ensure_ascii=False))" 300 | ] 301 | } 302 | ], 303 | "metadata": { 304 | "kernelspec": { 305 | "display_name": "Python 3 (ipykernel)", 306 | "language": "python", 307 | "name": "python3" 308 | }, 309 | "language_info": { 310 | "codemirror_mode": { 311 | "name": "ipython", 312 | "version": 3 313 | }, 314 | "file_extension": ".py", 315 | "mimetype": "text/x-python", 316 | "name": "python", 317 | "nbconvert_exporter": "python", 318 | "pygments_lexer": "ipython3", 319 | "version": "3.8.10" 320 | } 321 | }, 322 | "nbformat": 4, 323 | "nbformat_minor": 4 324 | } 325 | --------------------------------------------------------------------------------