├── .gitignore ├── README.md ├── cluener_dataset.py ├── global_pointer.py ├── main.py ├── predict.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤗 GlobalPointer 2 | 3 | - 苏剑林博客: 4 | - [Global Pointer](https://kexue.fm/archives/8373) 5 | - [Efficient GlobalPointer](https://spaces.ac.cn/archives/8877) 6 | - 原版 keras 实现:https://github.com/bojone/GlobalPointer/blob/main/CLUENER_GlobalPointer.py 7 | - CLUENER 官方测试集提交:https://www.cluebenchmarks.com/ 8 | 9 | CLUENER 结果对比 10 | 11 | | 方法名称 | 验证集F1 |测试集F1| 参数量 | 12 | |----------------------------------------------------------------------------------------------|--------| ---- |------------| 13 | | CRF(from [Global Pointer](https://kexue.fm/archives/8373)) | 79.51% | 78.70% | 14 | | GlobalPointer(from [Global Pointer](https://kexue.fm/archives/8373)) | 80.03% | 79.44% | 15 | | Efficient GlobalPointer (from [Efficient GlobalPointer](https://spaces.ac.cn/archives/8877)) | 80.66% | 80.04% | 16 | | GlobalPointer(w/ RoPE) | 80.26% | | 102661376 | 17 | | GlobalPointer(w/o RoPE) | 79.3% | | 102661376 | 18 | | Efficient GlobalPointer(w/ RoPE) | 80.64% || 101790868 | 19 | | Efficient GlobalPointer(w/o RoPE) | 79.57% || 101778068 | 20 | 21 | 训练脚本: 22 | 23 | - 通过 `--global_pointer_head` 切换 `GlobalPointer` 和 `EfficientGlobalPointer` 24 | - 通过 `--rope` 切换要不要加旋转位置编码 `RoPE` 25 | 26 | ```bash 27 | python3 main.py \ 28 | --model_name_or_path bert-base-chinese \ 29 | --dataset_name ./cluener_dataset.py \ 30 | --output_dir ./model/efficient_global_pointer_no_rope \ 31 | --save_total_limit 1 \ 32 | --per_device_train_batch_size 16 \ 33 | --learning_rate 2e-5 \ 34 | --lr_scheduler_type constant \ 35 | --global_pointer_head EfficientGlobalPointer \ 36 | --weight_decay 0.05 \ 37 | --num_train_epochs 10 \ 38 | --dataloader_num_workers 8 \ 39 | --load_best_model_at_end True \ 40 | --metric_for_best_model f1 \ 41 | --evaluation_strategy epoch \ 42 | --save_strategy epoch \ 43 | --logging_steps 100 \ 44 | --rope True \ 45 | --fp16 \ 46 | --do_train \ 47 | --do_eval 48 | ``` 49 | 50 | 对验证集进行评估: 51 | 52 | ```bash 53 | python3 main.py \ 54 | --model_name_or_path ./model/global_pointer \ 55 | --output_dir ./model/global_pointer \ 56 | --dataset_name ./cluener_dataset.py \ 57 | --fp16 \ 58 | --do_eval 59 | ``` 60 | 61 | 跑测试脚本,测试结果保存为 json: 62 | 63 | ```bash 64 | python3 predict.py test ./model/global_pointer.py gp_test.json 65 | ``` 66 | 67 | 直接输入 input 看预测结果 68 | 69 | ```bash 70 | python3 predict.py predict ./model/global_pointer.py 71 | ``` 72 | -------------------------------------------------------------------------------- /cluener_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import os 4 | 5 | import datasets 6 | 7 | logger = datasets.logging.get_logger(__name__) 8 | 9 | _URL = "https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip" 10 | _TRAINING_FILE = "train.json" 11 | _DEV_FILE = "dev.json" 12 | _TEST_FILE = "test.json" 13 | 14 | 15 | class CluenerConfig(datasets.BuilderConfig): 16 | def __init__(self, **kwargs): 17 | super(CluenerConfig, self).__init__(**kwargs) 18 | 19 | 20 | class Cluener(datasets.GeneratorBasedBuilder): 21 | """Conll2003 dataset.""" 22 | 23 | BUILDER_CONFIGS = [ 24 | CluenerConfig(name="cluener", version=datasets.Version("1.0.0")), 25 | ] 26 | 27 | def _info(self): 28 | return datasets.DatasetInfo( 29 | description="", 30 | features=datasets.Features( 31 | { 32 | "id": datasets.Value("string"), 33 | "text": datasets.Value("string"), 34 | "span_tags": datasets.Sequence( 35 | { 36 | "tag": datasets.features.ClassLabel( 37 | names=[ 38 | "address", 39 | "book", 40 | "company", 41 | "game", 42 | "government", 43 | "movie", 44 | "name", 45 | "organization", 46 | "position", 47 | "scene" 48 | ] 49 | ), 50 | "start": datasets.Value("int16"), 51 | "end": datasets.Value("int16"), 52 | } 53 | ), 54 | } 55 | ), 56 | supervised_keys=None, 57 | homepage="https://www.aclweb.org/anthology/W03-0419/", 58 | ) 59 | 60 | def _split_generators(self, dl_manager): 61 | """Returns SplitGenerators.""" 62 | downloaded_file = dl_manager.download_and_extract(_URL) 63 | data_files = { 64 | "train": os.path.join(downloaded_file, _TRAINING_FILE), 65 | "dev": os.path.join(downloaded_file, _DEV_FILE), 66 | "test": os.path.join(downloaded_file, _TEST_FILE), 67 | } 68 | 69 | return [ 70 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": data_files["train"]}), 71 | datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": data_files["dev"]}), 72 | datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": data_files["test"]}), 73 | ] 74 | 75 | def _generate_examples(self, filepath): 76 | logger.info("⏳ Generating examples from = %s", filepath) 77 | with open(filepath, encoding="utf-8") as f: 78 | guid = 0 79 | for line in f: 80 | """ 81 | { 82 | 'text': '索尼《GT赛车》新作可能会发行PC版?', 83 | 'label': { 84 | 'game': {'《GT赛车》': [[2, 7]]}, 85 | 'company': {'索尼': [[0, 1]]} 86 | } 87 | } 88 | """ 89 | data = json.loads(line) 90 | span_tags = [] 91 | 92 | if 'label' in data: 93 | for tag, labels in data['label'].items(): 94 | for spans in labels.values(): 95 | for span in spans: 96 | span_tags.append({ 97 | "tag": tag, 98 | "start": span[0], 99 | "end": span[1], 100 | }) 101 | 102 | yield guid, { 103 | "id": str(data.get("id", guid)), 104 | "text": data['text'].lower(), 105 | "span_tags": span_tags, 106 | } 107 | guid += 1 108 | -------------------------------------------------------------------------------- /global_pointer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import BertPreTrainedModel, BertModel 6 | from transformers.modeling_outputs import TokenClassifierOutput 7 | 8 | INFINITY = 1e12 9 | 10 | from transformers.utils import logging 11 | 12 | logging.set_verbosity_info() 13 | logger = logging.get_logger("transformers") 14 | 15 | 16 | class SinusoidalPositionEmbedding(nn.Module): 17 | def __init__(self, output_dim, merge_mode='add', custom_position_ids=False): 18 | super().__init__() 19 | self.output_dim = output_dim 20 | self.merge_mode = merge_mode 21 | self.custom_position_ids = custom_position_ids 22 | 23 | def forward(self, inputs): 24 | input_shape = inputs.shape 25 | 26 | batch_size, seq_len = input_shape[0], input_shape[1] 27 | position_ids = torch.arange(seq_len).type(torch.float)[None] 28 | indices = torch.arange(self.output_dim // 2).type(torch.float) 29 | indices = torch.pow(10000.0, -2 * indices / self.output_dim) 30 | embeddings = torch.einsum('bn,d->bnd', position_ids, indices) 31 | embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) 32 | embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim)) 33 | if self.merge_mode == 'add': 34 | return inputs + embeddings.to(inputs.device) 35 | elif self.merge_mode == 'mul': 36 | return inputs * (embeddings + 1.0).to(inputs.devices) 37 | elif self.merge_mode == 'zero': 38 | return embeddings.to(inputs.device) 39 | 40 | 41 | def apply_rotary_position_embeddings(pos, qw, kw): 42 | ndim = qw.ndim 43 | 44 | if ndim == 4: 45 | cos_pos = pos[..., None, 1::2].repeat(1, 1, 1, 2) 46 | sin_pos = pos[..., None, ::2].repeat(1, 1, 1, 2) 47 | else: 48 | cos_pos = pos[..., 1::2].repeat(1, 1, 2) 49 | sin_pos = pos[..., ::2].repeat(1, 1, 2) 50 | 51 | qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], ndim) 52 | qw2 = torch.reshape(qw2, qw.shape) 53 | qw = qw * cos_pos + qw2 * sin_pos 54 | 55 | kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], ndim) 56 | kw2 = torch.reshape(kw2, kw.shape) 57 | kw = kw * cos_pos + kw2 * sin_pos 58 | 59 | return qw, kw 60 | 61 | 62 | def sequence_masking(x, mask, value, dim): 63 | if mask is None: 64 | return x 65 | 66 | assert dim > 0, 'dim must > 0' 67 | for _ in range(dim - 1): 68 | mask = torch.unsqueeze(mask, 1) 69 | for _ in range(x.ndim - mask.ndim): 70 | mask = torch.unsqueeze(mask, mask.ndim) 71 | return x * mask + value * (1 - mask) 72 | 73 | 74 | class GlobalPointer(nn.Module): 75 | """全局指针模块 76 | 将序列的每个(start, end)作为整体来进行判断 77 | 参考:https://kexue.fm/archives/8373 78 | """ 79 | 80 | def __init__( 81 | self, 82 | heads, 83 | head_size, 84 | hidden_size, 85 | RoPE=True, 86 | ): 87 | super(GlobalPointer, self).__init__() 88 | self.heads = heads 89 | self.head_size = head_size 90 | self.hidden_size = hidden_size 91 | self.RoPE = RoPE 92 | self.dense = nn.Linear(hidden_size, heads * head_size * 2) 93 | 94 | def forward(self, inputs, mask=None): 95 | # 输入变换 96 | inputs = self.dense(inputs) 97 | inputs = torch.split(inputs, self.head_size * 2, dim=-1) 98 | inputs = torch.stack(inputs, dim=-2) 99 | qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:] 100 | # RoPE编码 101 | if self.RoPE: 102 | pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs) 103 | qw, kw = apply_rotary_position_embeddings(pos, qw, kw) 104 | # 计算内积 105 | logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw) 106 | # 排除padding 107 | logits = sequence_masking(logits, mask, -INFINITY, 2) 108 | logits = sequence_masking(logits, mask, -INFINITY, 3) 109 | 110 | # 排除下三角 111 | mask = torch.tril(torch.ones_like(logits), diagonal=-1) 112 | logits = logits - mask * INFINITY 113 | return logits / self.head_size ** 0.5 114 | 115 | 116 | class EfficientGlobalPointer(nn.Module): 117 | """更加参数高效的GlobalPointer 118 | 参考:https://kexue.fm/archives/8877 119 | """ 120 | 121 | def __init__(self, heads, head_size, hidden_size, RoPE=True): 122 | super().__init__() 123 | self.heads = heads 124 | self.head_size = head_size 125 | self.hidden_size = hidden_size 126 | self.RoPE = RoPE 127 | 128 | self.p_dense = nn.Linear( 129 | in_features=hidden_size, 130 | out_features=self.head_size * 2, 131 | ) 132 | self.q_dense = nn.Linear( 133 | in_features=self.head_size * 2, 134 | out_features=self.heads * 2, 135 | ) 136 | 137 | def forward(self, inputs, mask=None): 138 | # 输入变换 139 | inputs = self.p_dense(inputs) 140 | qw, kw = inputs[..., ::2], inputs[..., 1::2] 141 | # RoPE编码 142 | if self.RoPE: 143 | pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs) 144 | qw, kw = apply_rotary_position_embeddings(pos, qw, kw) 145 | # 计算内积 146 | logits = torch.einsum('bmd,bnd->bmn', qw, kw) / self.head_size ** 0.5 147 | bias = torch.einsum('bnh->bhn', self.q_dense(inputs)) / 2 148 | logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] 149 | # 排除padding 150 | logits = sequence_masking(logits, mask, -INFINITY, 2) 151 | logits = sequence_masking(logits, mask, -INFINITY, 3) 152 | # 排除下三角 153 | mask = torch.tril(torch.ones_like(logits), diagonal=-1) 154 | logits = logits - mask * INFINITY 155 | return logits 156 | 157 | 158 | def multilabel_categorical_crossentropy(y_true, y_pred): 159 | y_pred = (1 - 2 * y_true) * y_pred 160 | y_pred_neg = y_pred - y_true * INFINITY 161 | y_pred_pos = y_pred - (1 - y_true) * INFINITY 162 | zeros = torch.zeros_like(y_pred[..., :1]) 163 | y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) 164 | y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) 165 | neg_loss = torch.logsumexp(y_pred_neg, dim=-1) 166 | pos_loss = torch.logsumexp(y_pred_pos, dim=-1) 167 | return neg_loss + pos_loss 168 | 169 | 170 | def global_pointer_crossentropy(y_true, y_pred): 171 | """ 172 | 173 | :param y_true: [batch_size, num_classes, max_length, max_length] 174 | :param y_pred: 175 | :return: 176 | """ 177 | bh = y_pred.shape[0] * y_pred.shape[1] 178 | y_true = torch.reshape(y_true, (bh, -1)) 179 | y_pred = torch.reshape(y_pred, (bh, -1)) 180 | return torch.mean(multilabel_categorical_crossentropy(y_true, y_pred)) 181 | 182 | 183 | def global_pointer_f1_score(y_true, y_pred): 184 | """给GlobalPointer设计的F1 185 | """ 186 | y_pred[y_pred > 0] = 1 187 | y_pred[y_pred <= 0] = 0 188 | return 2 * (y_true * y_pred).sum() / (y_true + y_pred).sum() 189 | 190 | 191 | class BertGPForTokenClassification(BertPreTrainedModel): 192 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 193 | 194 | def __init__(self, config): 195 | super().__init__(config) 196 | self.num_labels = config.num_labels 197 | 198 | self.bert = BertModel(config, add_pooling_layer=False) 199 | classifier_dropout = ( 200 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 201 | ) 202 | self.dropout = nn.Dropout(classifier_dropout) 203 | 204 | global_pointer_head = config.global_pointer_head 205 | RoPE = config.RoPE 206 | 207 | if global_pointer_head == "GlobalPointer": 208 | self.classifier = GlobalPointer(config.num_labels, 64, config.hidden_size, RoPE=RoPE) 209 | else: 210 | self.classifier = EfficientGlobalPointer(config.num_labels, 64, config.hidden_size, RoPE=RoPE) 211 | 212 | self.post_init() 213 | 214 | def forward( 215 | self, 216 | input_ids: Optional[torch.Tensor] = None, 217 | attention_mask: Optional[torch.Tensor] = None, 218 | token_type_ids: Optional[torch.Tensor] = None, 219 | position_ids: Optional[torch.Tensor] = None, 220 | head_mask: Optional[torch.Tensor] = None, 221 | inputs_embeds: Optional[torch.Tensor] = None, 222 | labels: Optional[torch.Tensor] = None, 223 | output_attentions: Optional[bool] = None, 224 | output_hidden_states: Optional[bool] = None, 225 | return_dict: Optional[bool] = None, 226 | ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: 227 | r""" 228 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 229 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 230 | """ 231 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 232 | 233 | outputs = self.bert( 234 | input_ids, 235 | attention_mask=attention_mask, 236 | token_type_ids=token_type_ids, 237 | position_ids=position_ids, 238 | head_mask=head_mask, 239 | inputs_embeds=inputs_embeds, 240 | output_attentions=output_attentions, 241 | output_hidden_states=output_hidden_states, 242 | return_dict=return_dict, 243 | ) 244 | 245 | sequence_output = outputs[0] 246 | sequence_output = self.dropout(sequence_output) 247 | logits = self.classifier(sequence_output, mask=attention_mask) 248 | 249 | loss = None 250 | if labels is not None: 251 | loss = global_pointer_crossentropy(labels, logits) 252 | 253 | if not return_dict: 254 | output = (logits,) + outputs[2:] 255 | return ((loss,) + output) if loss is not None else output 256 | 257 | return TokenClassifierOutput( 258 | loss=loss, 259 | logits=logits, 260 | hidden_states=outputs.hidden_states, 261 | attentions=outputs.attentions, 262 | ) 263 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import datasets 9 | import numpy as np 10 | import rich 11 | import transformers 12 | from datasets import load_dataset 13 | from transformers import ( 14 | AutoConfig, 15 | AutoTokenizer, 16 | HfArgumentParser, 17 | Trainer, 18 | TrainingArguments, 19 | set_seed, 20 | ) 21 | from transformers.trainer_utils import get_last_checkpoint 22 | 23 | from global_pointer import BertGPForTokenClassification, global_pointer_f1_score 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | tags = [ 28 | "address", 29 | "book", 30 | "company", 31 | "game", 32 | "government", 33 | "movie", 34 | "name", 35 | "organization", 36 | "position", 37 | "scene" 38 | ] 39 | 40 | 41 | @dataclass 42 | class ModelArguments: 43 | """ 44 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 45 | """ 46 | 47 | model_name_or_path: str = field( 48 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 49 | ) 50 | global_pointer_head: str = field( 51 | default="GlobalPointer", 52 | metadata={"help": "GlobalPointer or EfficientGlobalPointer"} 53 | ) 54 | rope: bool = field( 55 | default=True, 56 | metadata={"help": "Whether or not to add SinusoidalPositionEmbedding"} 57 | ) 58 | 59 | 60 | @dataclass 61 | class DataTrainingArguments: 62 | """ 63 | Arguments pertaining to what data we are going to input our model for training and eval. 64 | """ 65 | 66 | dataset_name: Optional[str] = field( 67 | default="./cluener_dataset.py", metadata={"help": "The name of the dataset to use (via the datasets library)."} 68 | ) 69 | overwrite_cache: bool = field( 70 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 71 | ) 72 | preprocessing_num_workers: Optional[int] = field( 73 | default=None, 74 | metadata={"help": "The number of processes to use for the preprocessing."}, 75 | ) 76 | max_seq_length: int = field( 77 | default=256, 78 | metadata={ 79 | "help": ( 80 | "The maximum total input sequence length after tokenization. If set, sequences longer " 81 | "than this will be truncated, sequences shorter will be padded." 82 | ) 83 | }, 84 | ) 85 | max_train_samples: Optional[int] = field( 86 | default=None, 87 | metadata={ 88 | "help": ( 89 | "For debugging purposes or quicker training, truncate the number of training examples to this " 90 | "value if set." 91 | ) 92 | }, 93 | ) 94 | max_eval_samples: Optional[int] = field( 95 | default=None, 96 | metadata={ 97 | "help": ( 98 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 99 | "value if set." 100 | ) 101 | }, 102 | ) 103 | 104 | 105 | def main(): 106 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 107 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 108 | assert model_args.global_pointer_head in ["GlobalPointer", "EfficientGlobalPointer"] 109 | 110 | # Setup logging 111 | logging.basicConfig( 112 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 113 | datefmt="%m/%d/%Y %H:%M:%S", 114 | handlers=[logging.StreamHandler(sys.stdout)], 115 | ) 116 | 117 | log_level = training_args.get_process_log_level() 118 | logger.setLevel(log_level) 119 | datasets.utils.logging.set_verbosity(log_level) 120 | transformers.utils.logging.set_verbosity(log_level) 121 | transformers.utils.logging.enable_default_handler() 122 | transformers.utils.logging.enable_explicit_format() 123 | 124 | # Log on each process the small summary: 125 | logger.warning( 126 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 127 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 128 | ) 129 | logger.info(f"Training/evaluation parameters {training_args}") 130 | 131 | # Detecting last checkpoint. 132 | last_checkpoint = None 133 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 134 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 135 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 136 | raise ValueError( 137 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 138 | "Use --overwrite_output_dir to overcome." 139 | ) 140 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 141 | logger.info( 142 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 143 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 144 | ) 145 | 146 | set_seed(training_args.seed) 147 | 148 | raw_datasets = load_dataset( 149 | data_args.dataset_name, 150 | ) 151 | 152 | label_list = tags 153 | num_labels = len(label_list) 154 | 155 | config = AutoConfig.from_pretrained( 156 | model_args.model_name_or_path, 157 | num_labels=num_labels, 158 | ) 159 | config.global_pointer_head = model_args.global_pointer_head 160 | config.RoPE = model_args.rope 161 | 162 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=True) 163 | 164 | model = BertGPForTokenClassification.from_pretrained( 165 | model_args.model_name_or_path, 166 | config=config, 167 | ) 168 | 169 | logger.info(f"Model parameters: {model.num_parameters()}") 170 | 171 | # Tokenize all texts and align the labels with them. 172 | def tokenize_and_align_labels(examples): 173 | tokenized_inputs = tokenizer( 174 | examples["text"], 175 | padding="max_length", 176 | truncation=True, 177 | max_length=data_args.max_seq_length, 178 | return_offsets_mapping=True, 179 | return_tensors="pt", 180 | ) 181 | labels = [] 182 | for i, label in enumerate(examples["span_tags"]): 183 | offset_mapping = tokenized_inputs["offset_mapping"][i] 184 | mapping = {} 185 | for token_idx, val in enumerate(offset_mapping): 186 | start = val[0].item() 187 | end = val[-1].item() 188 | if start == end == 0: # [CLS] 189 | continue 190 | for ci in range(start, end): 191 | mapping[ci] = token_idx 192 | 193 | label_ids = np.zeros((num_labels, data_args.max_seq_length, data_args.max_seq_length)) 194 | 195 | text = examples["text"][i] 196 | # rich.print("-" * 80) 197 | # rich.print(text) 198 | # rich.print(label) 199 | # tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[i]) 200 | for tag_id, start, end in zip(label['tag'], label['start'], label['end']): 201 | if start in mapping and end in mapping: 202 | token_start = mapping[start] 203 | token_end = mapping[end] 204 | # if "".join(tokens[token_start: token_end + 1]).replace("#", "") != text[start:end + 1].replace("#", ""): 205 | # rich.print(tokens[token_start: token_end + 1]) 206 | # rich.print(text[start:end + 1]) 207 | # import ipdb; 208 | # ipdb.set_trace() 209 | label_ids[tag_id, token_start, token_end] = 1 210 | 211 | labels.append(label_ids) 212 | tokenized_inputs["labels"] = labels 213 | 214 | tokenized_inputs.pop("offset_mapping") 215 | return tokenized_inputs 216 | 217 | if training_args.do_train: 218 | if "train" not in raw_datasets: 219 | raise ValueError("--do_train requires a train dataset") 220 | train_dataset = raw_datasets["train"] 221 | if data_args.max_train_samples is not None: 222 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 223 | train_dataset = train_dataset.select(range(max_train_samples)) 224 | with training_args.main_process_first(desc="train dataset map pre-processing"): 225 | train_dataset = train_dataset.with_transform(tokenize_and_align_labels) 226 | 227 | if training_args.do_eval: 228 | if "validation" not in raw_datasets: 229 | raise ValueError("--do_eval requires a validation dataset") 230 | eval_dataset = raw_datasets["validation"] 231 | if data_args.max_eval_samples is not None: 232 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 233 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 234 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 235 | eval_dataset = eval_dataset.with_transform(tokenize_and_align_labels) 236 | 237 | def compute_metrics(p): 238 | predictions, labels = p 239 | f1 = global_pointer_f1_score(labels, predictions) 240 | return {"f1": f1} 241 | 242 | training_args.remove_unused_columns = False 243 | 244 | trainer = Trainer( 245 | model=model, 246 | args=training_args, 247 | train_dataset=train_dataset if training_args.do_train else None, 248 | eval_dataset=eval_dataset if training_args.do_eval else None, 249 | tokenizer=tokenizer, 250 | compute_metrics=compute_metrics, 251 | ) 252 | 253 | # Training 254 | if training_args.do_train: 255 | checkpoint = None 256 | if training_args.resume_from_checkpoint is not None: 257 | checkpoint = training_args.resume_from_checkpoint 258 | elif last_checkpoint is not None: 259 | checkpoint = last_checkpoint 260 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 261 | metrics = train_result.metrics 262 | trainer.save_model() # Saves the tokenizer too for easy upload 263 | 264 | max_train_samples = ( 265 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 266 | ) 267 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 268 | 269 | trainer.log_metrics("train", metrics) 270 | trainer.save_metrics("train", metrics) 271 | trainer.save_state() 272 | 273 | if training_args.do_eval: 274 | logger.info("*** Evaluate ***") 275 | 276 | metrics = trainer.evaluate() 277 | 278 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 279 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 280 | 281 | trainer.log_metrics("eval", metrics) 282 | trainer.save_metrics("eval", metrics) 283 | 284 | 285 | if __name__ == "__main__": 286 | main() 287 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import json 4 | from pathlib import Path 5 | 6 | import datasets 7 | import rich 8 | import numpy as np 9 | import torch 10 | import typer 11 | from datasets import load_dataset 12 | from transformers import AutoTokenizer 13 | 14 | from global_pointer import BertGPForTokenClassification 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | tags = [ 19 | "address", 20 | "book", 21 | "company", 22 | "game", 23 | "government", 24 | "movie", 25 | "name", 26 | "organization", 27 | "position", 28 | "scene" 29 | ] 30 | 31 | app = typer.Typer(add_completion=False) 32 | 33 | 34 | @app.command() 35 | def test(model_path: str, save_path: Path, device: str = 'cuda', dataset_name: str = './cluener_dataset.py'): 36 | device = torch.device(device) 37 | raw_datasets = load_dataset(dataset_name, split=datasets.Split.TEST) 38 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 39 | model = BertGPForTokenClassification.from_pretrained(model_path).to(device) 40 | 41 | entities = [] 42 | for it in raw_datasets: 43 | text = it['text'] 44 | tokenized_inputs = tokenizer( 45 | text, 46 | padding="max_length", 47 | truncation=True, 48 | max_length=256, 49 | return_offsets_mapping=True, 50 | return_tensors="pt" 51 | ) 52 | offsets_mapping = tokenized_inputs.pop('offset_mapping')[0] 53 | inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tokenized_inputs.items()} 54 | pred = model(**inputs) 55 | 56 | logits = pred.logits[0].detach().cpu().numpy() 57 | 58 | logits[:, [0, -1]] -= np.inf 59 | logits[:, :, [0, -1]] -= np.inf 60 | 61 | # { 62 | # 'text': '索尼《GT赛车》新作可能会发行PC版?', 63 | # 'label': { 64 | # 'game': {'《GT赛车》': [[2, 7]]}, 65 | # 'company': {'索尼': [[0, 1]]} 66 | # } 67 | # } 68 | labels = {} 69 | for tag_idx, token_start_index, token_end_index in zip(*np.where(logits > 0)): 70 | start = offsets_mapping[token_start_index][0].item() 71 | end = offsets_mapping[token_end_index][-1].item() 72 | tag = tags[tag_idx] 73 | entity = text[start:end] 74 | 75 | if tag not in labels: 76 | labels[tag] = {entity: [[start, end]]} 77 | else: 78 | if entity in labels[tag]: 79 | labels[tag][entity].append([start, end]) 80 | else: 81 | labels[tag][entity] = [[start, end]] 82 | 83 | entities.append({ 84 | "id": it['id'], 85 | "text": text, 86 | 'label': labels 87 | }) 88 | 89 | rich.print(labels) 90 | 91 | with open(save_path, 'w', encoding='utf-8') as f: 92 | for e in entities: 93 | f.write(f"{json.dumps(e, ensure_ascii=False)}\n") 94 | 95 | 96 | @app.command() 97 | def predict(model_path: str, device: str = 'cuda'): 98 | device = torch.device(device) 99 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 100 | model = BertGPForTokenClassification.from_pretrained(model_path).to(device) 101 | 102 | entities = [] 103 | while True: 104 | text = input("Input text:") 105 | tokenized_inputs = tokenizer( 106 | text, 107 | padding="max_length", 108 | truncation=True, 109 | max_length=256, 110 | return_offsets_mapping=True, 111 | return_tensors="pt" 112 | ) 113 | offsets_mapping = tokenized_inputs.pop('offset_mapping')[0] 114 | inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tokenized_inputs.items()} 115 | pred = model(**inputs) 116 | 117 | logits = pred.logits[0].detach().cpu().numpy() 118 | logits[:, [0, -1]] -= np.inf 119 | logits[:, :, [0, -1]] -= np.inf 120 | 121 | labels = {} 122 | for tag_idx, token_start_index, token_end_index in zip(*np.where(logits > 0)): 123 | start = offsets_mapping[token_start_index][0].item() 124 | end = offsets_mapping[token_end_index][-1].item() 125 | tag = tags[tag_idx] 126 | entity = text[start:end] 127 | 128 | if tag not in labels: 129 | labels[tag] = {entity: [[start, end]]} 130 | else: 131 | if entity in labels[tag]: 132 | labels[tag][entity].append([start, end]) 133 | else: 134 | labels[tag][entity] = [[start, end]] 135 | 136 | entities.append({ 137 | "text": text, 138 | 'label': labels 139 | }) 140 | 141 | rich.print(labels) 142 | 143 | 144 | if __name__ == "__main__": 145 | app() 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich 2 | typer 3 | transformers==4.18.0 4 | datasets==2.3.0 --------------------------------------------------------------------------------