├── CoSENT_ PAWSX.ipynb
├── CoSENT_ATEC.ipynb
├── CoSENT_BQ.ipynb
├── CoSENT_CHIP-STS.ipynb
├── CoSENT_LCQMC.ipynb
├── README.md
└── text_match_CHIP-STS.ipynb
/CoSENT_ PAWSX.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9ced17a4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import pandas as pd\n",
12 | "\n",
13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "id": "bfcea417",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# 目录地址\n",
26 | "train_data_path = '../data/source_datasets/PAWSX/PAWSX.train.data'\n",
27 | "dev_data_path = '../data/source_datasets/PAWSX/PAWSX.test.data'"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "ccea726a",
33 | "metadata": {
34 | "tags": []
35 | },
36 | "source": [
37 | "### 一、数据读入与处理"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "8d5c3337",
43 | "metadata": {
44 | "tags": []
45 | },
46 | "source": [
47 | "#### 1. 数据读入"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n",
58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "8cbc9c3e-35e0-4bbb-82fe-c139be87f7f4",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "train_data_df = train_data_df[~train_data_df['text_a'].isnull()]\n",
69 | "dev_data_df = dev_data_df[~dev_data_df['text_a'].isnull()]"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "90876062",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "cosent_train_dataset = Dataset(train_data_df)\n",
80 | "cosent_dev_dataset = Dataset(dev_data_df)"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "e061890a",
86 | "metadata": {},
87 | "source": [
88 | "#### 2. 词典创建和生成分词器"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "be116454",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "# 加载分词器\n",
99 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=128)"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "0d6c3b3d",
105 | "metadata": {},
106 | "source": [
107 | "#### 3. ID化"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "566dd6b6",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "cosent_train_dataset.convert_to_ids(tokenizer)\n",
118 | "cosent_dev_dataset.convert_to_ids(tokenizer)"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "981b4160",
124 | "metadata": {},
125 | "source": [
126 | "
\n",
127 | "\n",
128 | "### 二、模型构建"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "id": "72753ee8",
134 | "metadata": {},
135 | "source": [
136 | "#### 1. 模型参数设置"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "527535d3",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "from transformers import BertConfig\n",
147 | "\n",
148 | "bert_config = BertConfig.from_pretrained(\n",
149 | " 'bert-base-chinese',\n",
150 | " num_labels=len(cosent_train_dataset.cat2id)\n",
151 | ")"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "id": "77d8b580",
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "torch.cuda.empty_cache()"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "id": "700d7752",
167 | "metadata": {},
168 | "source": [
169 | "#### 2. 模型创建"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "id": "58f380f4",
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "import torch\n",
180 | "import torch.nn.functional as F\n",
181 | "\n",
182 | "from torch import nn\n",
183 | "from transformers import BertModel\n",
184 | "from ark_nlp.nn import Bert\n",
185 | "\n",
186 | "\n",
187 | "class CoSENT(Bert):\n",
188 | " \"\"\"\n",
189 | " CoSENT模型\n",
190 | "\n",
191 | " Args:\n",
192 | " config:\n",
193 | " 模型的配置对象\n",
194 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n",
195 | " bert参数是否可训练,默认可训练\n",
196 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
197 | " bert输出的池化方式,默认为\"last_avg\",\n",
198 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
199 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
200 | " dropout比例,默认为None,实际设置时会设置成0\n",
201 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n",
202 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n",
203 | "\n",
204 | " Reference:\n",
205 | " [1] https://kexue.fm/archives/8847\n",
206 | " [2] https://github.com/bojone/CoSENT \n",
207 | " \"\"\" # noqa: ignore flake8\"\n",
208 | "\n",
209 | " def __init__(\n",
210 | " self,\n",
211 | " config,\n",
212 | " encoder_trained=True,\n",
213 | " pooling='last_avg',\n",
214 | " dropout=None,\n",
215 | " output_emb_size=0\n",
216 | " ):\n",
217 | "\n",
218 | " super(CoSENT, self).__init__(config)\n",
219 | "\n",
220 | " self.bert = BertModel(config)\n",
221 | " self.pooling = pooling\n",
222 | "\n",
223 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n",
224 | "\n",
225 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
226 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
227 | " # recall performance and efficiency\n",
228 | " self.output_emb_size = output_emb_size\n",
229 | " if self.output_emb_size > 0:\n",
230 | " self.emb_reduce_linear = nn.Linear(\n",
231 | " config.hidden_size,\n",
232 | " self.output_emb_size\n",
233 | " )\n",
234 | " torch.nn.init.trunc_normal_(\n",
235 | " self.emb_reduce_linear.weight,\n",
236 | " std=0.02\n",
237 | " )\n",
238 | "\n",
239 | " for param in self.bert.parameters():\n",
240 | " param.requires_grad = encoder_trained\n",
241 | "\n",
242 | " self.init_weights()\n",
243 | "\n",
244 | " def get_pooled_embedding(\n",
245 | " self,\n",
246 | " input_ids,\n",
247 | " token_type_ids=None,\n",
248 | " position_ids=None,\n",
249 | " attention_mask=None\n",
250 | " ):\n",
251 | " outputs = self.bert(\n",
252 | " input_ids,\n",
253 | " attention_mask=attention_mask,\n",
254 | " token_type_ids=token_type_ids,\n",
255 | " position_ids=position_ids,\n",
256 | " return_dict=True,\n",
257 | " output_hidden_states=True\n",
258 | " )\n",
259 | "\n",
260 | " encoder_feature = self.get_encoder_feature(\n",
261 | " outputs,\n",
262 | " attention_mask\n",
263 | " )\n",
264 | "\n",
265 | " if self.output_emb_size > 0:\n",
266 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
267 | "\n",
268 | " encoder_feature = self.dropout(encoder_feature)\n",
269 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n",
270 | "\n",
271 | " return out\n",
272 | "\n",
273 | " def cosine_sim(\n",
274 | " self,\n",
275 | " input_ids_a,\n",
276 | " input_ids_b,\n",
277 | " token_type_ids_a=None,\n",
278 | " position_ids_ids_a=None,\n",
279 | " attention_mask_a=None,\n",
280 | " token_type_ids_b=None,\n",
281 | " position_ids_b=None,\n",
282 | " attention_mask_b=None,\n",
283 | " **kwargs\n",
284 | " ):\n",
285 | "\n",
286 | " query_cls_embedding = self.get_pooled_embedding(\n",
287 | " input_ids_a,\n",
288 | " token_type_ids_a,\n",
289 | " position_ids_ids_a,\n",
290 | " attention_mask_a\n",
291 | " )\n",
292 | "\n",
293 | " title_cls_embedding = self.get_pooled_embedding(\n",
294 | " input_ids_b,\n",
295 | " token_type_ids_b,\n",
296 | " position_ids_b,\n",
297 | " attention_mask_b\n",
298 | " )\n",
299 | "\n",
300 | " cosine_sim = torch.sum(\n",
301 | " query_cls_embedding * title_cls_embedding,\n",
302 | " axis=-1\n",
303 | " )\n",
304 | "\n",
305 | " return cosine_sim\n",
306 | "\n",
307 | " def forward(\n",
308 | " self,\n",
309 | " input_ids_a,\n",
310 | " input_ids_b,\n",
311 | " token_type_ids_a=None,\n",
312 | " position_ids_ids_a=None,\n",
313 | " attention_mask_a=None,\n",
314 | " token_type_ids_b=None,\n",
315 | " position_ids_b=None,\n",
316 | " attention_mask_b=None,\n",
317 | " label_ids=None,\n",
318 | " **kwargs\n",
319 | " ):\n",
320 | "\n",
321 | " cls_embedding_a = self.get_pooled_embedding(\n",
322 | " input_ids_a,\n",
323 | " token_type_ids_a,\n",
324 | " position_ids_ids_a,\n",
325 | " attention_mask_a\n",
326 | " )\n",
327 | "\n",
328 | " cls_embedding_b = self.get_pooled_embedding(\n",
329 | " input_ids_b,\n",
330 | " token_type_ids_b,\n",
331 | " position_ids_b,\n",
332 | " attention_mask_b\n",
333 | " )\n",
334 | "\n",
335 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
336 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
337 | " \n",
338 | " labels = label_ids[:, None] < label_ids[None, :]\n",
339 | " labels = labels.long()\n",
340 | " \n",
341 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
342 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
343 | " loss = torch.logsumexp(cosine_sim, dim=0)\n",
344 | "\n",
345 | " return cosine_sim, loss\n"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "id": "e630530b",
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "dl_module = CoSENT.from_pretrained(\n",
356 | " 'bert-base-chinese', \n",
357 | " config=bert_config\n",
358 | ")"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "id": "13e3c8ac",
364 | "metadata": {},
365 | "source": [
366 | "
\n",
367 | "\n",
368 | "### 三、任务构建"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "id": "31d1f76c",
374 | "metadata": {},
375 | "source": [
376 | "#### 1. 任务参数和必要部件设定"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "id": "943bf64c",
383 | "metadata": {},
384 | "outputs": [],
385 | "source": [
386 | "# 设置运行次数\n",
387 | "num_epoches = 5\n",
388 | "batch_size = 32"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "id": "74641ede",
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "param_optimizer = list(dl_module.named_parameters())\n",
399 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
400 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
401 | "optimizer_grouped_parameters = [\n",
402 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
403 | " 'weight_decay': 0.01},\n",
404 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
405 | "] "
406 | ]
407 | },
408 | {
409 | "cell_type": "markdown",
410 | "id": "bd5a9361",
411 | "metadata": {},
412 | "source": [
413 | "#### 2. 任务创建"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "id": "5bc04465",
420 | "metadata": {},
421 | "outputs": [],
422 | "source": [
423 | "import torch\n",
424 | "import numpy as np\n",
425 | "from scipy import stats\n",
426 | "\n",
427 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
428 | "\n",
429 | "\n",
430 | "class CoSENTTask(SequenceClassificationTask):\n",
431 | " \"\"\"\n",
432 | " 用于CoSENT模型文本匹配任务的Task\n",
433 | " \n",
434 | " Args:\n",
435 | " module: 深度学习模型\n",
436 | " optimizer: 训练模型使用的优化器名或者优化器对象\n",
437 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n",
438 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
439 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
440 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
441 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n",
442 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n",
443 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
444 | " **kwargs (optional): 其他可选参数\n",
445 | " \"\"\" # noqa: ignore flake8\"\n",
446 | "\n",
447 | " def _on_evaluate_begin_record(self, **kwargs):\n",
448 | "\n",
449 | " self.evaluate_logs['eval_loss'] = 0\n",
450 | " self.evaluate_logs['eval_step'] = 0\n",
451 | " self.evaluate_logs['eval_example'] = 0\n",
452 | "\n",
453 | " self.evaluate_logs['labels'] = []\n",
454 | " self.evaluate_logs['eval_sim'] = []\n",
455 | "\n",
456 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
457 | "\n",
458 | " with torch.no_grad():\n",
459 | " # compute loss\n",
460 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
461 | " self.evaluate_logs['eval_loss'] += loss.item()\n",
462 | "\n",
463 | " if 'label_ids' in inputs:\n",
464 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
465 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
466 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
467 | "\n",
468 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n",
469 | " self.evaluate_logs['eval_step'] += 1\n",
470 | "\n",
471 | " def _on_evaluate_epoch_end(\n",
472 | " self,\n",
473 | " validation_data,\n",
474 | " epoch=1,\n",
475 | " is_evaluate_print=True,\n",
476 | " **kwargs\n",
477 | " ):\n",
478 | "\n",
479 | " if is_evaluate_print:\n",
480 | " if 'labels' in self.evaluate_logs:\n",
481 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
482 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
483 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
484 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
485 | " spearman_corr,\n",
486 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
487 | " )\n",
488 | " )\n",
489 | " else:\n",
490 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "id": "3dfc61d9",
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
501 | ]
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "id": "35c96cf8",
506 | "metadata": {
507 | "tags": []
508 | },
509 | "source": [
510 | "#### 3. 训练"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "id": "62e3e9ff",
517 | "metadata": {},
518 | "outputs": [],
519 | "source": [
520 | "model.fit(\n",
521 | " cosent_train_dataset,\n",
522 | " cosent_dev_dataset,\n",
523 | " lr=2e-5,\n",
524 | " epochs=num_epoches,\n",
525 | " batch_size=batch_size,\n",
526 | " params=optimizer_grouped_parameters\n",
527 | ")"
528 | ]
529 | },
530 | {
531 | "cell_type": "markdown",
532 | "id": "9b27e57b",
533 | "metadata": {},
534 | "source": [
535 | "
\n",
536 | "\n",
537 | "### 四、模型验证"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": null,
543 | "id": "adb8effd-c437-4dd5-a10b-db6537780f91",
544 | "metadata": {},
545 | "outputs": [],
546 | "source": [
547 | "import torch\n",
548 | "\n",
549 | "from torch.utils.data import DataLoader\n",
550 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
551 | "\n",
552 | "\n",
553 | "class CoSENTPredictor(SequenceClassificationPredictor):\n",
554 | " \"\"\"\n",
555 | " CoSENT的预测器\n",
556 | " \n",
557 | " Args:\n",
558 | " module: 深度学习模型\n",
559 | " tokernizer: 分词器\n",
560 | " cat2id (:obj:`dict`): 标签映射\n",
561 | " \"\"\" # noqa: ignore flake8\"\n",
562 | "\n",
563 | " def _get_input_ids(\n",
564 | " self,\n",
565 | " text_a,\n",
566 | " text_b\n",
567 | " ):\n",
568 | " if self.tokenizer.tokenizer_type == 'vanilla':\n",
569 | " return self._convert_to_vanilla_ids(text_a, text_b)\n",
570 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n",
571 | " return self._convert_to_transfomer_ids(text_a, text_b)\n",
572 | " elif self.tokenizer.tokenizer_type == 'customized':\n",
573 | " return self._convert_to_customized_ids(text_a, text_b)\n",
574 | " else:\n",
575 | " raise ValueError(\"The tokenizer type does not exist\")\n",
576 | "\n",
577 | " def _convert_to_transfomer_ids(\n",
578 | " self,\n",
579 | " text_a,\n",
580 | " text_b\n",
581 | " ):\n",
582 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
583 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
584 | "\n",
585 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
586 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
587 | "\n",
588 | " features = {\n",
589 | " 'input_ids_a': input_ids_a,\n",
590 | " 'attention_mask_a': input_mask_a,\n",
591 | " 'token_type_ids_a': segment_ids_a,\n",
592 | " 'input_ids_b': input_ids_b,\n",
593 | " 'attention_mask_b': input_mask_b,\n",
594 | " 'token_type_ids_b': segment_ids_b\n",
595 | " }\n",
596 | "\n",
597 | " return features\n",
598 | "\n",
599 | " def predict_one_sample(\n",
600 | " self,\n",
601 | " text,\n",
602 | " topk=None,\n",
603 | " threshold=0.5,\n",
604 | " return_label_name=True,\n",
605 | " return_proba=False\n",
606 | " ):\n",
607 | " if topk is None:\n",
608 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
609 | " text_a, text_b = text\n",
610 | " features = self._get_input_ids(text_a, text_b)\n",
611 | " self.module.eval()\n",
612 | "\n",
613 | " with torch.no_grad():\n",
614 | " inputs = self._get_module_one_sample_inputs(features)\n",
615 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
616 | "\n",
617 | " _proba = logits[0]\n",
618 | " \n",
619 | " if threshold is not None:\n",
620 | " _pred = self._threshold(_proba, threshold)\n",
621 | "\n",
622 | " if return_label_name and threshold is not None:\n",
623 | " _pred = self.id2cat[_pred]\n",
624 | "\n",
625 | " if threshold is not None:\n",
626 | " if return_proba:\n",
627 | " return [_pred, _proba]\n",
628 | " else:\n",
629 | " return _pred\n",
630 | "\n",
631 | " return _proba\n",
632 | "\n",
633 | " def predict_batch(\n",
634 | " self,\n",
635 | " test_data,\n",
636 | " batch_size=16,\n",
637 | " shuffle=False\n",
638 | " ):\n",
639 | " self.inputs_cols = test_data.dataset_cols\n",
640 | "\n",
641 | " preds = []\n",
642 | "\n",
643 | " self.module.eval()\n",
644 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
645 | "\n",
646 | " with torch.no_grad():\n",
647 | " for step, inputs in enumerate(generator):\n",
648 | " inputs = self._get_module_batch_inputs(inputs)\n",
649 | "\n",
650 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
651 | "\n",
652 | " preds.extend(logits)\n",
653 | "\n",
654 | " return preds"
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": null,
660 | "id": "fce50880-d821-4fc5-8e34-c4ff7c356287",
661 | "metadata": {},
662 | "outputs": [],
663 | "source": [
664 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "id": "26a67a56-1f8e-4579-81a2-1c578aaf92ab",
671 | "metadata": {},
672 | "outputs": [],
673 | "source": [
674 | "cosent_predictor_instance.predict_one_sample(\n",
675 | " ['1975年的nba赛季 - 76赛季是全美篮球协会的第30个赛季。', \n",
676 | " '1975-76赛季的全国篮球协会是nba的第30个赛季。'],\n",
677 | " threshold=None\n",
678 | ")"
679 | ]
680 | }
681 | ],
682 | "metadata": {
683 | "kernelspec": {
684 | "display_name": "Python 3 (ipykernel)",
685 | "language": "python",
686 | "name": "python3"
687 | },
688 | "language_info": {
689 | "codemirror_mode": {
690 | "name": "ipython",
691 | "version": 3
692 | },
693 | "file_extension": ".py",
694 | "mimetype": "text/x-python",
695 | "name": "python",
696 | "nbconvert_exporter": "python",
697 | "pygments_lexer": "ipython3",
698 | "version": "3.8.10"
699 | }
700 | },
701 | "nbformat": 4,
702 | "nbformat_minor": 5
703 | }
704 |
--------------------------------------------------------------------------------
/CoSENT_ATEC.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9ced17a4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import pandas as pd\n",
12 | "\n",
13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "id": "bfcea417",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# 目录地址\n",
26 | "train_data_path = '../data/source_datasets/ATEC/ATEC.train.data'\n",
27 | "dev_data_path = '../data/source_datasets/ATEC/ATEC.test.data'"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "ccea726a",
33 | "metadata": {
34 | "tags": []
35 | },
36 | "source": [
37 | "### 一、数据读入与处理"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "8d5c3337",
43 | "metadata": {
44 | "tags": []
45 | },
46 | "source": [
47 | "#### 1. 数据读入"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n",
58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "90876062",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "cosent_train_dataset = Dataset(train_data_df)\n",
69 | "cosent_dev_dataset = Dataset(dev_data_df)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "e061890a",
75 | "metadata": {},
76 | "source": [
77 | "#### 2. 词典创建和生成分词器"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "be116454",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "# 加载分词器\n",
88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "0d6c3b3d",
94 | "metadata": {},
95 | "source": [
96 | "#### 3. ID化"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "566dd6b6",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n",
107 | "cosent_dev_dataset.convert_to_ids(tokenizer)"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "id": "981b4160",
113 | "metadata": {},
114 | "source": [
115 | "
\n",
116 | "\n",
117 | "### 二、模型构建"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "id": "72753ee8",
123 | "metadata": {
124 | "tags": []
125 | },
126 | "source": [
127 | "#### 1. 模型参数设置"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "id": "527535d3",
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "from transformers import BertConfig\n",
138 | "\n",
139 | "bert_config = BertConfig.from_pretrained(\n",
140 | " 'bert-base-chinese',\n",
141 | " num_labels=len(cosent_train_dataset.cat2id)\n",
142 | ")"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "id": "77d8b580",
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "torch.cuda.empty_cache()"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "id": "700d7752",
158 | "metadata": {},
159 | "source": [
160 | "#### 2. 模型创建"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "58f380f4",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "import torch\n",
171 | "import torch.nn.functional as F\n",
172 | "\n",
173 | "from torch import nn\n",
174 | "from transformers import BertModel\n",
175 | "from ark_nlp.nn import Bert\n",
176 | "\n",
177 | "\n",
178 | "class CoSENT(Bert):\n",
179 | " \"\"\"\n",
180 | " CoSENT模型\n",
181 | "\n",
182 | " Args:\n",
183 | " config:\n",
184 | " 模型的配置对象\n",
185 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n",
186 | " bert参数是否可训练,默认可训练\n",
187 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
188 | " bert输出的池化方式,默认为\"last_avg\",\n",
189 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
190 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
191 | " dropout比例,默认为None,实际设置时会设置成0\n",
192 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n",
193 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n",
194 | "\n",
195 | " Reference:\n",
196 | " [1] https://kexue.fm/archives/8847\n",
197 | " [2] https://github.com/bojone/CoSENT \n",
198 | " \"\"\" # noqa: ignore flake8\"\n",
199 | "\n",
200 | " def __init__(\n",
201 | " self,\n",
202 | " config,\n",
203 | " encoder_trained=True,\n",
204 | " pooling='last_avg',\n",
205 | " dropout=None,\n",
206 | " output_emb_size=0\n",
207 | " ):\n",
208 | "\n",
209 | " super(CoSENT, self).__init__(config)\n",
210 | "\n",
211 | " self.bert = BertModel(config)\n",
212 | " self.pooling = pooling\n",
213 | "\n",
214 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n",
215 | "\n",
216 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
217 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
218 | " # recall performance and efficiency\n",
219 | " self.output_emb_size = output_emb_size\n",
220 | " if self.output_emb_size > 0:\n",
221 | " self.emb_reduce_linear = nn.Linear(\n",
222 | " config.hidden_size,\n",
223 | " self.output_emb_size\n",
224 | " )\n",
225 | " torch.nn.init.trunc_normal_(\n",
226 | " self.emb_reduce_linear.weight,\n",
227 | " std=0.02\n",
228 | " )\n",
229 | "\n",
230 | " for param in self.bert.parameters():\n",
231 | " param.requires_grad = encoder_trained\n",
232 | "\n",
233 | " self.init_weights()\n",
234 | "\n",
235 | " def get_pooled_embedding(\n",
236 | " self,\n",
237 | " input_ids,\n",
238 | " token_type_ids=None,\n",
239 | " position_ids=None,\n",
240 | " attention_mask=None\n",
241 | " ):\n",
242 | " outputs = self.bert(\n",
243 | " input_ids,\n",
244 | " attention_mask=attention_mask,\n",
245 | " token_type_ids=token_type_ids,\n",
246 | " position_ids=position_ids,\n",
247 | " return_dict=True,\n",
248 | " output_hidden_states=True\n",
249 | " )\n",
250 | "\n",
251 | " encoder_feature = self.get_encoder_feature(\n",
252 | " outputs,\n",
253 | " attention_mask\n",
254 | " )\n",
255 | "\n",
256 | " if self.output_emb_size > 0:\n",
257 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
258 | "\n",
259 | " encoder_feature = self.dropout(encoder_feature)\n",
260 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n",
261 | "\n",
262 | " return out\n",
263 | "\n",
264 | " def cosine_sim(\n",
265 | " self,\n",
266 | " input_ids_a,\n",
267 | " input_ids_b,\n",
268 | " token_type_ids_a=None,\n",
269 | " position_ids_ids_a=None,\n",
270 | " attention_mask_a=None,\n",
271 | " token_type_ids_b=None,\n",
272 | " position_ids_b=None,\n",
273 | " attention_mask_b=None,\n",
274 | " **kwargs\n",
275 | " ):\n",
276 | "\n",
277 | " query_cls_embedding = self.get_pooled_embedding(\n",
278 | " input_ids_a,\n",
279 | " token_type_ids_a,\n",
280 | " position_ids_ids_a,\n",
281 | " attention_mask_a\n",
282 | " )\n",
283 | "\n",
284 | " title_cls_embedding = self.get_pooled_embedding(\n",
285 | " input_ids_b,\n",
286 | " token_type_ids_b,\n",
287 | " position_ids_b,\n",
288 | " attention_mask_b\n",
289 | " )\n",
290 | "\n",
291 | " cosine_sim = torch.sum(\n",
292 | " query_cls_embedding * title_cls_embedding,\n",
293 | " axis=-1\n",
294 | " )\n",
295 | "\n",
296 | " return cosine_sim\n",
297 | "\n",
298 | " def forward(\n",
299 | " self,\n",
300 | " input_ids_a,\n",
301 | " input_ids_b,\n",
302 | " token_type_ids_a=None,\n",
303 | " position_ids_ids_a=None,\n",
304 | " attention_mask_a=None,\n",
305 | " token_type_ids_b=None,\n",
306 | " position_ids_b=None,\n",
307 | " attention_mask_b=None,\n",
308 | " label_ids=None,\n",
309 | " **kwargs\n",
310 | " ):\n",
311 | "\n",
312 | " cls_embedding_a = self.get_pooled_embedding(\n",
313 | " input_ids_a,\n",
314 | " token_type_ids_a,\n",
315 | " position_ids_ids_a,\n",
316 | " attention_mask_a\n",
317 | " )\n",
318 | "\n",
319 | " cls_embedding_b = self.get_pooled_embedding(\n",
320 | " input_ids_b,\n",
321 | " token_type_ids_b,\n",
322 | " position_ids_b,\n",
323 | " attention_mask_b\n",
324 | " )\n",
325 | "\n",
326 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
327 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
328 | " \n",
329 | " labels = label_ids[:, None] < label_ids[None, :]\n",
330 | " labels = labels.long()\n",
331 | " \n",
332 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
333 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
334 | " loss = torch.logsumexp(cosine_sim, dim=0)\n",
335 | "\n",
336 | " return cosine_sim, loss\n"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "id": "e630530b",
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "dl_module = CoSENT.from_pretrained(\n",
347 | " 'bert-base-chinese', \n",
348 | " config=bert_config\n",
349 | ")"
350 | ]
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "id": "13e3c8ac",
355 | "metadata": {},
356 | "source": [
357 | "
\n",
358 | "\n",
359 | "### 三、任务构建"
360 | ]
361 | },
362 | {
363 | "cell_type": "markdown",
364 | "id": "31d1f76c",
365 | "metadata": {},
366 | "source": [
367 | "#### 1. 任务参数和必要部件设定"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": null,
373 | "id": "943bf64c",
374 | "metadata": {},
375 | "outputs": [],
376 | "source": [
377 | "# 设置运行次数\n",
378 | "num_epoches = 5\n",
379 | "batch_size = 32"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "id": "74641ede",
386 | "metadata": {},
387 | "outputs": [],
388 | "source": [
389 | "param_optimizer = list(dl_module.named_parameters())\n",
390 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
391 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
392 | "optimizer_grouped_parameters = [\n",
393 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
394 | " 'weight_decay': 0.01},\n",
395 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
396 | "] "
397 | ]
398 | },
399 | {
400 | "cell_type": "markdown",
401 | "id": "bd5a9361",
402 | "metadata": {},
403 | "source": [
404 | "#### 2. 任务创建"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": null,
410 | "id": "5bc04465",
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "import torch\n",
415 | "import numpy as np\n",
416 | "from scipy import stats\n",
417 | "\n",
418 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
419 | "\n",
420 | "\n",
421 | "class CoSENTTask(SequenceClassificationTask):\n",
422 | " \"\"\"\n",
423 | " 用于CoSENT模型文本匹配任务的Task\n",
424 | " \n",
425 | " Args:\n",
426 | " module: 深度学习模型\n",
427 | " optimizer: 训练模型使用的优化器名或者优化器对象\n",
428 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n",
429 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
430 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
431 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
432 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n",
433 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n",
434 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
435 | " **kwargs (optional): 其他可选参数\n",
436 | " \"\"\" # noqa: ignore flake8\"\n",
437 | "\n",
438 | " def _on_evaluate_begin_record(self, **kwargs):\n",
439 | "\n",
440 | " self.evaluate_logs['eval_loss'] = 0\n",
441 | " self.evaluate_logs['eval_step'] = 0\n",
442 | " self.evaluate_logs['eval_example'] = 0\n",
443 | "\n",
444 | " self.evaluate_logs['labels'] = []\n",
445 | " self.evaluate_logs['eval_sim'] = []\n",
446 | "\n",
447 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
448 | "\n",
449 | " with torch.no_grad():\n",
450 | " # compute loss\n",
451 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
452 | " self.evaluate_logs['eval_loss'] += loss.item()\n",
453 | "\n",
454 | " if 'label_ids' in inputs:\n",
455 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
456 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
457 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
458 | "\n",
459 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n",
460 | " self.evaluate_logs['eval_step'] += 1\n",
461 | "\n",
462 | " def _on_evaluate_epoch_end(\n",
463 | " self,\n",
464 | " validation_data,\n",
465 | " epoch=1,\n",
466 | " is_evaluate_print=True,\n",
467 | " **kwargs\n",
468 | " ):\n",
469 | "\n",
470 | " if is_evaluate_print:\n",
471 | " if 'labels' in self.evaluate_logs:\n",
472 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
473 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
474 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
475 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
476 | " spearman_corr,\n",
477 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
478 | " )\n",
479 | " )\n",
480 | " else:\n",
481 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "id": "3dfc61d9",
488 | "metadata": {},
489 | "outputs": [],
490 | "source": [
491 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
492 | ]
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "id": "35c96cf8",
497 | "metadata": {
498 | "tags": []
499 | },
500 | "source": [
501 | "#### 3. 训练"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "id": "62e3e9ff",
508 | "metadata": {},
509 | "outputs": [],
510 | "source": [
511 | "model.fit(\n",
512 | " cosent_train_dataset,\n",
513 | " cosent_dev_dataset,\n",
514 | " lr=2e-5,\n",
515 | " epochs=num_epoches,\n",
516 | " batch_size=batch_size,\n",
517 | " params=optimizer_grouped_parameters\n",
518 | ")"
519 | ]
520 | },
521 | {
522 | "cell_type": "markdown",
523 | "id": "9b27e57b",
524 | "metadata": {},
525 | "source": [
526 | "
\n",
527 | "\n",
528 | "### 四、模型验证"
529 | ]
530 | },
531 | {
532 | "cell_type": "code",
533 | "execution_count": null,
534 | "id": "0800c2a3-2b57-435a-87d7-cfafbbc69fd1",
535 | "metadata": {},
536 | "outputs": [],
537 | "source": [
538 | "import torch\n",
539 | "\n",
540 | "from torch.utils.data import DataLoader\n",
541 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
542 | "\n",
543 | "\n",
544 | "class CoSENTPredictor(SequenceClassificationPredictor):\n",
545 | " \"\"\"\n",
546 | " CoSENT的预测器\n",
547 | " \n",
548 | " Args:\n",
549 | " module: 深度学习模型\n",
550 | " tokernizer: 分词器\n",
551 | " cat2id (:obj:`dict`): 标签映射\n",
552 | " \"\"\" # noqa: ignore flake8\"\n",
553 | "\n",
554 | " def _get_input_ids(\n",
555 | " self,\n",
556 | " text_a,\n",
557 | " text_b\n",
558 | " ):\n",
559 | " if self.tokenizer.tokenizer_type == 'vanilla':\n",
560 | " return self._convert_to_vanilla_ids(text_a, text_b)\n",
561 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n",
562 | " return self._convert_to_transfomer_ids(text_a, text_b)\n",
563 | " elif self.tokenizer.tokenizer_type == 'customized':\n",
564 | " return self._convert_to_customized_ids(text_a, text_b)\n",
565 | " else:\n",
566 | " raise ValueError(\"The tokenizer type does not exist\")\n",
567 | "\n",
568 | " def _convert_to_transfomer_ids(\n",
569 | " self,\n",
570 | " text_a,\n",
571 | " text_b\n",
572 | " ):\n",
573 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
574 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
575 | "\n",
576 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
577 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
578 | "\n",
579 | " features = {\n",
580 | " 'input_ids_a': input_ids_a,\n",
581 | " 'attention_mask_a': input_mask_a,\n",
582 | " 'token_type_ids_a': segment_ids_a,\n",
583 | " 'input_ids_b': input_ids_b,\n",
584 | " 'attention_mask_b': input_mask_b,\n",
585 | " 'token_type_ids_b': segment_ids_b\n",
586 | " }\n",
587 | "\n",
588 | " return features\n",
589 | "\n",
590 | " def predict_one_sample(\n",
591 | " self,\n",
592 | " text,\n",
593 | " topk=None,\n",
594 | " threshold=0.5,\n",
595 | " return_label_name=True,\n",
596 | " return_proba=False\n",
597 | " ):\n",
598 | " if topk is None:\n",
599 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
600 | " text_a, text_b = text\n",
601 | " features = self._get_input_ids(text_a, text_b)\n",
602 | " self.module.eval()\n",
603 | "\n",
604 | " with torch.no_grad():\n",
605 | " inputs = self._get_module_one_sample_inputs(features)\n",
606 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
607 | "\n",
608 | " _proba = logits[0]\n",
609 | " \n",
610 | " if threshold is not None:\n",
611 | " _pred = self._threshold(_proba, threshold)\n",
612 | "\n",
613 | " if return_label_name and threshold is not None:\n",
614 | " _pred = self.id2cat[_pred]\n",
615 | "\n",
616 | " if threshold is not None:\n",
617 | " if return_proba:\n",
618 | " return [_pred, _proba]\n",
619 | " else:\n",
620 | " return _pred\n",
621 | "\n",
622 | " return _proba\n",
623 | "\n",
624 | " def predict_batch(\n",
625 | " self,\n",
626 | " test_data,\n",
627 | " batch_size=16,\n",
628 | " shuffle=False\n",
629 | " ):\n",
630 | " self.inputs_cols = test_data.dataset_cols\n",
631 | "\n",
632 | " preds = []\n",
633 | "\n",
634 | " self.module.eval()\n",
635 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
636 | "\n",
637 | " with torch.no_grad():\n",
638 | " for step, inputs in enumerate(generator):\n",
639 | " inputs = self._get_module_batch_inputs(inputs)\n",
640 | "\n",
641 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
642 | "\n",
643 | " preds.extend(logits)\n",
644 | "\n",
645 | " return preds"
646 | ]
647 | },
648 | {
649 | "cell_type": "code",
650 | "execution_count": null,
651 | "id": "6f416a92-4e5d-4f78-91f1-b42f9db71b25",
652 | "metadata": {},
653 | "outputs": [],
654 | "source": [
655 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": null,
661 | "id": "55fe09db-9e27-4300-b57f-df6603643e50",
662 | "metadata": {},
663 | "outputs": [],
664 | "source": [
665 | "cosent_predictor_instance.predict_one_sample(\n",
666 | " ['怎么花呗不可以用来生活缴费了呀', \n",
667 | " '怎么我的花呗不能付电费了'], \n",
668 | " threshold=None\n",
669 | ")"
670 | ]
671 | }
672 | ],
673 | "metadata": {
674 | "kernelspec": {
675 | "display_name": "Python 3 (ipykernel)",
676 | "language": "python",
677 | "name": "python3"
678 | },
679 | "language_info": {
680 | "codemirror_mode": {
681 | "name": "ipython",
682 | "version": 3
683 | },
684 | "file_extension": ".py",
685 | "mimetype": "text/x-python",
686 | "name": "python",
687 | "nbconvert_exporter": "python",
688 | "pygments_lexer": "ipython3",
689 | "version": "3.8.10"
690 | }
691 | },
692 | "nbformat": 4,
693 | "nbformat_minor": 5
694 | }
695 |
--------------------------------------------------------------------------------
/CoSENT_BQ.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9ced17a4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import pandas as pd\n",
12 | "\n",
13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "id": "bfcea417",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# 目录地址\n",
26 | "train_data_path = '../data/source_datasets/BQ/BQ.train.data'\n",
27 | "dev_data_path = '../data/source_datasets/BQ/BQ.test.data'"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "ccea726a",
33 | "metadata": {
34 | "tags": []
35 | },
36 | "source": [
37 | "### 一、数据读入与处理"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "8d5c3337",
43 | "metadata": {
44 | "tags": []
45 | },
46 | "source": [
47 | "#### 1. 数据读入"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n",
58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "90876062",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "cosent_train_dataset = Dataset(train_data_df)\n",
69 | "cosent_dev_dataset = Dataset(dev_data_df)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "e061890a",
75 | "metadata": {},
76 | "source": [
77 | "#### 2. 词典创建和生成分词器"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "be116454",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "# 加载分词器\n",
88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "0d6c3b3d",
94 | "metadata": {},
95 | "source": [
96 | "#### 3. ID化"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "566dd6b6",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n",
107 | "cosent_dev_dataset.convert_to_ids(tokenizer)"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "id": "981b4160",
113 | "metadata": {},
114 | "source": [
115 | "
\n",
116 | "\n",
117 | "### 二、模型构建"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "id": "72753ee8",
123 | "metadata": {},
124 | "source": [
125 | "#### 1. 模型参数设置"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "id": "527535d3",
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "from transformers import BertConfig\n",
136 | "\n",
137 | "bert_config = BertConfig.from_pretrained(\n",
138 | " 'bert-base-chinese',\n",
139 | " num_labels=2\n",
140 | ")"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "id": "77d8b580",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "torch.cuda.empty_cache()"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "700d7752",
156 | "metadata": {},
157 | "source": [
158 | "#### 2. 模型创建"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "id": "58f380f4",
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "import torch\n",
169 | "import torch.nn.functional as F\n",
170 | "\n",
171 | "from torch import nn\n",
172 | "from transformers import BertModel\n",
173 | "from ark_nlp.nn import Bert\n",
174 | "\n",
175 | "\n",
176 | "class CoSENT(Bert):\n",
177 | " \"\"\"\n",
178 | " CoSENT模型\n",
179 | "\n",
180 | " Args:\n",
181 | " config:\n",
182 | " 模型的配置对象\n",
183 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n",
184 | " bert参数是否可训练,默认可训练\n",
185 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
186 | " bert输出的池化方式,默认为\"last_avg\",\n",
187 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
188 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
189 | " dropout比例,默认为None,实际设置时会设置成0\n",
190 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n",
191 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n",
192 | "\n",
193 | " Reference:\n",
194 | " [1] https://kexue.fm/archives/8847\n",
195 | " [2] https://github.com/bojone/CoSENT \n",
196 | " \"\"\" # noqa: ignore flake8\"\n",
197 | "\n",
198 | " def __init__(\n",
199 | " self,\n",
200 | " config,\n",
201 | " encoder_trained=True,\n",
202 | " pooling='last_avg',\n",
203 | " dropout=None,\n",
204 | " output_emb_size=0\n",
205 | " ):\n",
206 | "\n",
207 | " super(CoSENT, self).__init__(config)\n",
208 | "\n",
209 | " self.bert = BertModel(config)\n",
210 | " self.pooling = pooling\n",
211 | "\n",
212 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n",
213 | "\n",
214 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
215 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
216 | " # recall performance and efficiency\n",
217 | " self.output_emb_size = output_emb_size\n",
218 | " if self.output_emb_size > 0:\n",
219 | " self.emb_reduce_linear = nn.Linear(\n",
220 | " config.hidden_size,\n",
221 | " self.output_emb_size\n",
222 | " )\n",
223 | " torch.nn.init.trunc_normal_(\n",
224 | " self.emb_reduce_linear.weight,\n",
225 | " std=0.02\n",
226 | " )\n",
227 | "\n",
228 | " for param in self.bert.parameters():\n",
229 | " param.requires_grad = encoder_trained\n",
230 | "\n",
231 | " self.init_weights()\n",
232 | "\n",
233 | " def get_pooled_embedding(\n",
234 | " self,\n",
235 | " input_ids,\n",
236 | " token_type_ids=None,\n",
237 | " position_ids=None,\n",
238 | " attention_mask=None\n",
239 | " ):\n",
240 | " outputs = self.bert(\n",
241 | " input_ids,\n",
242 | " attention_mask=attention_mask,\n",
243 | " token_type_ids=token_type_ids,\n",
244 | " position_ids=position_ids,\n",
245 | " return_dict=True,\n",
246 | " output_hidden_states=True\n",
247 | " )\n",
248 | "\n",
249 | " encoder_feature = self.get_encoder_feature(\n",
250 | " outputs,\n",
251 | " attention_mask\n",
252 | " )\n",
253 | "\n",
254 | " if self.output_emb_size > 0:\n",
255 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
256 | "\n",
257 | " encoder_feature = self.dropout(encoder_feature)\n",
258 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n",
259 | "\n",
260 | " return out\n",
261 | "\n",
262 | " def cosine_sim(\n",
263 | " self,\n",
264 | " input_ids_a,\n",
265 | " input_ids_b,\n",
266 | " token_type_ids_a=None,\n",
267 | " position_ids_ids_a=None,\n",
268 | " attention_mask_a=None,\n",
269 | " token_type_ids_b=None,\n",
270 | " position_ids_b=None,\n",
271 | " attention_mask_b=None,\n",
272 | " **kwargs\n",
273 | " ):\n",
274 | "\n",
275 | " query_cls_embedding = self.get_pooled_embedding(\n",
276 | " input_ids_a,\n",
277 | " token_type_ids_a,\n",
278 | " position_ids_ids_a,\n",
279 | " attention_mask_a\n",
280 | " )\n",
281 | "\n",
282 | " title_cls_embedding = self.get_pooled_embedding(\n",
283 | " input_ids_b,\n",
284 | " token_type_ids_b,\n",
285 | " position_ids_b,\n",
286 | " attention_mask_b\n",
287 | " )\n",
288 | "\n",
289 | " cosine_sim = torch.sum(\n",
290 | " query_cls_embedding * title_cls_embedding,\n",
291 | " axis=-1\n",
292 | " )\n",
293 | "\n",
294 | " return cosine_sim\n",
295 | "\n",
296 | " def forward(\n",
297 | " self,\n",
298 | " input_ids_a,\n",
299 | " input_ids_b,\n",
300 | " token_type_ids_a=None,\n",
301 | " position_ids_ids_a=None,\n",
302 | " attention_mask_a=None,\n",
303 | " token_type_ids_b=None,\n",
304 | " position_ids_b=None,\n",
305 | " attention_mask_b=None,\n",
306 | " label_ids=None,\n",
307 | " **kwargs\n",
308 | " ):\n",
309 | "\n",
310 | " cls_embedding_a = self.get_pooled_embedding(\n",
311 | " input_ids_a,\n",
312 | " token_type_ids_a,\n",
313 | " position_ids_ids_a,\n",
314 | " attention_mask_a\n",
315 | " )\n",
316 | "\n",
317 | " cls_embedding_b = self.get_pooled_embedding(\n",
318 | " input_ids_b,\n",
319 | " token_type_ids_b,\n",
320 | " position_ids_b,\n",
321 | " attention_mask_b\n",
322 | " )\n",
323 | "\n",
324 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
325 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
326 | " \n",
327 | " labels = label_ids[:, None] < label_ids[None, :]\n",
328 | " labels = labels.long()\n",
329 | " \n",
330 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
331 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
332 | " loss = torch.logsumexp(cosine_sim, dim=0)\n",
333 | "\n",
334 | " return cosine_sim, loss\n"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "id": "e630530b",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "dl_module = CoSENT.from_pretrained(\n",
345 | " 'bert-base-chinese', \n",
346 | " config=bert_config\n",
347 | ")"
348 | ]
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "id": "13e3c8ac",
353 | "metadata": {},
354 | "source": [
355 | "
\n",
356 | "\n",
357 | "### 三、任务构建"
358 | ]
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "id": "31d1f76c",
363 | "metadata": {},
364 | "source": [
365 | "#### 1. 任务参数和必要部件设定"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "id": "943bf64c",
372 | "metadata": {},
373 | "outputs": [],
374 | "source": [
375 | "# 设置运行次数\n",
376 | "num_epoches = 5\n",
377 | "batch_size = 32"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "id": "74641ede",
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "param_optimizer = list(dl_module.named_parameters())\n",
388 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
389 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
390 | "optimizer_grouped_parameters = [\n",
391 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
392 | " 'weight_decay': 0.01},\n",
393 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
394 | "] "
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "id": "bd5a9361",
400 | "metadata": {},
401 | "source": [
402 | "#### 2. 任务创建"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "id": "5bc04465",
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "import torch\n",
413 | "import numpy as np\n",
414 | "from scipy import stats\n",
415 | "\n",
416 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
417 | "\n",
418 | "\n",
419 | "class CoSENTTask(SequenceClassificationTask):\n",
420 | " \"\"\"\n",
421 | " 用于CoSENT模型文本匹配任务的Task\n",
422 | " \n",
423 | " Args:\n",
424 | " module: 深度学习模型\n",
425 | " optimizer: 训练模型使用的优化器名或者优化器对象\n",
426 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n",
427 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
428 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
429 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
430 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n",
431 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n",
432 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
433 | " **kwargs (optional): 其他可选参数\n",
434 | " \"\"\" # noqa: ignore flake8\"\n",
435 | "\n",
436 | " def _on_evaluate_begin_record(self, **kwargs):\n",
437 | "\n",
438 | " self.evaluate_logs['eval_loss'] = 0\n",
439 | " self.evaluate_logs['eval_step'] = 0\n",
440 | " self.evaluate_logs['eval_example'] = 0\n",
441 | "\n",
442 | " self.evaluate_logs['labels'] = []\n",
443 | " self.evaluate_logs['eval_sim'] = []\n",
444 | "\n",
445 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
446 | "\n",
447 | " with torch.no_grad():\n",
448 | " # compute loss\n",
449 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
450 | " self.evaluate_logs['eval_loss'] += loss.item()\n",
451 | "\n",
452 | " if 'label_ids' in inputs:\n",
453 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
454 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
455 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
456 | "\n",
457 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n",
458 | " self.evaluate_logs['eval_step'] += 1\n",
459 | "\n",
460 | " def _on_evaluate_epoch_end(\n",
461 | " self,\n",
462 | " validation_data,\n",
463 | " epoch=1,\n",
464 | " is_evaluate_print=True,\n",
465 | " **kwargs\n",
466 | " ):\n",
467 | "\n",
468 | " if is_evaluate_print:\n",
469 | " if 'labels' in self.evaluate_logs:\n",
470 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
471 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
472 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
473 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
474 | " spearman_corr,\n",
475 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
476 | " )\n",
477 | " )\n",
478 | " else:\n",
479 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": null,
485 | "id": "3dfc61d9",
486 | "metadata": {},
487 | "outputs": [],
488 | "source": [
489 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "id": "35c96cf8",
495 | "metadata": {
496 | "tags": []
497 | },
498 | "source": [
499 | "#### 3. 训练"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "id": "62e3e9ff",
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "model.fit(\n",
510 | " cosent_train_dataset,\n",
511 | " cosent_dev_dataset,\n",
512 | " lr=2e-5,\n",
513 | " epochs=num_epoches,\n",
514 | " batch_size=batch_size,\n",
515 | " params=optimizer_grouped_parameters\n",
516 | ")"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "id": "9b27e57b",
522 | "metadata": {},
523 | "source": [
524 | "
\n",
525 | "\n",
526 | "### 四、模型验证"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": null,
532 | "id": "e6fe3272-777a-45b5-b736-6546af69ec34",
533 | "metadata": {},
534 | "outputs": [],
535 | "source": [
536 | "import torch\n",
537 | "\n",
538 | "from torch.utils.data import DataLoader\n",
539 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
540 | "\n",
541 | "\n",
542 | "class CoSENTPredictor(SequenceClassificationPredictor):\n",
543 | " \"\"\"\n",
544 | " CoSENT的预测器\n",
545 | " \n",
546 | " Args:\n",
547 | " module: 深度学习模型\n",
548 | " tokernizer: 分词器\n",
549 | " cat2id (:obj:`dict`): 标签映射\n",
550 | " \"\"\" # noqa: ignore flake8\"\n",
551 | "\n",
552 | " def _get_input_ids(\n",
553 | " self,\n",
554 | " text_a,\n",
555 | " text_b\n",
556 | " ):\n",
557 | " if self.tokenizer.tokenizer_type == 'vanilla':\n",
558 | " return self._convert_to_vanilla_ids(text_a, text_b)\n",
559 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n",
560 | " return self._convert_to_transfomer_ids(text_a, text_b)\n",
561 | " elif self.tokenizer.tokenizer_type == 'customized':\n",
562 | " return self._convert_to_customized_ids(text_a, text_b)\n",
563 | " else:\n",
564 | " raise ValueError(\"The tokenizer type does not exist\")\n",
565 | "\n",
566 | " def _convert_to_transfomer_ids(\n",
567 | " self,\n",
568 | " text_a,\n",
569 | " text_b\n",
570 | " ):\n",
571 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
572 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
573 | "\n",
574 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
575 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
576 | "\n",
577 | " features = {\n",
578 | " 'input_ids_a': input_ids_a,\n",
579 | " 'attention_mask_a': input_mask_a,\n",
580 | " 'token_type_ids_a': segment_ids_a,\n",
581 | " 'input_ids_b': input_ids_b,\n",
582 | " 'attention_mask_b': input_mask_b,\n",
583 | " 'token_type_ids_b': segment_ids_b\n",
584 | " }\n",
585 | "\n",
586 | " return features\n",
587 | "\n",
588 | " def predict_one_sample(\n",
589 | " self,\n",
590 | " text,\n",
591 | " topk=None,\n",
592 | " threshold=0.5,\n",
593 | " return_label_name=True,\n",
594 | " return_proba=False\n",
595 | " ):\n",
596 | " if topk is None:\n",
597 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
598 | " text_a, text_b = text\n",
599 | " features = self._get_input_ids(text_a, text_b)\n",
600 | " self.module.eval()\n",
601 | "\n",
602 | " with torch.no_grad():\n",
603 | " inputs = self._get_module_one_sample_inputs(features)\n",
604 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
605 | "\n",
606 | " _proba = logits[0]\n",
607 | " \n",
608 | " if threshold is not None:\n",
609 | " _pred = self._threshold(_proba, threshold)\n",
610 | "\n",
611 | " if return_label_name and threshold is not None:\n",
612 | " _pred = self.id2cat[_pred]\n",
613 | "\n",
614 | " if threshold is not None:\n",
615 | " if return_proba:\n",
616 | " return [_pred, _proba]\n",
617 | " else:\n",
618 | " return _pred\n",
619 | "\n",
620 | " return _proba\n",
621 | "\n",
622 | " def predict_batch(\n",
623 | " self,\n",
624 | " test_data,\n",
625 | " batch_size=16,\n",
626 | " shuffle=False\n",
627 | " ):\n",
628 | " self.inputs_cols = test_data.dataset_cols\n",
629 | "\n",
630 | " preds = []\n",
631 | "\n",
632 | " self.module.eval()\n",
633 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
634 | "\n",
635 | " with torch.no_grad():\n",
636 | " for step, inputs in enumerate(generator):\n",
637 | " inputs = self._get_module_batch_inputs(inputs)\n",
638 | "\n",
639 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
640 | "\n",
641 | " preds.extend(logits)\n",
642 | "\n",
643 | " return preds"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": null,
649 | "id": "2d14331b-bf8f-4431-9f8f-e6947cbe7bfd",
650 | "metadata": {},
651 | "outputs": [],
652 | "source": [
653 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
654 | ]
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": null,
659 | "id": "ab835fab-8d52-4fad-9610-a9125c56b825",
660 | "metadata": {},
661 | "outputs": [],
662 | "source": [
663 | "cosent_predictor_instance.predict_one_sample(\n",
664 | " ['用微信都6年,微信没有微粒贷功能', \n",
665 | " '4。 号码来微粒贷'],\n",
666 | " threshold=None\n",
667 | ")"
668 | ]
669 | }
670 | ],
671 | "metadata": {
672 | "kernelspec": {
673 | "display_name": "Python 3 (ipykernel)",
674 | "language": "python",
675 | "name": "python3"
676 | },
677 | "language_info": {
678 | "codemirror_mode": {
679 | "name": "ipython",
680 | "version": 3
681 | },
682 | "file_extension": ".py",
683 | "mimetype": "text/x-python",
684 | "name": "python",
685 | "nbconvert_exporter": "python",
686 | "pygments_lexer": "ipython3",
687 | "version": "3.8.10"
688 | }
689 | },
690 | "nbformat": 4,
691 | "nbformat_minor": 5
692 | }
693 |
--------------------------------------------------------------------------------
/CoSENT_CHIP-STS.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9ced17a4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import pandas as pd\n",
12 | "\n",
13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "ccea726a",
21 | "metadata": {
22 | "tags": []
23 | },
24 | "source": [
25 | "### 一、数据读入与处理"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "8d5c3337",
31 | "metadata": {
32 | "tags": []
33 | },
34 | "source": [
35 | "#### 1. 数据读入"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "train_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_train.json')\n",
46 | "train_data_df = (train_data_df\n",
47 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n",
48 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])\n",
49 | "\n",
50 | "dev_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_dev.json')\n",
51 | "dev_data_df = dev_data_df[dev_data_df['label'] != \"NA\"]\n",
52 | "dev_data_df = (dev_data_df\n",
53 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n",
54 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "90876062",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "cosent_train_dataset = Dataset(train_data_df)\n",
65 | "cosent_dev_dataset = Dataset(dev_data_df)"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "id": "e061890a",
71 | "metadata": {},
72 | "source": [
73 | "#### 2. 词典创建和生成分词器"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "id": "be116454",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "# 加载分词器\n",
84 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "id": "0d6c3b3d",
90 | "metadata": {},
91 | "source": [
92 | "#### 3. ID化"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "id": "566dd6b6",
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "cosent_train_dataset.convert_to_ids(tokenizer)\n",
103 | "cosent_dev_dataset.convert_to_ids(tokenizer)"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "981b4160",
109 | "metadata": {},
110 | "source": [
111 | "
\n",
112 | "\n",
113 | "### 二、模型构建"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "id": "72753ee8",
119 | "metadata": {},
120 | "source": [
121 | "#### 1. 模型参数设置"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "id": "527535d3",
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "from transformers import BertConfig\n",
132 | "\n",
133 | "bert_config = BertConfig.from_pretrained(\n",
134 | " 'bert-base-chinese',\n",
135 | " num_labels=len(cosent_train_dataset.cat2id)\n",
136 | ")"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "77d8b580",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "torch.cuda.empty_cache()"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "id": "700d7752",
152 | "metadata": {},
153 | "source": [
154 | "#### 2. 模型创建"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "id": "58f380f4",
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "import torch\n",
165 | "import torch.nn.functional as F\n",
166 | "\n",
167 | "from torch import nn\n",
168 | "from transformers import BertModel\n",
169 | "from ark_nlp.nn import Bert\n",
170 | "\n",
171 | "\n",
172 | "class CoSENT(Bert):\n",
173 | " \"\"\"\n",
174 | " CoSENT模型\n",
175 | "\n",
176 | " Args:\n",
177 | " config:\n",
178 | " 模型的配置对象\n",
179 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n",
180 | " bert参数是否可训练,默认可训练\n",
181 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
182 | " bert输出的池化方式,默认为\"last_avg\",\n",
183 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
184 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
185 | " dropout比例,默认为None,实际设置时会设置成0\n",
186 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n",
187 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n",
188 | "\n",
189 | " Reference:\n",
190 | " [1] https://kexue.fm/archives/8847\n",
191 | " [2] https://github.com/bojone/CoSENT \n",
192 | " \"\"\" # noqa: ignore flake8\"\n",
193 | "\n",
194 | " def __init__(\n",
195 | " self,\n",
196 | " config,\n",
197 | " encoder_trained=True,\n",
198 | " pooling='last_avg',\n",
199 | " dropout=None,\n",
200 | " output_emb_size=0\n",
201 | " ):\n",
202 | "\n",
203 | " super(CoSENT, self).__init__(config)\n",
204 | "\n",
205 | " self.bert = BertModel(config)\n",
206 | " self.pooling = pooling\n",
207 | "\n",
208 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0)\n",
209 | "\n",
210 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
211 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
212 | " # recall performance and efficiency\n",
213 | " self.output_emb_size = output_emb_size\n",
214 | " if self.output_emb_size > 0:\n",
215 | " self.emb_reduce_linear = nn.Linear(\n",
216 | " config.hidden_size,\n",
217 | " self.output_emb_size\n",
218 | " )\n",
219 | " torch.nn.init.trunc_normal_(\n",
220 | " self.emb_reduce_linear.weight,\n",
221 | " std=0.02\n",
222 | " )\n",
223 | "\n",
224 | " for param in self.bert.parameters():\n",
225 | " param.requires_grad = encoder_trained\n",
226 | "\n",
227 | " self.init_weights()\n",
228 | "\n",
229 | " def get_pooled_embedding(\n",
230 | " self,\n",
231 | " input_ids,\n",
232 | " token_type_ids=None,\n",
233 | " position_ids=None,\n",
234 | " attention_mask=None\n",
235 | " ):\n",
236 | " outputs = self.bert(\n",
237 | " input_ids,\n",
238 | " attention_mask=attention_mask,\n",
239 | " token_type_ids=token_type_ids,\n",
240 | " position_ids=position_ids,\n",
241 | " return_dict=True,\n",
242 | " output_hidden_states=True\n",
243 | " )\n",
244 | "\n",
245 | " encoder_feature = self.get_encoder_feature(\n",
246 | " outputs,\n",
247 | " attention_mask\n",
248 | " )\n",
249 | "\n",
250 | " if self.output_emb_size > 0:\n",
251 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
252 | "\n",
253 | " encoder_feature = self.dropout(encoder_feature)\n",
254 | " out = F.normalize(encoder_feature, p=2, dim=-1)\n",
255 | "\n",
256 | " return out\n",
257 | "\n",
258 | " def cosine_sim(\n",
259 | " self,\n",
260 | " input_ids_a,\n",
261 | " input_ids_b,\n",
262 | " token_type_ids_a=None,\n",
263 | " position_ids_ids_a=None,\n",
264 | " attention_mask_a=None,\n",
265 | " token_type_ids_b=None,\n",
266 | " position_ids_b=None,\n",
267 | " attention_mask_b=None,\n",
268 | " **kwargs\n",
269 | " ):\n",
270 | "\n",
271 | " query_cls_embedding = self.get_pooled_embedding(\n",
272 | " input_ids_a,\n",
273 | " token_type_ids_a,\n",
274 | " position_ids_ids_a,\n",
275 | " attention_mask_a\n",
276 | " )\n",
277 | "\n",
278 | " title_cls_embedding = self.get_pooled_embedding(\n",
279 | " input_ids_b,\n",
280 | " token_type_ids_b,\n",
281 | " position_ids_b,\n",
282 | " attention_mask_b\n",
283 | " )\n",
284 | "\n",
285 | " cosine_sim = torch.sum(\n",
286 | " query_cls_embedding * title_cls_embedding,\n",
287 | " axis=-1\n",
288 | " )\n",
289 | "\n",
290 | " return cosine_sim\n",
291 | "\n",
292 | " def forward(\n",
293 | " self,\n",
294 | " input_ids_a,\n",
295 | " input_ids_b,\n",
296 | " token_type_ids_a=None,\n",
297 | " position_ids_ids_a=None,\n",
298 | " attention_mask_a=None,\n",
299 | " token_type_ids_b=None,\n",
300 | " position_ids_b=None,\n",
301 | " attention_mask_b=None,\n",
302 | " label_ids=None,\n",
303 | " **kwargs\n",
304 | " ):\n",
305 | "\n",
306 | " cls_embedding_a = self.get_pooled_embedding(\n",
307 | " input_ids_a,\n",
308 | " token_type_ids_a,\n",
309 | " position_ids_ids_a,\n",
310 | " attention_mask_a\n",
311 | " )\n",
312 | "\n",
313 | " cls_embedding_b = self.get_pooled_embedding(\n",
314 | " input_ids_b,\n",
315 | " token_type_ids_b,\n",
316 | " position_ids_b,\n",
317 | " attention_mask_b\n",
318 | " )\n",
319 | "\n",
320 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
321 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
322 | " \n",
323 | " labels = label_ids[:, None] < label_ids[None, :]\n",
324 | " labels = labels.long()\n",
325 | " \n",
326 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
327 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
328 | " loss = torch.logsumexp(cosine_sim.view(-1), dim=0)\n",
329 | "\n",
330 | " return cosine_sim, loss\n"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "id": "e630530b",
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "dl_module = CoSENT.from_pretrained(\n",
341 | " 'bert-base-chinese', \n",
342 | " config=bert_config\n",
343 | ")"
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "id": "13e3c8ac",
349 | "metadata": {},
350 | "source": [
351 | "
\n",
352 | "\n",
353 | "### 三、任务构建"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "id": "31d1f76c",
359 | "metadata": {},
360 | "source": [
361 | "#### 1. 任务参数和必要部件设定"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": null,
367 | "id": "943bf64c",
368 | "metadata": {},
369 | "outputs": [],
370 | "source": [
371 | "# 设置运行次数\n",
372 | "num_epoches = 5\n",
373 | "batch_size = 32"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": null,
379 | "id": "74641ede",
380 | "metadata": {},
381 | "outputs": [],
382 | "source": [
383 | "param_optimizer = list(dl_module.named_parameters())\n",
384 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
385 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
386 | "optimizer_grouped_parameters = [\n",
387 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
388 | " 'weight_decay': 0.01},\n",
389 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
390 | "] "
391 | ]
392 | },
393 | {
394 | "cell_type": "markdown",
395 | "id": "bd5a9361",
396 | "metadata": {},
397 | "source": [
398 | "#### 2. 任务创建"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "id": "5bc04465",
405 | "metadata": {},
406 | "outputs": [],
407 | "source": [
408 | "import torch\n",
409 | "import numpy as np\n",
410 | "from scipy import stats\n",
411 | "\n",
412 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
413 | "\n",
414 | "\n",
415 | "class CoSENTTask(SequenceClassificationTask):\n",
416 | " \"\"\"\n",
417 | " 用于CoSENT模型文本匹配任务的Task\n",
418 | " \n",
419 | " Args:\n",
420 | " module: 深度学习模型\n",
421 | " optimizer: 训练模型使用的优化器名或者优化器对象\n",
422 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n",
423 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
424 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
425 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
426 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n",
427 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n",
428 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
429 | " **kwargs (optional): 其他可选参数\n",
430 | " \"\"\" # noqa: ignore flake8\"\n",
431 | "\n",
432 | " def _on_evaluate_begin_record(self, **kwargs):\n",
433 | "\n",
434 | " self.evaluate_logs['eval_loss'] = 0\n",
435 | " self.evaluate_logs['eval_step'] = 0\n",
436 | " self.evaluate_logs['eval_example'] = 0\n",
437 | "\n",
438 | " self.evaluate_logs['labels'] = []\n",
439 | " self.evaluate_logs['eval_sim'] = []\n",
440 | "\n",
441 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
442 | "\n",
443 | " with torch.no_grad():\n",
444 | " # compute loss\n",
445 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
446 | " self.evaluate_logs['eval_loss'] += loss.item()\n",
447 | "\n",
448 | " if 'label_ids' in inputs:\n",
449 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
450 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
451 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
452 | "\n",
453 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n",
454 | " self.evaluate_logs['eval_step'] += 1\n",
455 | "\n",
456 | " def _on_evaluate_epoch_end(\n",
457 | " self,\n",
458 | " validation_data,\n",
459 | " epoch=1,\n",
460 | " is_evaluate_print=True,\n",
461 | " **kwargs\n",
462 | " ):\n",
463 | "\n",
464 | " if is_evaluate_print:\n",
465 | " if 'labels' in self.evaluate_logs:\n",
466 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
467 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
468 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
469 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
470 | " spearman_corr,\n",
471 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
472 | " )\n",
473 | " )\n",
474 | " else:\n",
475 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "id": "3dfc61d9",
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
486 | ]
487 | },
488 | {
489 | "cell_type": "markdown",
490 | "id": "35c96cf8",
491 | "metadata": {
492 | "tags": []
493 | },
494 | "source": [
495 | "#### 3. 训练"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": null,
501 | "id": "62e3e9ff",
502 | "metadata": {},
503 | "outputs": [],
504 | "source": [
505 | "model.fit(\n",
506 | " cosent_train_dataset,\n",
507 | " cosent_dev_dataset,\n",
508 | " lr=2e-5,\n",
509 | " epochs=num_epoches,\n",
510 | " batch_size=batch_size,\n",
511 | " params=optimizer_grouped_parameters\n",
512 | ")"
513 | ]
514 | },
515 | {
516 | "cell_type": "markdown",
517 | "id": "9b27e57b",
518 | "metadata": {},
519 | "source": [
520 | "
\n",
521 | "\n",
522 | "### 四、模型验证"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "id": "971c58b0-23fe-45c2-a9e8-0c8d48ab9326",
529 | "metadata": {},
530 | "outputs": [],
531 | "source": [
532 | "import torch\n",
533 | "\n",
534 | "from torch.utils.data import DataLoader\n",
535 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
536 | "\n",
537 | "\n",
538 | "class CoSENTPredictor(SequenceClassificationPredictor):\n",
539 | " \"\"\"\n",
540 | " CoSENT的预测器\n",
541 | " \n",
542 | " Args:\n",
543 | " module: 深度学习模型\n",
544 | " tokernizer: 分词器\n",
545 | " cat2id (:obj:`dict`): 标签映射\n",
546 | " \"\"\" # noqa: ignore flake8\"\n",
547 | "\n",
548 | " def _get_input_ids(\n",
549 | " self,\n",
550 | " text_a,\n",
551 | " text_b\n",
552 | " ):\n",
553 | " if self.tokenizer.tokenizer_type == 'vanilla':\n",
554 | " return self._convert_to_vanilla_ids(text_a, text_b)\n",
555 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n",
556 | " return self._convert_to_transfomer_ids(text_a, text_b)\n",
557 | " elif self.tokenizer.tokenizer_type == 'customized':\n",
558 | " return self._convert_to_customized_ids(text_a, text_b)\n",
559 | " else:\n",
560 | " raise ValueError(\"The tokenizer type does not exist\")\n",
561 | "\n",
562 | " def _convert_to_transfomer_ids(\n",
563 | " self,\n",
564 | " text_a,\n",
565 | " text_b\n",
566 | " ):\n",
567 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
568 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
569 | "\n",
570 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
571 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
572 | "\n",
573 | " features = {\n",
574 | " 'input_ids_a': input_ids_a,\n",
575 | " 'attention_mask_a': input_mask_a,\n",
576 | " 'token_type_ids_a': segment_ids_a,\n",
577 | " 'input_ids_b': input_ids_b,\n",
578 | " 'attention_mask_b': input_mask_b,\n",
579 | " 'token_type_ids_b': segment_ids_b\n",
580 | " }\n",
581 | "\n",
582 | " return features\n",
583 | "\n",
584 | " def predict_one_sample(\n",
585 | " self,\n",
586 | " text,\n",
587 | " topk=None,\n",
588 | " threshold=0.5,\n",
589 | " return_label_name=True,\n",
590 | " return_proba=False\n",
591 | " ):\n",
592 | " if topk is None:\n",
593 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
594 | " text_a, text_b = text\n",
595 | " features = self._get_input_ids(text_a, text_b)\n",
596 | " self.module.eval()\n",
597 | "\n",
598 | " with torch.no_grad():\n",
599 | " inputs = self._get_module_one_sample_inputs(features)\n",
600 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
601 | "\n",
602 | " _proba = logits[0]\n",
603 | " \n",
604 | " if threshold is not None:\n",
605 | " _pred = self._threshold(_proba, threshold)\n",
606 | "\n",
607 | " if return_label_name and threshold is not None:\n",
608 | " _pred = self.id2cat[_pred]\n",
609 | "\n",
610 | " if threshold is not None:\n",
611 | " if return_proba:\n",
612 | " return [_pred, _proba]\n",
613 | " else:\n",
614 | " return _pred\n",
615 | "\n",
616 | " return _proba\n",
617 | "\n",
618 | " def predict_batch(\n",
619 | " self,\n",
620 | " test_data,\n",
621 | " batch_size=16,\n",
622 | " shuffle=False\n",
623 | " ):\n",
624 | " self.inputs_cols = test_data.dataset_cols\n",
625 | "\n",
626 | " preds = []\n",
627 | "\n",
628 | " self.module.eval()\n",
629 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
630 | "\n",
631 | " with torch.no_grad():\n",
632 | " for step, inputs in enumerate(generator):\n",
633 | " inputs = self._get_module_batch_inputs(inputs)\n",
634 | "\n",
635 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
636 | "\n",
637 | " preds.extend(logits)\n",
638 | "\n",
639 | " return preds"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": null,
645 | "id": "4b2f71d7-8547-4d17-95f2-77952673b871",
646 | "metadata": {},
647 | "outputs": [],
648 | "source": [
649 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
650 | ]
651 | },
652 | {
653 | "cell_type": "code",
654 | "execution_count": null,
655 | "id": "ec9094a9-cdee-4709-80c7-35285290a992",
656 | "metadata": {},
657 | "outputs": [],
658 | "source": [
659 | "cosent_predictor_instance.predict_one_sample(\n",
660 | " ['糖尿病能吃减肥药吗?能治愈吗?', \n",
661 | " '糖尿病为什么不能吃减肥药'],\n",
662 | " threshold=None\n",
663 | ")"
664 | ]
665 | },
666 | {
667 | "cell_type": "markdown",
668 | "id": "fd6fdff6-1a80-4f6e-bebd-82511121365c",
669 | "metadata": {},
670 | "source": [
671 | "### 五、CBLUE打榜提交"
672 | ]
673 | },
674 | {
675 | "cell_type": "code",
676 | "execution_count": null,
677 | "id": "4e66ade4-b0f6-4e77-863c-6c62762ef793",
678 | "metadata": {},
679 | "outputs": [],
680 | "source": [
681 | "import pandas as pd\n",
682 | "\n",
683 | "from tqdm import tqdm\n",
684 | "\n",
685 | "test_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_test.json')\n",
686 | "\n",
687 | "submit = []\n",
688 | "for _id, _text_a, _text_b, _condition in tqdm(zip(\n",
689 | " test_df['id'],\n",
690 | " test_df['text1'],\n",
691 | " test_df['text2'],\n",
692 | " test_df['category']\n",
693 | ")):\n",
694 | " if _condition == 'daibetes':\n",
695 | " _condition = 'diabetes'\n",
696 | "\n",
697 | " predict_ = cosent_predictor_instance.predict_one_sample([_text_a, _text_b], threshold=0.6)\n",
698 | " \n",
699 | " submit.append({\n",
700 | " 'id': str(_id),\n",
701 | " 'text1': _text_a,\n",
702 | " 'text2': _text_b,\n",
703 | " 'label': predict_,\n",
704 | " 'category': _condition\n",
705 | " })"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": null,
711 | "id": "1fa52f20-7f8f-4279-ae2b-fc1d54c5f0f2",
712 | "metadata": {},
713 | "outputs": [],
714 | "source": [
715 | "import json\n",
716 | "\n",
717 | "output_path = '../data/output_datasets/CHIP-STS_test.json'\n",
718 | "\n",
719 | "with open(output_path, 'w', encoding='utf-8') as f:\n",
720 | " f.write(json.dumps(submit, ensure_ascii=False))"
721 | ]
722 | }
723 | ],
724 | "metadata": {
725 | "kernelspec": {
726 | "display_name": "Python 3 (ipykernel)",
727 | "language": "python",
728 | "name": "python3"
729 | },
730 | "language_info": {
731 | "codemirror_mode": {
732 | "name": "ipython",
733 | "version": 3
734 | },
735 | "file_extension": ".py",
736 | "mimetype": "text/x-python",
737 | "name": "python",
738 | "nbconvert_exporter": "python",
739 | "pygments_lexer": "ipython3",
740 | "version": "3.8.10"
741 | }
742 | },
743 | "nbformat": 4,
744 | "nbformat_minor": 5
745 | }
746 |
--------------------------------------------------------------------------------
/CoSENT_LCQMC.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "9ced17a4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "import pandas as pd\n",
12 | "\n",
13 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
14 | "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
15 | "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "id": "bfcea417",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# 目录地址\n",
26 | "train_data_path = '../data/source_datasets/LCQMC/LCQMC.train.data'\n",
27 | "dev_data_path = '../data/source_datasets/LCQMC/LCQMC.test.data'"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "ccea726a",
33 | "metadata": {
34 | "tags": []
35 | },
36 | "source": [
37 | "### 一、数据读入与处理"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "8d5c3337",
43 | "metadata": {
44 | "tags": []
45 | },
46 | "source": [
47 | "#### 1. 数据读入"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n",
58 | "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "90876062",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "cosent_train_dataset = Dataset(train_data_df)\n",
69 | "cosent_dev_dataset = Dataset(dev_data_df)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "e061890a",
75 | "metadata": {},
76 | "source": [
77 | "#### 2. 词典创建和生成分词器"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "be116454",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "# 加载分词器\n",
88 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "0d6c3b3d",
94 | "metadata": {},
95 | "source": [
96 | "#### 3. ID化"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "566dd6b6",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "cosent_train_dataset.convert_to_ids(tokenizer)\n",
107 | "cosent_dev_dataset.convert_to_ids(tokenizer)"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "id": "981b4160",
113 | "metadata": {},
114 | "source": [
115 | "
\n",
116 | "\n",
117 | "### 二、模型构建"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "id": "72753ee8",
123 | "metadata": {},
124 | "source": [
125 | "#### 1. 模型参数设置"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "id": "527535d3",
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "from transformers import BertConfig\n",
136 | "\n",
137 | "bert_config = BertConfig.from_pretrained(\n",
138 | " 'bert-base-chinese',\n",
139 | " num_labels=2\n",
140 | ")"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "id": "77d8b580",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "torch.cuda.empty_cache()"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "700d7752",
156 | "metadata": {},
157 | "source": [
158 | "#### 2. 模型创建"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "id": "58f380f4",
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "import torch\n",
169 | "import torch.nn.functional as F\n",
170 | "\n",
171 | "from torch import nn\n",
172 | "from transformers import BertModel\n",
173 | "from ark_nlp.nn import Bert\n",
174 | "\n",
175 | "\n",
176 | "class CoSENT(Bert):\n",
177 | " \"\"\"\n",
178 | " CoSENT模型\n",
179 | "\n",
180 | " Args:\n",
181 | " config:\n",
182 | " 模型的配置对象\n",
183 | " encoder_trained (:obj:`bool`, optional, defaults to True):\n",
184 | " bert参数是否可训练,默认可训练\n",
185 | " pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
186 | " bert输出的池化方式,默认为\"last_avg\",\n",
187 | " 可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
188 | " dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
189 | " dropout比例,默认为None,实际设置时会设置成0\n",
190 | " output_emb_size (:obj:`int`, optional, defaults to 0):\n",
191 | " 输出的矩阵的维度,默认为0,即不进行矩阵维度变换\n",
192 | "\n",
193 | " Reference:\n",
194 | " [1] https://kexue.fm/archives/8847\n",
195 | " [2] https://github.com/bojone/CoSENT \n",
196 | " \"\"\" # noqa: ignore flake8\"\n",
197 | "\n",
198 | " def __init__(\n",
199 | " self,\n",
200 | " config,\n",
201 | " encoder_trained=True,\n",
202 | " pooling='last_avg',\n",
203 | " dropout=None,\n",
204 | " output_emb_size=0\n",
205 | " ):\n",
206 | "\n",
207 | " super(CoSENT, self).__init__(config)\n",
208 | "\n",
209 | " self.bert = BertModel(config)\n",
210 | " self.pooling = pooling\n",
211 | "\n",
212 | " self.dropout = nn.Dropout(dropout if dropout is not None else 0)\n",
213 | "\n",
214 | " # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
215 | " # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
216 | " # recall performance and efficiency\n",
217 | " self.output_emb_size = output_emb_size\n",
218 | " if self.output_emb_size > 0:\n",
219 | " self.emb_reduce_linear = nn.Linear(\n",
220 | " config.hidden_size,\n",
221 | " self.output_emb_size\n",
222 | " )\n",
223 | " torch.nn.init.trunc_normal_(\n",
224 | " self.emb_reduce_linear.weight,\n",
225 | " std=0.02\n",
226 | " )\n",
227 | "\n",
228 | " for param in self.bert.parameters():\n",
229 | " param.requires_grad = encoder_trained\n",
230 | "\n",
231 | " self.init_weights()\n",
232 | "\n",
233 | " def get_pooled_embedding(\n",
234 | " self,\n",
235 | " input_ids,\n",
236 | " token_type_ids=None,\n",
237 | " position_ids=None,\n",
238 | " attention_mask=None\n",
239 | " ):\n",
240 | " outputs = self.bert(\n",
241 | " input_ids,\n",
242 | " attention_mask=attention_mask,\n",
243 | " token_type_ids=token_type_ids,\n",
244 | " position_ids=position_ids,\n",
245 | " return_dict=True,\n",
246 | " output_hidden_states=True\n",
247 | " )\n",
248 | "\n",
249 | " encoder_feature = self.get_encoder_feature(\n",
250 | " outputs,\n",
251 | " attention_mask\n",
252 | " )\n",
253 | "\n",
254 | " if self.output_emb_size > 0:\n",
255 | " encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
256 | "\n",
257 | " encoder_feature = self.dropout(encoder_feature)\n",
258 | " out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n",
259 | "\n",
260 | " return out\n",
261 | "\n",
262 | " def cosine_sim(\n",
263 | " self,\n",
264 | " input_ids_a,\n",
265 | " input_ids_b,\n",
266 | " token_type_ids_a=None,\n",
267 | " position_ids_ids_a=None,\n",
268 | " attention_mask_a=None,\n",
269 | " token_type_ids_b=None,\n",
270 | " position_ids_b=None,\n",
271 | " attention_mask_b=None,\n",
272 | " **kwargs\n",
273 | " ):\n",
274 | "\n",
275 | " query_cls_embedding = self.get_pooled_embedding(\n",
276 | " input_ids_a,\n",
277 | " token_type_ids_a,\n",
278 | " position_ids_ids_a,\n",
279 | " attention_mask_a\n",
280 | " )\n",
281 | "\n",
282 | " title_cls_embedding = self.get_pooled_embedding(\n",
283 | " input_ids_b,\n",
284 | " token_type_ids_b,\n",
285 | " position_ids_b,\n",
286 | " attention_mask_b\n",
287 | " )\n",
288 | "\n",
289 | " cosine_sim = torch.sum(\n",
290 | " query_cls_embedding * title_cls_embedding,\n",
291 | " axis=-1\n",
292 | " )\n",
293 | "\n",
294 | " return cosine_sim\n",
295 | "\n",
296 | " def forward(\n",
297 | " self,\n",
298 | " input_ids_a,\n",
299 | " input_ids_b,\n",
300 | " token_type_ids_a=None,\n",
301 | " position_ids_ids_a=None,\n",
302 | " attention_mask_a=None,\n",
303 | " token_type_ids_b=None,\n",
304 | " position_ids_b=None,\n",
305 | " attention_mask_b=None,\n",
306 | " label_ids=None,\n",
307 | " **kwargs\n",
308 | " ):\n",
309 | "\n",
310 | " cls_embedding_a = self.get_pooled_embedding(\n",
311 | " input_ids_a,\n",
312 | " token_type_ids_a,\n",
313 | " position_ids_ids_a,\n",
314 | " attention_mask_a\n",
315 | " )\n",
316 | "\n",
317 | " cls_embedding_b = self.get_pooled_embedding(\n",
318 | " input_ids_b,\n",
319 | " token_type_ids_b,\n",
320 | " position_ids_b,\n",
321 | " attention_mask_b\n",
322 | " )\n",
323 | "\n",
324 | " cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
325 | " cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
326 | " \n",
327 | " labels = label_ids[:, None] < label_ids[None, :]\n",
328 | " labels = labels.long()\n",
329 | " \n",
330 | " cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
331 | " cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
332 | " loss = torch.logsumexp(cosine_sim.view(-1), dim=0)\n",
333 | "\n",
334 | " return cosine_sim, loss\n"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "id": "e630530b",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "dl_module = CoSENT.from_pretrained(\n",
345 | " 'bert-base-chinese', \n",
346 | " config=bert_config\n",
347 | ")"
348 | ]
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "id": "13e3c8ac",
353 | "metadata": {},
354 | "source": [
355 | "
\n",
356 | "\n",
357 | "### 三、任务构建"
358 | ]
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "id": "31d1f76c",
363 | "metadata": {},
364 | "source": [
365 | "#### 1. 任务参数和必要部件设定"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "id": "943bf64c",
372 | "metadata": {},
373 | "outputs": [],
374 | "source": [
375 | "# 设置运行次数\n",
376 | "num_epoches = 5\n",
377 | "batch_size = 32"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "id": "74641ede",
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "param_optimizer = list(dl_module.named_parameters())\n",
388 | "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
389 | "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
390 | "optimizer_grouped_parameters = [\n",
391 | " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
392 | " 'weight_decay': 0.01},\n",
393 | " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
394 | "] "
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "id": "bd5a9361",
400 | "metadata": {},
401 | "source": [
402 | "#### 2. 任务创建"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "id": "5bc04465",
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "import torch\n",
413 | "import numpy as np\n",
414 | "from scipy import stats\n",
415 | "\n",
416 | "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
417 | "\n",
418 | "\n",
419 | "class CoSENTTask(SequenceClassificationTask):\n",
420 | " \"\"\"\n",
421 | " 用于CoSENT模型文本匹配任务的Task\n",
422 | " \n",
423 | " Args:\n",
424 | " module: 深度学习模型\n",
425 | " optimizer: 训练模型使用的优化器名或者优化器对象\n",
426 | " loss_function: 训练模型使用的损失函数名或损失函数对象\n",
427 | " class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
428 | " scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
429 | " n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
430 | " device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU\n",
431 | " cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device\n",
432 | " ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
433 | " **kwargs (optional): 其他可选参数\n",
434 | " \"\"\" # noqa: ignore flake8\"\n",
435 | "\n",
436 | " def _on_evaluate_begin_record(self, **kwargs):\n",
437 | "\n",
438 | " self.evaluate_logs['eval_loss'] = 0\n",
439 | " self.evaluate_logs['eval_step'] = 0\n",
440 | " self.evaluate_logs['eval_example'] = 0\n",
441 | "\n",
442 | " self.evaluate_logs['labels'] = []\n",
443 | " self.evaluate_logs['eval_sim'] = []\n",
444 | "\n",
445 | " def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
446 | "\n",
447 | " with torch.no_grad():\n",
448 | " # compute loss\n",
449 | " logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
450 | " self.evaluate_logs['eval_loss'] += loss.item()\n",
451 | "\n",
452 | " if 'label_ids' in inputs:\n",
453 | " cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
454 | " self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
455 | " self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
456 | "\n",
457 | " self.evaluate_logs['eval_example'] += logits.shape[0]\n",
458 | " self.evaluate_logs['eval_step'] += 1\n",
459 | "\n",
460 | " def _on_evaluate_epoch_end(\n",
461 | " self,\n",
462 | " validation_data,\n",
463 | " epoch=1,\n",
464 | " is_evaluate_print=True,\n",
465 | " **kwargs\n",
466 | " ):\n",
467 | "\n",
468 | " if is_evaluate_print:\n",
469 | " if 'labels' in self.evaluate_logs:\n",
470 | " _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
471 | " _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
472 | " spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
473 | " print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
474 | " spearman_corr,\n",
475 | " self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
476 | " )\n",
477 | " )\n",
478 | " else:\n",
479 | " print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": null,
485 | "id": "3dfc61d9",
486 | "metadata": {},
487 | "outputs": [],
488 | "source": [
489 | "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
490 | ]
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "id": "35c96cf8",
495 | "metadata": {
496 | "tags": []
497 | },
498 | "source": [
499 | "#### 3. 训练"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "id": "62e3e9ff",
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "model.fit(\n",
510 | " cosent_train_dataset,\n",
511 | " cosent_dev_dataset,\n",
512 | " lr=2e-5,\n",
513 | " epochs=num_epoches,\n",
514 | " batch_size=batch_size,\n",
515 | " params=optimizer_grouped_parameters\n",
516 | ")"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "id": "9b27e57b",
522 | "metadata": {},
523 | "source": [
524 | "
\n",
525 | "\n",
526 | "### 四、模型验证"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": null,
532 | "id": "6b296f3f-7adb-41df-9a4e-43847b17a900",
533 | "metadata": {},
534 | "outputs": [],
535 | "source": [
536 | "import torch\n",
537 | "\n",
538 | "from torch.utils.data import DataLoader\n",
539 | "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
540 | "\n",
541 | "\n",
542 | "class CoSENTPredictor(SequenceClassificationPredictor):\n",
543 | " \"\"\"\n",
544 | " CoSENT的预测器\n",
545 | " \n",
546 | " Args:\n",
547 | " module: 深度学习模型\n",
548 | " tokernizer: 分词器\n",
549 | " cat2id (:obj:`dict`): 标签映射\n",
550 | " \"\"\" # noqa: ignore flake8\"\n",
551 | "\n",
552 | " def _get_input_ids(\n",
553 | " self,\n",
554 | " text_a,\n",
555 | " text_b\n",
556 | " ):\n",
557 | " if self.tokenizer.tokenizer_type == 'vanilla':\n",
558 | " return self._convert_to_vanilla_ids(text_a, text_b)\n",
559 | " elif self.tokenizer.tokenizer_type == 'transfomer':\n",
560 | " return self._convert_to_transfomer_ids(text_a, text_b)\n",
561 | " elif self.tokenizer.tokenizer_type == 'customized':\n",
562 | " return self._convert_to_customized_ids(text_a, text_b)\n",
563 | " else:\n",
564 | " raise ValueError(\"The tokenizer type does not exist\")\n",
565 | "\n",
566 | " def _convert_to_transfomer_ids(\n",
567 | " self,\n",
568 | " text_a,\n",
569 | " text_b\n",
570 | " ):\n",
571 | " input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
572 | " input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
573 | "\n",
574 | " input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
575 | " input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
576 | "\n",
577 | " features = {\n",
578 | " 'input_ids_a': input_ids_a,\n",
579 | " 'attention_mask_a': input_mask_a,\n",
580 | " 'token_type_ids_a': segment_ids_a,\n",
581 | " 'input_ids_b': input_ids_b,\n",
582 | " 'attention_mask_b': input_mask_b,\n",
583 | " 'token_type_ids_b': segment_ids_b\n",
584 | " }\n",
585 | "\n",
586 | " return features\n",
587 | "\n",
588 | " def predict_one_sample(\n",
589 | " self,\n",
590 | " text,\n",
591 | " topk=None,\n",
592 | " threshold=0.5,\n",
593 | " return_label_name=True,\n",
594 | " return_proba=False\n",
595 | " ):\n",
596 | " if topk is None:\n",
597 | " topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
598 | " text_a, text_b = text\n",
599 | " features = self._get_input_ids(text_a, text_b)\n",
600 | " self.module.eval()\n",
601 | "\n",
602 | " with torch.no_grad():\n",
603 | " inputs = self._get_module_one_sample_inputs(features)\n",
604 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
605 | "\n",
606 | " _proba = logits[0]\n",
607 | " \n",
608 | " if threshold is not None:\n",
609 | " _pred = self._threshold(_proba, threshold)\n",
610 | "\n",
611 | " if return_label_name and threshold is not None:\n",
612 | " _pred = self.id2cat[_pred]\n",
613 | "\n",
614 | " if threshold is not None:\n",
615 | " if return_proba:\n",
616 | " return [_pred, _proba]\n",
617 | " else:\n",
618 | " return _pred\n",
619 | "\n",
620 | " return _proba\n",
621 | "\n",
622 | " def predict_batch(\n",
623 | " self,\n",
624 | " test_data,\n",
625 | " batch_size=16,\n",
626 | " shuffle=False\n",
627 | " ):\n",
628 | " self.inputs_cols = test_data.dataset_cols\n",
629 | "\n",
630 | " preds = []\n",
631 | "\n",
632 | " self.module.eval()\n",
633 | " generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
634 | "\n",
635 | " with torch.no_grad():\n",
636 | " for step, inputs in enumerate(generator):\n",
637 | " inputs = self._get_module_batch_inputs(inputs)\n",
638 | "\n",
639 | " logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
640 | "\n",
641 | " preds.extend(logits)\n",
642 | "\n",
643 | " return preds"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": null,
649 | "id": "b1df061b",
650 | "metadata": {},
651 | "outputs": [],
652 | "source": [
653 | "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
654 | ]
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": null,
659 | "id": "24994fe9",
660 | "metadata": {},
661 | "outputs": [],
662 | "source": [
663 | "cosent_predictor_instance.predict_one_sample(\n",
664 | " ['喜欢打篮球的男生喜欢什么样的女生', \n",
665 | " '爱打篮球的男生喜欢什么样的女生'],\n",
666 | " threshold=None\n",
667 | ")"
668 | ]
669 | }
670 | ],
671 | "metadata": {
672 | "kernelspec": {
673 | "display_name": "Python 3 (ipykernel)",
674 | "language": "python",
675 | "name": "python3"
676 | },
677 | "language_info": {
678 | "codemirror_mode": {
679 | "name": "ipython",
680 | "version": 3
681 | },
682 | "file_extension": ".py",
683 | "mimetype": "text/x-python",
684 | "name": "python",
685 | "nbconvert_exporter": "python",
686 | "pygments_lexer": "ipython3",
687 | "version": "3.8.10"
688 | }
689 | },
690 | "nbformat": 4,
691 | "nbformat_minor": 5
692 | }
693 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch_CoSENT
2 |
3 | ## 简介
4 |
5 | 比Sentence-BERT更有效的句向量方案,复现苏神提出的[CoSENT](https://github.com/bojone/CoSENT),模型细节可以参考苏神的文章:https://kexue.fm/archives/8847
6 |
7 | **数据下载**
8 |
9 | * ATEC、BQ、LCQMC和PAWSX:https://github.com/bojone/BERT-whitening/tree/main/chn
10 | * CHIP-STS(平安医疗科技疾病问答迁移学习):https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414
11 |
12 |
13 | ## 环境
14 |
15 | ```
16 | pip install ark-nlp
17 | pip install pandas
18 | ```
19 |
20 | ## 使用说明
21 |
22 | 项目目录按以下格式设置
23 |
24 | ```shell
25 | │
26 | ├── data # 数据文件夹
27 | │ ├── source_datasets
28 | │ ├── task_datasets
29 | │ └── output_datasets
30 | │
31 | ├── checkpoint # 存放训练好的模型
32 | │ ├── ...
33 | │ └── ...
34 | │
35 | └── code # 代码
36 | ```
37 | 下载数据并解压到`data/source_datasets`中,运行`code`文件夹中的`.ipynb`文件,最终提交文件会生成在`data/output_datasets`
38 |
39 | ## 参数设置
40 |
41 | 代码参数设置如下:
42 |
43 | ```
44 | 句子截断长度:64(PAWSX数据集截断长度为128)
45 | batch_size:32
46 | epochs:5
47 | ```
48 |
49 |
50 | ## 效果
51 |
52 | 使用spearman系数作为测评指标,ATEC、BQ、LCQMC和PAWSX使用test集进行测试实验,CHIP-STS则使用验证集
53 |
54 | | | ATEC | BQ | LCQMC | PAWSX | CHIP-STS |
55 | | :-: | :-: | :-: | :-: | :-: | :-: |
56 | | BERT+CoSENT(ark-nlp) | 49.80 | 72.46 | 79.00 | 59.17 | 76.22 |
57 | | BERT+CoSENT(bert4keras) | 49.74 | 72.38 | 78.69 | 60.00 | |
58 | | Sentence-BERT(bert4keras)| 46.36 | 70.36 | 78.72 | 46.86 | |
59 |
60 | PS:上表ark-nlp展示的是5轮里最好的结果,由于没有深入了解bert4keras,所以设置参数可能还是存在差异,因此对比仅供参考
61 |
62 | 针对CHIP-STS测试数据集,选择阈值为0.6生成结果进行提交,结果如下:
63 |
64 | | | Precision | Recall | Macro-F1|
65 | | :-: | :-: | :-: | :-:|
66 | | BERT+CoSENT| 82.89 | 81.528 | 81.688 |
67 | | BERT句子对分类 | 84.331 | 83.799 | 83.924|
68 |
69 |
70 | ## Acknowledge
71 |
72 | 感谢苏神无私的分享
--------------------------------------------------------------------------------
/text_match_CHIP-STS.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import warnings\n",
10 | "warnings.filterwarnings(\"ignore\")\n",
11 | "\n",
12 | "import os\n",
13 | "import jieba\n",
14 | "import torch\n",
15 | "import pickle\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import pandas as pd\n",
19 | "\n",
20 | "from ark_nlp.model.tm.bert import Bert\n",
21 | "from ark_nlp.model.tm.bert import BertConfig\n",
22 | "from ark_nlp.model.tm.bert import Dataset\n",
23 | "from ark_nlp.model.tm.bert import Task\n",
24 | "from ark_nlp.model.tm.bert import get_default_model_optimizer\n",
25 | "from ark_nlp.model.tm.bert import Tokenizer"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "### 一、数据读入与处理"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "#### 1. 数据读入"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "train_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_train.json')\n",
49 | "train_data_df = (train_data_df\n",
50 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n",
51 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])\n",
52 | "\n",
53 | "dev_data_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_dev.json')\n",
54 | "dev_data_df = dev_data_df[dev_data_df['label'] != \"NA\"]\n",
55 | "dev_data_df = (dev_data_df\n",
56 | " .rename(columns={'text1': 'text_a', 'text2': 'text_b', 'category': 'condition'})\n",
57 | " .loc[:,['text_a', 'text_b', 'condition', 'label']])"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "tm_train_dataset = Dataset(train_data_df)\n",
67 | "tm_dev_dataset = Dataset(dev_data_df)"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "#### 2. 词典创建和生成分词器"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "#### 3. ID化"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "tm_train_dataset.convert_to_ids(tokenizer)\n",
100 | "tm_dev_dataset.convert_to_ids(tokenizer)"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "
\n",
108 | "\n",
109 | "### 二、模型构建"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "#### 1. 模型参数设置"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "config = BertConfig.from_pretrained('bert-base-chinese',\n",
126 | " num_labels=len(tm_train_dataset.cat2id))"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "#### 2. 模型创建"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "torch.cuda.empty_cache()"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "dl_module = Bert.from_pretrained('bert-base-chinese', \n",
152 | " config=config)"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "metadata": {},
158 | "source": [
159 | "
\n",
160 | "\n",
161 | "### 三、任务构建"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "#### 1. 任务参数和必要部件设定"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "# 设置运行次数\n",
178 | "num_epoches = 5\n",
179 | "batch_size = 32"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "optimizer = get_default_model_optimizer(dl_module)"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "#### 2. 任务创建"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "model = Task(dl_module, optimizer, 'ce', cuda_device=0)"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": [
211 | "#### 3. 训练"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "model.fit(tm_train_dataset, \n",
221 | " tm_dev_dataset,\n",
222 | " lr=2e-5,\n",
223 | " epochs=num_epoches, \n",
224 | " batch_size=batch_size\n",
225 | " )"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "
\n",
233 | "\n",
234 | "### 四、CBLUE提交生成"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "from ark_nlp.model.tm.bert import Predictor"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "tm_predictor_instance = Predictor(model.module, tokenizer, tm_train_dataset.cat2id)"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "import pandas as pd\n",
262 | "\n",
263 | "from tqdm import tqdm\n",
264 | "\n",
265 | "test_df = pd.read_json('../data/source_datasets/CHIP-STS/CHIP-STS_test.json')\n",
266 | "\n",
267 | "submit = []\n",
268 | "for _id, _text_a, _text_b, _condition in tqdm(zip(\n",
269 | " test_df['id'],\n",
270 | " test_df['text1'],\n",
271 | " test_df['text2'],\n",
272 | " test_df['category']\n",
273 | ")):\n",
274 | " if _condition == 'daibetes':\n",
275 | " _condition = 'diabetes'\n",
276 | "\n",
277 | " predict_ = tm_predictor_instance.predict_one_sample([_text_a, _text_b])[0]\n",
278 | " \n",
279 | " submit.append({\n",
280 | " 'id': str(_id),\n",
281 | " 'text1': _text_a,\n",
282 | " 'text2': _text_b,\n",
283 | " 'label': predict_,\n",
284 | " 'category': _condition\n",
285 | " })"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": null,
291 | "metadata": {},
292 | "outputs": [],
293 | "source": [
294 | "import json\n",
295 | "\n",
296 | "output_path = '../data/output_datasets/CHIP-STS_test.json'\n",
297 | "\n",
298 | "with open(output_path, 'w', encoding='utf-8') as f:\n",
299 | " f.write(json.dumps(submit, ensure_ascii=False))"
300 | ]
301 | }
302 | ],
303 | "metadata": {
304 | "kernelspec": {
305 | "display_name": "Python 3 (ipykernel)",
306 | "language": "python",
307 | "name": "python3"
308 | },
309 | "language_info": {
310 | "codemirror_mode": {
311 | "name": "ipython",
312 | "version": 3
313 | },
314 | "file_extension": ".py",
315 | "mimetype": "text/x-python",
316 | "name": "python",
317 | "nbconvert_exporter": "python",
318 | "pygments_lexer": "ipython3",
319 | "version": "3.8.10"
320 | }
321 | },
322 | "nbformat": 4,
323 | "nbformat_minor": 4
324 | }
325 |
--------------------------------------------------------------------------------