├── data └── .gitkeep ├── logs └── .gitkeep ├── models └── .gitkeep ├── raw_data └── .gitkeep ├── .gitignore ├── src ├── others │ ├── logging.py │ ├── metrics.py │ └── data_collator.py ├── preprocess.py ├── test.py ├── main.py ├── train.py ├── prepro │ └── json_to_data.py └── models │ └── sm_model.py ├── LICENSE └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /raw_data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.json 4 | *.arrow 5 | *.txt 6 | *.pt 7 | *.bin 8 | *.pth -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yicheng Zou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/others/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | from datasets import load_metric 3 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score 4 | 5 | 6 | class Metric(): 7 | def __init__(self): 8 | self.metric = load_metric('glue', 'qqp') 9 | 10 | def compute_metrics_f1(self, pred): 11 | predictions, labels = pred 12 | if isinstance(predictions, tuple): 13 | preds = predictions[0].argmax(-1) 14 | else: 15 | preds = predictions.argmax(-1) 16 | return self.metric.compute(predictions=preds, references=labels) 17 | 18 | def compute_metrics_macro_f1(self, pred): 19 | predictions, labels = pred 20 | if isinstance(predictions, tuple): 21 | preds = predictions[0].argmax(-1) 22 | else: 23 | preds = predictions.argmax(-1) 24 | precision, recall, f1, _ = precision_recall_fscore_support( 25 | labels, preds, average='macro' 26 | ) 27 | acc = accuracy_score(labels, preds) 28 | return { 29 | 'accuracy': acc, 30 | 'macro-f1': f1, 31 | 'precision': precision, 32 | 'recall': recall 33 | } 34 | -------------------------------------------------------------------------------- /src/others/data_collator.py: -------------------------------------------------------------------------------- 1 | from transformers.data.data_collator import DataCollatorWithPadding 2 | from typing import Any, Dict, List 3 | import torch 4 | 5 | 6 | class DataCollator(DataCollatorWithPadding): 7 | 8 | def __init__(self, args, tokenizer, padding=True): 9 | super(DataCollator, self).__init__(tokenizer, padding) 10 | self.args = args 11 | self.pad_id = 0 12 | 13 | def _pad(self, data, width=-1, dtype=torch.long): 14 | if (width == -1): 15 | width = max(len(d) for d in data) 16 | rtn_data = [d + [self.pad_id] * (width - len(d)) for d in data] 17 | return torch.tensor(rtn_data, dtype=dtype) 18 | 19 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 20 | 21 | """ 22 | features: 23 | input_ids, token_type_ids, attention_mask, labels, 24 | keyword_mask, context_mask, special_mask, 25 | origin_str, keywords 26 | """ 27 | batch = {} 28 | 29 | # process entity-masked sentence pairs 30 | features_new = list(map(lambda x: {"input_ids": x['input_ids'], 31 | "token_type_ids": x['token_type_ids'], 32 | "labels": x['labels'] if x.get('labels', 'no') != 'no' else x['label']}, features)) 33 | 34 | batch = self.tokenizer.pad( 35 | features_new, 36 | padding=self.padding, 37 | max_length=self.max_length, 38 | pad_to_multiple_of=self.pad_to_multiple_of, 39 | return_tensors=self.return_tensors, 40 | ) 41 | 42 | batch['attention_mask'] = self._pad([x['attention_mask'] for x in features]) 43 | if "keyword_mask" in features[0].keys(): 44 | batch['keyword_mask'] = self._pad([x['keyword_mask'] for x in features]) 45 | else: 46 | batch['keyword_mask'] = [] 47 | if "context_mask" in features[0].keys(): 48 | batch['context_mask'] = self._pad([x['context_mask'] for x in features]) 49 | else: 50 | batch['context_mask'] = [] 51 | if "special_mask" in features[0].keys(): 52 | batch['special_mask'] = self._pad([x['special_mask'] for x in features]) 53 | else: 54 | batch['special_mask'] = [] 55 | 56 | return batch 57 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import argparse 4 | from others.logging import init_logger 5 | from prepro import json_to_data as data_builder 6 | 7 | 8 | def str2bool(v): 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | 17 | if __name__ == '__main__': 18 | # ['roberta-base', 'roberta-large', 'bert-base-uncased', 'bert-large-uncased','albert-base-v2','albert-large-v2','microsoft/deberta-large','microsoft/deberta-base', 'funnel-transformer/medium'] 19 | # ['hfl/chinese-macbert-base','hfl/chinese-macbert-large'] 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-raw_path", default='raw_data/mrpc', type=str) 23 | parser.add_argument("-save_path", default='data', type=str) 24 | parser.add_argument("-n_cpus", default=4, type=int) 25 | parser.add_argument("-random_seed", default=666, type=int) 26 | 27 | # json_to_data args 28 | parser.add_argument('-num_class', default=2, type=int) 29 | parser.add_argument('-log_file', default='logs/json_to_data.log') 30 | parser.add_argument("-tokenizer", default="") 31 | parser.add_argument('-min_length', default=1, type=int) 32 | parser.add_argument('-max_length', default=150, type=int) 33 | parser.add_argument("-truncated", nargs='?', const=True, default=True) 34 | parser.add_argument("-shard_size", default=5000, type=int) 35 | 36 | args = parser.parse_args() 37 | init_logger(args.log_file) 38 | 39 | model_names_english = [ 40 | 'roberta-base', 41 | 'roberta-large', 42 | 'bert-base-uncased', 43 | 'bert-large-uncased', 44 | 'albert-base-v2', 45 | 'albert-large-v2', 46 | 'microsoft/deberta-large', 47 | 'microsoft/deberta-base', 48 | 'funnel-transformer/medium' 49 | ] 50 | 51 | model_names_chinese = [ 52 | 'hfl/chinese-macbert-base', 53 | 'hfl/chinese-macbert-large', 54 | 'hfl/chinese-roberta-wwm-ext', 55 | 'hfl/chinese-roberta-wwm-ext-large' 56 | ] 57 | data_name = args.raw_path.split('/')[-1] 58 | if data_name == 'medical': 59 | model_names = model_names_chinese 60 | else: 61 | model_names = model_names_english 62 | data_saved_path = args.save_path 63 | for raw_name in model_names: 64 | name = raw_name.replace('/', '-') 65 | args.save_path = data_saved_path + '/' + data_name + '/' + name 66 | args.tokenizer = raw_name 67 | data_builder.format_json_to_data(args) 68 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from datasets import load_from_disk 3 | from others.metrics import Metric 4 | from others.data_collator import DataCollator 5 | from models.sm_model import Model 6 | from transformers import TrainingArguments, Trainer 7 | from transformers import AutoModelForSequenceClassification 8 | from transformers import AutoTokenizer 9 | 10 | 11 | def test(args): 12 | print(args) 13 | if args.task == 'qqp' or args.task == 'medical': 14 | trainer_args = TrainingArguments( 15 | output_dir=args.model_path, 16 | evaluation_strategy="steps", 17 | per_device_train_batch_size=args.batch_size, 18 | per_device_eval_batch_size=args.test_batch_size, 19 | gradient_accumulation_steps=args.accum_count, 20 | learning_rate=args.lr, 21 | weight_decay=args.weight_decay, 22 | max_grad_norm=args.max_grad_norm, 23 | max_steps=args.train_steps, 24 | warmup_steps=0 if not args.warmup else args.warmup_steps, 25 | logging_steps=args.report_every, 26 | save_strategy="steps", 27 | save_steps=args.save_checkpoint_steps, 28 | eval_steps=args.save_checkpoint_steps, 29 | no_cuda=True if args.visible_gpus == '-1' else False, 30 | seed=args.seed, 31 | load_best_model_at_end=True, 32 | metric_for_best_model="accuracy" 33 | ) 34 | elif args.task == 'mrpc': 35 | trainer_args = TrainingArguments( 36 | output_dir=args.model_path, 37 | evaluation_strategy="epoch", 38 | per_device_train_batch_size=args.batch_size, 39 | per_device_eval_batch_size=args.test_batch_size, 40 | gradient_accumulation_steps=args.accum_count, 41 | learning_rate=args.lr, 42 | weight_decay=args.weight_decay, 43 | max_grad_norm=args.max_grad_norm, 44 | warmup_steps=0, 45 | logging_steps=args.report_every, 46 | num_train_epochs=20, 47 | save_strategy="epoch", 48 | no_cuda=True if args.visible_gpus == '-1' else False, 49 | seed=args.seed, 50 | load_best_model_at_end=True, 51 | metric_for_best_model="accuracy", 52 | ) 53 | 54 | dataset = load_from_disk(args.data_path) 55 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) 56 | data_collator = DataCollator(args, tokenizer) 57 | 58 | if args.baseline: 59 | model = AutoModelForSequenceClassification.from_pretrained( 60 | args.test_from, num_labels=args.num_labels 61 | ) 62 | else: 63 | model = Model(args.model, args.num_labels, checkpoint=args.test_from, debug=args.debug) 64 | 65 | metric = Metric() 66 | if args.num_labels > 2: 67 | metric_fct = metric.compute_metrics_macro_f1 68 | else: 69 | metric_fct = metric.compute_metrics_f1 70 | 71 | trainer = Trainer( 72 | model, 73 | trainer_args, 74 | train_dataset=dataset["train"], 75 | eval_dataset=dataset['test'], 76 | data_collator=data_collator if not args.baseline else None, 77 | tokenizer=tokenizer, 78 | compute_metrics=metric_fct 79 | ) 80 | eval_result = trainer.evaluate() 81 | print(eval_result) 82 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import division 4 | 5 | import argparse 6 | import os 7 | 8 | 9 | def str2bool(v): 10 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 11 | return True 12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError('Boolean value expected.') 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | # ['roberta-base', 'roberta-large', 'bert-base-uncased', 'bert-large-uncased','albert-base-v2','albert-large-v2','microsoft/deberta-large','microsoft/deberta-base', 'funnel-transformer/medium'] 21 | # ['hfl/chinese-macbert-base','hfl/chinese-macbert-large'] 22 | parser = argparse.ArgumentParser() 23 | # Basic args_ 24 | parser.add_argument("-mode", default='train', type=str, choices=['train', 'test']) 25 | parser.add_argument("-data_path", default='data', type=str) 26 | parser.add_argument("-model_path", default='models', type=str) 27 | parser.add_argument("-result_path", default='results', type=str) 28 | parser.add_argument('-task', default='qqp', type=str, choices=['qqp', 'mrpc', 'medical']) 29 | parser.add_argument('-visible_gpus', default='0', type=str) 30 | parser.add_argument('-seed', default=666, type=int) 31 | parser.add_argument('-loss_type', default=1, type=int) # For ablation 32 | 33 | # Batch sizes 34 | parser.add_argument("-batch_size", default=16, type=int) 35 | parser.add_argument("-test_batch_size", default=64, type=int) 36 | 37 | # Model args 38 | parser.add_argument("-baseline", type=str2bool, nargs='?', const=True, default=False) 39 | parser.add_argument("-model", default="", type=str) 40 | parser.add_argument("-num_labels", default=2, type=int) 41 | 42 | # Training process args 43 | parser.add_argument("-save_checkpoint_steps", default=2000, type=int) 44 | parser.add_argument("-accum_count", default=4, type=int) 45 | parser.add_argument("-report_every", default=5, type=int) 46 | parser.add_argument("-train_steps", default=50000, type=int) 47 | 48 | # Optim args 49 | parser.add_argument("-lr", default=2e-05, type=float) 50 | parser.add_argument("-warmup", type=str2bool, nargs='?', const=True, default=False) 51 | parser.add_argument("-warmup_steps", default=1000, type=int) 52 | parser.add_argument("-weight_decay", default=0.01, type=float) 53 | parser.add_argument("-max_grad_norm", default=1.0, type=float) 54 | 55 | # Utility args 56 | parser.add_argument("-test_from", default='', type=str) 57 | parser.add_argument("-train_from", default='', type=str) 58 | parser.add_argument("-debug", type=str2bool, nargs='?', const=True, default=False) 59 | 60 | args = parser.parse_args() 61 | args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))] 62 | args.world_size = len(args.gpu_ranks) 63 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 64 | 65 | model_name = args.model.replace('/', '-') 66 | args.model_path = args.model_path + '/' + args.task + '/' + model_name 67 | args.result_path = args.result_path + '/' + args.task + '/' + model_name + '.txt' 68 | args.data_path = args.data_path + '/' + args.task + '/' + model_name + '.save' 69 | 70 | from train import train 71 | from test import test 72 | 73 | if (args.mode == 'train'): 74 | train(args) 75 | elif (args.mode == 'test'): 76 | test(args) 77 | else: 78 | print("Undefined mode! Please check input.") 79 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | from datasets import load_from_disk 8 | from others.metrics import Metric 9 | from others.data_collator import DataCollator 10 | from models.sm_model import Model 11 | from transformers import TrainingArguments, Trainer 12 | from transformers import AutoModelForSequenceClassification 13 | from transformers import AutoTokenizer 14 | 15 | 16 | def train(args): 17 | print(args) 18 | if args.task == 'qqp' or args.task == 'medical': 19 | trainer_args = TrainingArguments( 20 | output_dir=args.model_path, 21 | evaluation_strategy="steps", 22 | per_device_train_batch_size=args.batch_size, 23 | per_device_eval_batch_size=args.test_batch_size, 24 | gradient_accumulation_steps=args.accum_count, 25 | learning_rate=args.lr, 26 | weight_decay=args.weight_decay, 27 | max_grad_norm=args.max_grad_norm, 28 | max_steps=args.train_steps, 29 | warmup_steps=0 if not args.warmup else args.warmup_steps, 30 | logging_steps=args.report_every, 31 | save_strategy="steps", 32 | save_steps=args.save_checkpoint_steps, 33 | eval_steps=args.save_checkpoint_steps, 34 | no_cuda=True if args.visible_gpus == '-1' else False, 35 | seed=args.seed, 36 | load_best_model_at_end=True, 37 | metric_for_best_model="accuracy" 38 | ) 39 | elif args.task == 'mrpc': 40 | trainer_args = TrainingArguments( 41 | output_dir=args.model_path, 42 | evaluation_strategy="epoch", 43 | per_device_train_batch_size=args.batch_size, 44 | per_device_eval_batch_size=args.test_batch_size, 45 | gradient_accumulation_steps=args.accum_count, 46 | learning_rate=args.lr, 47 | weight_decay=args.weight_decay, 48 | max_grad_norm=args.max_grad_norm, 49 | warmup_steps=0, 50 | logging_steps=args.report_every, 51 | num_train_epochs=20, 52 | save_strategy="epoch", 53 | no_cuda=True if args.visible_gpus == '-1' else False, 54 | seed=args.seed, 55 | load_best_model_at_end=True, 56 | metric_for_best_model="accuracy", 57 | ) 58 | dataset = load_from_disk(args.data_path) 59 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) 60 | data_collator = DataCollator(args, tokenizer) 61 | 62 | if args.baseline: 63 | model = AutoModelForSequenceClassification.from_pretrained( 64 | args.model, num_labels=args.num_labels 65 | ) 66 | else: 67 | model = Model(args.model, args.num_labels, loss_type=args.loss_type) 68 | 69 | metric = Metric() 70 | if args.num_labels > 2: 71 | metric_fct = metric.compute_metrics_macro_f1 72 | else: 73 | metric_fct = metric.compute_metrics_f1 74 | 75 | trainer = Trainer( 76 | model, 77 | trainer_args, 78 | train_dataset=dataset["train"], 79 | eval_dataset=dataset['validation'], 80 | data_collator=data_collator if not args.baseline else None, 81 | tokenizer=tokenizer, 82 | compute_metrics=metric_fct 83 | ) 84 | 85 | trainer.train() 86 | eval_result = trainer.evaluate() 87 | predict_result = trainer.predict(dataset['test']) 88 | print(eval_result) 89 | print(predict_result) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DC-Match 2 | 3 | Pytorch implementation of the ACL-2022 (findings) paper: [Divide and Conquer: Text Semantic Matching with Disentangled Keywords and Intents](https://arxiv.org/abs/2203.02898). 4 | 5 | ## Environments 6 | 7 | * Python 3.9.7 8 | 9 | * pytorch 1.10.2 10 | 11 | * transformers 4.17.0 12 | 13 | * datasets 2.0.0 14 | 15 | * RTX 3090 GPU & Titan RTX 16 | 17 | * CUDA 11.4 18 | 19 | ## Data 20 | 21 | All the processed datastes used in our work are available at [Google Drive](https://drive.google.com/file/d/1OugsTLxqdoxAWaC93hn8xAoiMzjdxjXg/view?usp=sharing) or [Baidu Pan (extract code: w2op)](https://pan.baidu.com/s/1H1fpJJL9wZEMDicdBOemwg?pwd=w2op), including QQP, MRPC, and Medical-SM. 22 | 23 | ## Usage 24 | 25 | * Download raw datasets from the above data links and put them into the directory **raw_data** like this: 26 | 27 | ``` 28 | --- raw_data 29 | | 30 | |--- medical 31 | | 32 | |--- mrpc 33 | | 34 | |--- qqp 35 | ``` 36 | 37 | * We have tried various pre-trained models. The following models work fine with our code: 38 | 39 | * model names for QQP and MRPC: 40 | - roberta-base 41 | - roberta-large 42 | - bert-base-uncased 43 | - bert-large-uncased 44 | - albert-base-v2 45 | - albert-large-v2 46 | - microsoft/deberta-large 47 | - microsoft/deberta-base 48 | - funnel-transformer/medium 49 | 50 | * model names for Medical: 51 | - hfl/chinese-macbert-base 52 | - hfl/chinese-macbert-large 53 | - hfl/chinese-roberta-wwm-ext 54 | - hfl/chinese-roberta-wwm-ext-large 55 | 56 | * Pre-process datasets. 57 | 58 | ``` 59 | PYTHONPATH=. python ./src/preprocess.py -raw_path raw_data/mrpc 60 | ``` 61 | ``` 62 | PYTHONPATH=. python ./src/preprocess.py -raw_path raw_data/qqp 63 | ``` 64 | ``` 65 | PYTHONPATH=. python ./src/preprocess.py -raw_path raw_data/medical 66 | ``` 67 | 68 | * Training and Evaluation (Baseline) 69 | 70 | * MRPC 71 | ``` 72 | PYTHONPATH=. python -u src/main.py \ 73 | -baseline \ 74 | -task mrpc \ 75 | -model roberta-large \ 76 | -num_labels 2 \ 77 | -batch_size 16 \ 78 | -accum_count 1 \ 79 | -test_batch_size 128 \ 80 | >> logs/mrpc.roberta_large.baseline.log 81 | ``` 82 | 83 | * QQP 84 | ``` 85 | PYTHONPATH=. python -u src/main.py \ 86 | -baseline \ 87 | -task qqp \ 88 | -model roberta-large \ 89 | -num_labels 2 \ 90 | -batch_size 16 \ 91 | -accum_count 4 \ 92 | -test_batch_size 128 \ 93 | >> logs/qqp.roberta_large.baseline.log 94 | ``` 95 | 96 | * Medical 97 | ``` 98 | PYTHONPATH=. python -u src/main.py \ 99 | -baseline \ 100 | -task medical \ 101 | -model hfl/chinese-roberta-wwm-ext-large \ 102 | -num_labels 3 \ 103 | -batch_size 16 \ 104 | -accum_count 4 \ 105 | -test_batch_size 128 \ 106 | >> logs/medical.roberta_large.baseline.log 107 | ``` 108 | 109 | * Training and Evaluation (DC-Match) 110 | 111 | * MRPC 112 | ``` 113 | PYTHONPATH=. python -u src/main.py \ 114 | -task mrpc \ 115 | -model roberta-large \ 116 | -num_labels 2 \ 117 | -batch_size 16 \ 118 | -accum_count 1 \ 119 | -test_batch_size 128 \ 120 | >> logs/mrpc.roberta_large.log 121 | ``` 122 | 123 | * QQP 124 | ``` 125 | PYTHONPATH=. python -u src/main.py \ 126 | -task qqp \ 127 | -model roberta-large \ 128 | -num_labels 2 \ 129 | -batch_size 16 \ 130 | -accum_count 4 \ 131 | -test_batch_size 128 \ 132 | >> logs/qqp.roberta_large.log 133 | ``` 134 | 135 | * Medical 136 | ``` 137 | PYTHONPATH=. python -u src/main.py \ 138 | -task medical \ 139 | -model hfl/chinese-roberta-wwm-ext-large \ 140 | -num_labels 3 \ 141 | -batch_size 16 \ 142 | -accum_count 4 \ 143 | -test_batch_size 128 \ 144 | >> logs/medical.roberta_large.log 145 | ``` 146 | 147 | ## Citation 148 | 149 | @article{zou2022divide, 150 | title={Divide and Conquer: Text Semantic Matching with Disentangled Keywords and Intents}, 151 | author={Zou, Yicheng and Liu, Hongwei and Gui, Tao and Wang, Junzhe and Zhang, Qi and Tang, Meng and Li, Haixiang and Wang, Daniel}, 152 | journal={arXiv preprint arXiv:2203.02898}, 153 | year={2022} 154 | } 155 | -------------------------------------------------------------------------------- /src/prepro/json_to_data.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import glob 3 | from os.path import join as pjoin 4 | 5 | from others.logging import logger 6 | from transformers import AutoTokenizer 7 | 8 | 9 | class Processor(): 10 | def __init__(self, args): 11 | self.args = args 12 | self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 13 | self.num_class = args.num_class 14 | 15 | def handle_special_token(self, token_list): 16 | return [item.replace("Ġ", "") for item in token_list] 17 | 18 | def process(self, example): 19 | 20 | encoded_dict = {} 21 | 22 | encoded_origin_dict = self.tokenizer(example['origin_a'], example['origin_b'], 23 | truncation=True, 24 | max_length=self.args.max_length, 25 | return_token_type_ids=True) 26 | 27 | origin_a_b = self.tokenizer.convert_ids_to_tokens(encoded_origin_dict['input_ids']) 28 | 29 | if "chinese" not in self.args.tokenizer: 30 | origin_a_b = self.handle_special_token(origin_a_b) 31 | 32 | sorted_keywords = sorted(example['keyword_a'] + example['keyword_b'], 33 | key=lambda x: len(x['entity']), 34 | reverse=True) 35 | 36 | replaced_a = example['origin_a'] 37 | replaced_b = example['origin_b'] 38 | keyword_mask = [0] * len(origin_a_b) 39 | special_mask = [1 if token in (self.tokenizer.cls_token, 40 | self.tokenizer.sep_token) else 0 41 | for token in origin_a_b] 42 | context_mask = [0 if token in (self.tokenizer.cls_token, 43 | self.tokenizer.sep_token) else 1 44 | for token in origin_a_b] 45 | 46 | keywords = [] 47 | for item in sorted_keywords: 48 | 49 | if not (item['entity'] in replaced_a or item['entity'] in replaced_b): 50 | continue 51 | 52 | encoded_kw = self.tokenizer.tokenize(item['entity']) 53 | if "chinese" not in self.args.tokenizer: 54 | encoded_kw = self.handle_special_token(encoded_kw) 55 | for idx in range(len(origin_a_b)): 56 | if origin_a_b[idx] == encoded_kw[0] and \ 57 | origin_a_b[idx: idx+len(encoded_kw)] == encoded_kw: 58 | keyword_mask[idx: idx+len(encoded_kw)] = [1] * len(encoded_kw) 59 | context_mask[idx: idx+len(encoded_kw)] = [0] * len(encoded_kw) 60 | keywords.append(item['entity']) 61 | replaced_a = replaced_a.replace(item['entity'], '#') 62 | replaced_b = replaced_b.replace(item['entity'], '#') 63 | 64 | if 'chinese' in self.args.tokenizer: 65 | origin_a_b = ''.join(origin_a_b) 66 | else: 67 | origin_a_b = ' '.join(origin_a_b) 68 | 69 | encoded_dict['input_ids'] = encoded_origin_dict['input_ids'] 70 | encoded_dict['attention_mask'] = encoded_origin_dict['attention_mask'] 71 | encoded_dict['token_type_ids'] = encoded_origin_dict['token_type_ids'] 72 | encoded_dict['keyword_mask'] = keyword_mask 73 | encoded_dict['context_mask'] = context_mask 74 | encoded_dict['special_mask'] = special_mask 75 | encoded_dict['origin_str'] = origin_a_b 76 | encoded_dict['keywords'] = keywords 77 | 78 | # mask assertion 79 | for idx in range(len(encoded_dict['input_ids'])): 80 | assert encoded_dict['keyword_mask'][idx] + \ 81 | encoded_dict['context_mask'][idx] + \ 82 | encoded_dict['special_mask'][idx] == \ 83 | encoded_dict['attention_mask'][idx] 84 | 85 | return encoded_dict 86 | 87 | 88 | def format_json_to_data(args): 89 | 90 | train_lst = [] 91 | dev_lst = [] 92 | test_lst = [] 93 | for json_f in glob.glob(pjoin(args.raw_path, '*.json')): 94 | real_name = json_f.split('/')[-1] 95 | corpus_type = real_name.split('.')[-2] 96 | if corpus_type == 'train': 97 | train_lst.append(json_f) 98 | elif corpus_type == 'test': 99 | test_lst.append(json_f) 100 | else: 101 | dev_lst.append(json_f) 102 | 103 | dataset = datasets.load_dataset( 104 | 'json', 105 | data_files={'train': train_lst, 106 | 'validation': dev_lst if len(dev_lst) > 0 else test_lst, 107 | 'test': test_lst} 108 | ) 109 | 110 | processor = Processor(args) 111 | encoded_dataset = dataset.map( 112 | processor.process, 113 | load_from_cache_file=False, 114 | num_proc=8 115 | ) 116 | for corpus_type in ['train', 'validation', 'test']: 117 | total_statistic = { 118 | "instances": 0, 119 | "exceed_length_num": 0, 120 | "total_length": 0., 121 | "src_length_dist": [0] * 11, 122 | } 123 | for item in encoded_dataset[corpus_type]: 124 | total_statistic['instances'] += 1 125 | if len(item['input_ids']) > args.max_length: 126 | total_statistic['exceed_length_num'] += 1 127 | total_statistic['total_length'] += len(item['origin_a']) + len(item['origin_b']) 128 | total_statistic['src_length_dist'][min(len(item['origin_a']) // 30, 10)] += 1 129 | total_statistic['src_length_dist'][min(len(item['origin_b']) // 30, 10)] += 1 130 | 131 | dataset[corpus_type] 132 | if total_statistic["instances"] > 0: 133 | logger.info("Total %s examples: %d" % 134 | (corpus_type, total_statistic["instances"])) 135 | logger.info("Number of samples that exceed maximum length: %d" % 136 | total_statistic["exceed_length_num"]) 137 | logger.info("Average length of src sentence: %f" % 138 | (total_statistic["total_length"] / (2. * total_statistic["instances"]))) 139 | for idx, num in enumerate(total_statistic["src_length_dist"]): 140 | logger.info("token num %d ~ %d: %d, %.2f%%" % 141 | (idx * 30, (idx+1) * 30, num, (num / (2. * total_statistic["instances"])))) 142 | 143 | encoded_dataset.save_to_disk(args.save_path + '.save') 144 | -------------------------------------------------------------------------------- /src/models/sm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import AutoModel, AutoConfig 5 | from transformers.modeling_outputs import SequenceClassifierOutput 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, model_name, labels, loss_type=1, checkpoint=None, debug=False): 10 | super(Model, self).__init__() 11 | self.label_num = labels 12 | self.debug = debug 13 | self.config = AutoConfig.from_pretrained(model_name) 14 | self.loss_type = loss_type 15 | 16 | if "deberta" in model_name.lower() or "funnel" in model_name.lower(): 17 | self.pooler = nn.Sequential( 18 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 19 | nn.Tanh() 20 | ) 21 | self.classifier = nn.Linear(self.config.hidden_size, labels) 22 | self.kw_con_classifier = nn.Linear(self.config.hidden_size, 1) 23 | if "funnel" in model_name.lower(): 24 | self.dropout = nn.Dropout(self.config.hidden_dropout) 25 | else: 26 | self.dropout = nn.Dropout(self.config.hidden_dropout_prob) 27 | 28 | self.apply(self._init_weights) 29 | 30 | self.encoder = AutoModel.from_pretrained(model_name) 31 | 32 | if checkpoint is not None: 33 | cp = torch.load(checkpoint, map_location=lambda storage, loc: storage) 34 | self.load_state_dict(cp, strict=True) 35 | 36 | def _init_weights(self, module): 37 | """Initialize the weights""" 38 | if isinstance(module, nn.Linear): 39 | # Slightly different from the TF version which uses truncated_normal for initialization 40 | # cf https://github.com/pytorch/pytorch/pull/5617 41 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 42 | if module.bias is not None: 43 | module.bias.data.zero_() 44 | elif isinstance(module, nn.Embedding): 45 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 46 | if module.padding_idx is not None: 47 | module.weight.data[module.padding_idx].zero_() 48 | elif isinstance(module, nn.LayerNorm): 49 | module.bias.data.zero_() 50 | module.weight.data.fill_(1.0) 51 | 52 | def apply(self, fn): 53 | for module in self.children(): 54 | module.apply(fn) 55 | fn(self) 56 | return self 57 | 58 | def forward(self, input_ids, token_type_ids, attention_mask, labels, 59 | keyword_mask, context_mask, special_mask): 60 | 61 | if not self.training and not self.debug: 62 | output_all = self.encoder(input_ids, attention_mask, token_type_ids, return_dict=True) 63 | if "pooler_output" in output_all.keys(): 64 | logits_all = self.classifier(self.dropout(output_all.pooler_output)) 65 | else: 66 | pooler_out_all = self.pooler(output_all.last_hidden_state[:, 0]) 67 | logits_all = self.classifier(self.dropout(pooler_out_all)) 68 | cls_loss = F.cross_entropy(logits_all.view(-1, self.label_num), labels.view(-1)) 69 | 70 | return SequenceClassifierOutput( 71 | loss=cls_loss, 72 | logits=logits_all 73 | ) 74 | 75 | # build mask 76 | kw_mask = keyword_mask + special_mask 77 | con_mask = context_mask + special_mask 78 | 79 | # encoding 80 | output_all = self.encoder(input_ids, attention_mask, token_type_ids, return_dict=True) 81 | output_kw = self.encoder(input_ids, kw_mask, token_type_ids, return_dict=True) 82 | output_con = self.encoder(input_ids, con_mask, token_type_ids, return_dict=True) 83 | 84 | # cls logits 85 | if "pooler_output" in output_all.keys(): 86 | logits_all = self.classifier(self.dropout(output_all.pooler_output)) 87 | logits_kw = self.classifier(self.dropout(output_kw.pooler_output)) 88 | logits_con = self.classifier(self.dropout(output_con.pooler_output)) 89 | else: 90 | pooler_out_all = self.pooler(output_all.last_hidden_state[:, 0]) 91 | pooler_out_kw = self.pooler(output_kw.last_hidden_state[:, 0]) 92 | pooler_out_con = self.pooler(output_con.last_hidden_state[:, 0]) 93 | 94 | logits_all = self.classifier(self.dropout(pooler_out_all)) 95 | logits_kw = self.classifier(self.dropout(pooler_out_kw)) 96 | logits_con = self.classifier(self.dropout(pooler_out_con)) 97 | 98 | # get mean pooling states 99 | all_kw = (output_all.last_hidden_state * kw_mask.unsqueeze(-1).float())\ 100 | .sum(1).div(kw_mask.float().sum(-1).unsqueeze_(-1)) 101 | all_con = (output_all.last_hidden_state * con_mask.unsqueeze(-1).float())\ 102 | .sum(1).div(con_mask.float().sum(-1).unsqueeze_(-1)) 103 | sep_kw = (output_kw.last_hidden_state * kw_mask.unsqueeze(-1).float())\ 104 | .sum(1).div(kw_mask.float().sum(-1).unsqueeze_(-1)) 105 | sep_con = (output_con.last_hidden_state * con_mask.unsqueeze(-1).float())\ 106 | .sum(1).div(con_mask.float().sum(-1).unsqueeze_(-1)) 107 | 108 | # kw_con mean pooling logits 109 | kw_con_logits = self.kw_con_classifier( 110 | self.dropout(torch.cat([all_kw, sep_kw, all_con, sep_con], 0)) 111 | ) 112 | 113 | # kw_con labels 114 | kw_con_labels = torch.cat([labels.new_ones(all_kw.size(0) * 2), 115 | labels.new_zeros(all_con.size(0) * 2)], 0).float() 116 | 117 | # joint probability distribution 118 | prob_all = F.log_softmax(logits_all, -1).view(-1, self.label_num) 119 | prob_kw = F.log_softmax(logits_kw, -1).view(-1, self.label_num) 120 | prob_con = F.log_softmax(logits_con, -1).view(-1, self.label_num) 121 | prob_joint = prob_kw.unsqueeze(-1).expand(-1, self.label_num, self.label_num) + \ 122 | prob_con.unsqueeze(-2).expand(-1, self.label_num, self.label_num) 123 | prob_joint_list = [] 124 | for idx in range(self.label_num): 125 | prob_dim = prob_joint[:, idx:, idx:].exp().sum((1, 2)) - prob_joint[:, idx+1:, idx+1:].exp().sum((1, 2)) 126 | prob_joint_list.append(prob_dim.unsqueeze(-1)) 127 | prob_kw_con = (torch.cat(prob_joint_list, -1)+1e-20).log() 128 | 129 | cls_loss = F.cross_entropy(logits_all.view(-1, self.label_num), labels.view(-1)) 130 | kw_con_loss = F.binary_cross_entropy_with_logits(kw_con_logits.view(-1), kw_con_labels.view(-1)) 131 | # kl_loss = F.kl_div(prob_kw_con, prob_all, reduction='batchmean', log_target=True) 132 | kl_loss = 0.5 * ( 133 | F.kl_div(prob_kw_con, prob_all, reduction='batchmean', log_target=True) + 134 | F.kl_div(prob_all, prob_kw_con, reduction='batchmean', log_target=True) 135 | ) 136 | if self.loss_type == 1: 137 | loss = cls_loss + kw_con_loss + kl_loss 138 | elif self.loss_type == 2: 139 | loss = cls_loss + kl_loss 140 | else: 141 | loss = cls_loss + kw_con_loss 142 | 143 | if self.training: 144 | return SequenceClassifierOutput( 145 | loss=loss, 146 | logits=logits_all 147 | ) 148 | else: 149 | return { 150 | "loss": loss, 151 | "logits": logits_all, 152 | "kw_logits": logits_kw, 153 | "con_logits": logits_con 154 | } 155 | --------------------------------------------------------------------------------