├── .gitignore ├── README.md ├── __init__.py ├── pretraining ├── README.MD ├── lm_pretrain │ ├── data_utils.py │ ├── gpt2_base_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── gpt2_large_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── gpt2_xl_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ └── train.py ├── mlm_pretrain │ ├── data_utils.py │ └── train.py ├── prompt_t5_pretrain │ ├── README.md │ ├── data_utils.py │ ├── evaluate_pclue.py │ ├── t5_base_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_large_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_small_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_xl_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_xxl_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ └── task_prompt_t5.py ├── seq2seq_pretrain │ ├── data_utils.py │ ├── t5_base_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_large_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_small_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_xl_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ ├── t5_xxl_config │ │ ├── config.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ └── train.py ├── simbert-v2_pretrain │ ├── data_utils.py │ └── task_simsce_unilm.py └── t5encoder_mlm_pretrain │ ├── data_utils.py │ ├── t5_base_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_large_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_xl_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_xxl_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ └── train.py ├── task_classify ├── task_tnews.py ├── task_tnews_adversarial.py ├── task_tnews_hierarchical_position.py ├── task_tnews_prefixprompt.py └── task_tnews_prefixtuning.py ├── task_extract_event └── task_event_gplinker.py ├── task_extract_ner ├── task_cluener_cascad_crf.py ├── task_cluener_crf.py ├── task_cluener_crf_adversarial.py ├── task_cluener_crf_prefixtuning.py ├── task_cluener_mhs_ner.py ├── task_cluener_pointer.py ├── task_cluener_pointer_adversarial.py ├── task_cluener_pointer_prefixtuning.py ├── task_cluener_pure.py ├── task_cluener_span_ner.py ├── task_cluener_tplinkerplus.py └── task_cluener_w2ner.py ├── task_extract_relation ├── task_relation_casrel.py ├── task_relation_gplinker.py ├── task_relation_gplinker_adversarial.py ├── task_relation_mhslinker.py ├── task_relation_onerel.py ├── task_relation_prgc.py ├── task_relation_splinker.py ├── task_relation_spn4re.py ├── task_relation_tplinker.py └── task_relation_tplinkerplus.py ├── task_grammatical_error_correction ├── 1.png ├── README.md ├── task_ctc_gector │ ├── data_utils.py │ └── task_ctc_gector.py └── task_ctc_seq2seq │ ├── data_utils.py │ ├── t5_base_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_large_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_small_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_xl_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ ├── t5_xxl_config │ ├── config.json │ ├── tokenizer_config.json │ └── vocab.txt │ └── task_ctc_seq2seq.py ├── task_sentence_vector ├── task_classify_vector │ ├── task_tnews_arcface.py │ ├── task_tnews_circle_loss.py │ └── task_tnews_cosface.py ├── task_classify_vector_record │ ├── convert_train_pos_neg_for_infonce.py │ ├── corpus_process │ │ ├── jieba_process_corpus.py │ │ ├── split_corpus.py │ │ └── stopwards.txt │ ├── load_record.py │ ├── make_record_for_classify.py │ ├── merge_record.py │ ├── shuffle_record.py │ ├── split_record.py │ ├── split_record_and_modify.py │ ├── task_my_arcface.py │ ├── task_my_circleloss.py │ ├── task_my_cosface.py │ └── task_my_infonce.py ├── task_sup_vector │ ├── task_afqmc_contrastiveloss.py │ ├── task_afqmc_cosent.py │ ├── task_diffcse_sup.py │ ├── task_infonce_sup.py │ ├── task_promptbertcse_sup.py │ └── task_simsce_sup.py └── task_unsup_vector │ ├── task_diffcse.py │ ├── task_esimsce.py │ ├── task_promptbertcse.py │ ├── task_simsce.py │ ├── task_simsce_mlm.py │ └── task_tsdae.py └── task_text_generate ├── task_autotitle_unilm.py └── task_autotitle_unilm_distillation.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /test 3 | /tests 4 | /build 5 | /deep_training.egg-info 6 | /dist 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 安装 2 | 3 | - pip install -U deep_training >= 0.1.0 4 | - 当前文档版本pypi 0.1.0 5 | 6 | ## 更新详情 7 | 8 | - [deep_training](https://github.com/ssbuild/deep_training) 9 | - 10 | ## 其他训练 11 | - [pytorch-task-example](https://github.com/ssbuild/pytorch-task-example) 12 | - [tf-task-example](https://github.com/ssbuild/tf-task-example) 13 | - [poetry_training](https://github.com/ssbuild/poetry_training) 14 | 15 | 16 | ## 目录 17 | - pretraining 主流预训练模型 18 | - task_classify 分类模型 19 | - task_extract_ner 序列抽取模型 20 | - tast_extract_relation 关系抽取模型 21 | - tast_extract_event 事件抽取模型 22 | - task_generate 文本生成模型 23 | - task_grammatical_error_correction 文本纠错模型 24 | - task_sentence_vector 句向量模型 25 | - task_custom_muti_gpu 更多自定义训练操作,例如多卡训练例子, 模型转换onnx 等一些列自定义操作 26 | 27 | ## 对抗训练就在配置里增加一个选项 28 | 'adv': { 29 | 'mode': 'fgm', # None, fgm, fgsm_local, fgsm(不推荐), pgd, free_local, free(不推荐) 30 | 'emb_name': 'embedding', 31 | 'attack_iters': 2, # pgd 32 | 'minibatch_replays': 2, # free 33 | 'alpha': 0.5, # pgd,fgsm 34 | 'epsilon': 0.5, # pgd,fgm 35 | } 36 | 37 | ## 层次分解位置编码,让BERT可以处理超长文本 38 | 'hierarchical_position': 0.4 39 | 40 | ## 导出onnx模型 通常只需要三步 41 | 42 | 第一步,参数配置 convert_onnx = True 43 | 第二步 加载权重例子 44 | model = MyTransformer.load_from_checkpoint('./best.pt', config=config, model_args=model_args, 45 | training_args=training_args) 46 | 第三步 #导出onnx模型 47 | model.convert_to_onnx('./best.onnx') 48 | 49 | ## 多卡训练策略 strategy , 通常只需要一步 50 | 修改参数配置 devices = N 51 | 52 | # Available names: bagua, colossalai, ddp, ddp_find_unused_parameters_false, ddp_fork, 53 | # ddp_fork_find_unused_parameters_false, ddp_fully_sharded, 54 | # ddp_notebook, ddp_notebook_find_unused_parameters_false, ddp_sharded, 55 | # ddp_sharded_find_unused_parameters_false, ddp_sharded_spawn, 56 | # ddp_sharded_spawn_find_unused_parameters_false, 57 | # ddp_spawn, ddp_spawn_find_unused_parameters_false, 58 | # deepspeed, deepspeed_stage_1, deepspeed_stage_2, deepspeed_stage_2_offload, 59 | # deepspeed_stage_3, deepspeed_stage_3_offload, deepspeed_stage_3_offload_nvme, 60 | # dp, fsdp, fsdp_native, fsdp_native_full_shard_offload, horovod, hpu_parallel, 61 | # hpu_single, ipu_strategy, single_device, single_tpu, tpu_spawn, tpu_spawn_debug" 62 | 63 | ## 大模型Lora训练 64 | 65 | [chatyuan_finetuning](https://github.com/ssbuild/chatyuan_finetuning) 66 | [prompt_finetuning](https://github.com/ssbuild/prompt_finetuning) 67 | 68 | ## 愿景 69 | 70 | 创建一个模型工厂, 轻量且高效的训练程序,让训练模型更容易,更轻松上手。 71 | 72 | ## 交流 73 | 74 | QQ交流群:185144988 75 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbuild/pytorch-task-example/7d2341562c4ae3070fc7fc18b3b1886a74391ca2/__init__.py -------------------------------------------------------------------------------- /pretraining/README.MD: -------------------------------------------------------------------------------- 1 | ## 预训练模型 2 | 3 | ## 更多生成任务,参考poetry_training 4 | 5 | - [poetry_training](https://github.com/ssbuild/poetry_training) (https://github.com/ssbuild/poetry_training) -------------------------------------------------------------------------------- /pretraining/lm_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:02 3 | # @Author:XIE392 4 | # @File:data_utils.py 5 | import json 6 | 7 | import numpy as np 8 | import torch 9 | import typing 10 | 11 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, DataArguments 12 | from transformers import BertTokenizer, HfArgumentParser 13 | 14 | train_info_args = { 15 | 'devices': 1, 16 | 'data_backend': 'record', 17 | 'model_type': 'gpt2', 18 | # 预训练模型路径 , 从0训练,则置空 19 | # 'model_name_or_path': '/data/nlp/pre_models/torch/', 20 | 'tokenizer_name': './gpt2_base_config', 21 | 'config_name': './gpt2_base_config/config.json', 22 | 'convert_onnx': False, # 转换onnx模型 23 | 'do_train': True, 24 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 25 | 'learning_rate': 5e-5, 26 | 'max_epochs': 3, 27 | 'train_batch_size': 8, 28 | 'test_batch_size': 2, 29 | 'adam_epsilon': 1e-8, 30 | 'gradient_accumulation_steps': 1, 31 | 'max_grad_norm': 1.0, 32 | 'weight_decay': 0, 33 | 'warmup_steps': 0, 34 | 'output_dir': './output', 35 | 'train_max_seq_length': 400, 36 | 'eval_max_seq_length': 512, 37 | 'test_max_seq_length': 512, 38 | } 39 | 40 | class NN_DataHelper(DataHelper): 41 | # 切分词 42 | def on_data_process(self, data: typing.Any, mode: str): 43 | tokenizer: BertTokenizer 44 | max_seq_length = self.max_seq_length_dict[mode] 45 | tokenizer = self.tokenizer 46 | 47 | x = data 48 | if isinstance(x, tuple): 49 | o = tokenizer(text=x[0], text_pair=x[1], max_length=max_seq_length, truncation=True, 50 | add_special_tokens=True) 51 | else: 52 | o = tokenizer(x, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 53 | 54 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 55 | attention_mask = np.asarray(o['attention_mask'], dtype=np.int64) 56 | token_type_ids = np.asarray(o['token_type_ids'], dtype=np.int64) 57 | 58 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 59 | pad_len = max_seq_length - len(input_ids) 60 | if pad_len > 0: 61 | pad_val = tokenizer.pad_token_id 62 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 63 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 64 | token_type_ids = np.pad(token_type_ids, (0, pad_len), 'constant', constant_values=(0, 0)) 65 | d = { 66 | 'input_ids': input_ids, 67 | 'attention_mask': attention_mask, 68 | 'token_type_ids': token_type_ids, 69 | 'labels': np.where(input_ids != tokenizer.pad_token_id, input_ids, np.ones_like(input_ids) * -100), 70 | 'seqlen': seqlen 71 | } 72 | return d 73 | 74 | # 读取文件 75 | def on_get_corpus(self, files: typing.List, mode: str): 76 | D = [] 77 | for filename in files: 78 | with open(filename, mode='r', encoding='utf-8') as f: 79 | lines = f.readlines() 80 | for i, line in enumerate(lines): 81 | jd = json.loads(line) 82 | D.append((jd['content'], jd['title'])) 83 | if i > 1000: 84 | break 85 | return D 86 | 87 | def collate_fn(self,batch): 88 | o = {} 89 | for i, b in enumerate(batch): 90 | if i == 0: 91 | for k in b: 92 | o[k] = [torch.tensor(b[k])] 93 | else: 94 | for k in b: 95 | o[k].append(torch.tensor(b[k])) 96 | for k in o: 97 | o[k] = torch.stack(o[k]) 98 | 99 | max_len = torch.max(o.pop('seqlen')) 100 | 101 | o['input_ids'] = o['input_ids'][:, :max_len] 102 | o['attention_mask'] = o['attention_mask'][:, :max_len] 103 | if 'token_type_ids' in o: 104 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 105 | o['labels'] = o['labels'][:, :max_len] 106 | return o 107 | 108 | if __name__ == '__main__': 109 | 110 | 111 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 112 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 113 | 114 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 115 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 116 | 117 | # 缓存数据集 118 | if data_args.do_train: 119 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 120 | if data_args.do_eval: 121 | dataHelper.make_dataset_with_args(data_args.eval_file,shuffle=False,mode='eval') 122 | if data_args.do_test: 123 | dataHelper.make_dataset_with_args(data_args.test_file, shuffle=False, mode='test') 124 | -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_base_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "return_dict": false, 7 | "resid_pdrop": 0.1, 8 | "embd_pdrop": 0.1, 9 | "attn_pdrop": 0.1, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-8, 12 | "n_ctx": 1024, 13 | "n_embd": 768, 14 | "n_head": 12, 15 | "n_layer": 12, 16 | "n_positions": 1024, 17 | "tokenizer_class": "BertTokenizer", 18 | "vocab_size": 16448 19 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_base_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_large_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "return_dict": false, 7 | "resid_pdrop": 0.1, 8 | "embd_pdrop": 0.1, 9 | "attn_pdrop": 0.1, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-8, 12 | "n_ctx": 1024, 13 | "n_embd": 1280, 14 | "n_head": 20, 15 | "n_layer": 36, 16 | "n_positions": 1024, 17 | "tokenizer_class": "BertTokenizer", 18 | "vocab_size": 16448 19 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_large_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_xl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "return_dict": false, 7 | "resid_pdrop": 0.1, 8 | "embd_pdrop": 0.1, 9 | "attn_pdrop": 0.1, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-8, 12 | "n_ctx": 1024, 13 | "n_embd": 1600, 14 | "n_head": 25, 15 | "n_layer": 48, 16 | "n_positions": 1024, 17 | "tokenizer_class": "BertTokenizer", 18 | "vocab_size": 16448 19 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/gpt2_xl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/lm_pretrain/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 5 | from deep_training.nlp.models.transformer import TransformerForCausalLM 6 | from lightning import Trainer 7 | from lightning.pytorch.callbacks import ModelCheckpoint 8 | from torch.utils.data import DataLoader, IterableDataset 9 | from transformers import HfArgumentParser 10 | from data_utils import NN_DataHelper,train_info_args 11 | 12 | 13 | class MyTransformer(TransformerForCausalLM, with_pl=True): 14 | def __init__(self, *args, **kwargs): 15 | super(MyTransformer, self).__init__(*args, **kwargs) 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 20 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 21 | 22 | checkpoint_callback = ModelCheckpoint(monitor="loss", save_top_k=5, 23 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 24 | trainer = Trainer( 25 | callbacks=[checkpoint_callback], 26 | max_epochs=training_args.max_epochs, 27 | max_steps=training_args.max_steps, 28 | accelerator="gpu", 29 | devices=data_args.devices, 30 | enable_progress_bar=True, 31 | default_root_dir=data_args.output_dir, 32 | gradient_clip_val=training_args.max_grad_norm, 33 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 34 | num_sanity_val_steps=0, 35 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 36 | ) 37 | 38 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 39 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 40 | 41 | # 缓存数据集 42 | if data_args.do_train: 43 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 44 | if data_args.do_eval: 45 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 46 | if data_args.do_test: 47 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 48 | 49 | 50 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 51 | 52 | if not data_args.convert_onnx: 53 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 54 | with_load_memory=False, 55 | with_record_iterable_dataset=True, 56 | collate_fn=dataHelper.collate_fn, 57 | batch_size=training_args.train_batch_size, 58 | shuffle=True, infinite=True, num_processes=trainer.world_size, 59 | process_index=trainer.global_rank) 60 | if train_datasets is not None: 61 | trainer.fit(model, train_dataloaders=train_datasets) 62 | else: 63 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 64 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 65 | if eval_datasets is not None: 66 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 67 | 68 | if test_datasets is not None: 69 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 70 | -------------------------------------------------------------------------------- /pretraining/mlm_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:09 3 | # @Author:XIE392 4 | # @File:data_utils.py 5 | import copy 6 | import json 7 | import random 8 | 9 | import torch 10 | import typing 11 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, MlmDataArguments, DataArguments 12 | from deep_training.utils.maskedlm import make_mlm_wwm_sample 13 | from transformers import BertTokenizer, HfArgumentParser 14 | from fastdatasets import gfile 15 | 16 | 17 | train_info_args = { 18 | 'devices': 1, 19 | 'data_backend': 'record', 20 | 'model_type': 'bert', 21 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 22 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 23 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 24 | 'convert_onnx': False, # 转换onnx模型 25 | 'do_train': True, 26 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 27 | 'learning_rate': 5e-5, 28 | 'max_epochs': None, 29 | 'max_steps': 300000, 30 | 'train_batch_size': 10, 31 | 'test_batch_size': 2, 32 | 'adam_epsilon': 1e-8, 33 | 'gradient_accumulation_steps': 1, 34 | 'max_grad_norm': 1.0, 35 | 'weight_decay': 0.01, 36 | 'warmup_steps': 10000, 37 | 'output_dir': './output', 38 | 'train_max_seq_length': 512, 39 | 'eval_max_seq_length': 512, 40 | 'test_max_seq_length': 512, 41 | 'do_lower_case': True, 42 | 'do_whole_word_mask': True, 43 | 'max_predictions_per_seq': 20, 44 | 'dupe_factor': 5, 45 | 'masked_lm_prob': 0.15 46 | } 47 | 48 | 49 | data_conf = { 50 | 'count_per_group': 1, 51 | } 52 | 53 | class NN_DataHelper(DataHelper): 54 | index = -1 55 | def on_data_ready(self): 56 | self.index = -1 57 | 58 | # 切分词 59 | def on_data_process(self, data: typing.Any, mode: typing.Any): 60 | self.index += 1 61 | 62 | tokenizer: BertTokenizer 63 | max_seq_length = self.max_seq_length_dict[mode] 64 | tokenizer = self.tokenizer 65 | 66 | 67 | rng, do_whole_word_mask, max_predictions_per_seq, masked_lm_prob = self.external_kwargs['mlm_args'] 68 | 69 | group_documents = data 70 | 71 | document_text_string = '' 72 | for documents in group_documents: 73 | document_text_string += ''.join(documents) 74 | 75 | document_texts = [] 76 | pos = 0 77 | slide_window = int(max_seq_length * 1.0) 78 | while pos < len(document_text_string): 79 | text = document_text_string[pos:pos + slide_window - 2] 80 | pos += len(text) 81 | document_texts.append(text) 82 | # 返回多个文档 83 | document_nodes = [] 84 | for text in document_texts: 85 | node = make_mlm_wwm_sample(text, tokenizer, max_seq_length, rng, do_whole_word_mask, 86 | max_predictions_per_seq, masked_lm_prob) 87 | document_nodes.append(node) 88 | 89 | if self.index < 3: 90 | print(document_nodes[0]) 91 | return document_nodes 92 | 93 | # 读取文件 94 | def on_get_corpus(self, files: typing.List, mode: str): 95 | COUNT_PER_GROUP = data_conf['count_per_group'] 96 | D = [] 97 | sub = [] 98 | line_no = 0 99 | for input_file in files: 100 | with open(input_file, 'r', encoding='utf-8') as f: 101 | lines = f.readlines() 102 | for line in lines: 103 | jd = json.loads(line) 104 | if not jd: 105 | continue 106 | text = jd['text'] 107 | docs = text.split('\n\n') 108 | 109 | d = [doc for doc in docs if doc] 110 | sub.append(d) 111 | if len(sub) >= COUNT_PER_GROUP: 112 | D.append(copy.deepcopy(sub)) 113 | sub.clear() 114 | 115 | line_no += 1 116 | if line_no % 10000 == 0: 117 | print('read_line', line_no) 118 | print(d) 119 | if len(sub): 120 | D.append(copy.deepcopy(sub)) 121 | sub.clear() 122 | 123 | return D 124 | 125 | def collate_fn(self, batch): 126 | o = {} 127 | for i, b in enumerate(batch): 128 | if i == 0: 129 | for k in b: 130 | o[k] = [torch.tensor(b[k])] 131 | else: 132 | for k in b: 133 | o[k].append(torch.tensor(b[k])) 134 | for k in o: 135 | o[k] = torch.stack(o[k]) 136 | 137 | max_len = torch.max(o.pop('seqlen')) 138 | 139 | o['input_ids'] = o['input_ids'][:, :max_len] 140 | o['attention_mask'] = o['attention_mask'][:, :max_len] 141 | if 'token_type_ids' in o: 142 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 143 | 144 | input_ids = o['input_ids'] 145 | masked_lm_positions = o.pop('masked_lm_positions') 146 | masked_lm_ids = o.pop('masked_lm_ids') 147 | masked_lm_weights = o.pop('masked_lm_weights') 148 | labels = torch.clone(input_ids) 149 | mask = torch.zeros_like(input_ids) 150 | for i, (index, value, weight) in enumerate(zip(masked_lm_positions, masked_lm_ids, masked_lm_weights.long())): 151 | s = torch.sum(weight) 152 | labels[i, index[:s]] = value[:s] 153 | mask[i, index[:s]] = 1 154 | o['labels'] = labels 155 | o['mask'] = mask 156 | return o 157 | 158 | if __name__ == '__main__': 159 | 160 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, MlmDataArguments)) 161 | model_args, training_args, data_args, mlm_data_args = parser.parse_dict(train_info_args) 162 | 163 | rng = random.Random(training_args.seed) 164 | dataHelper = NN_DataHelper(model_args, training_args, data_args, mlm_args=( 165 | rng, mlm_data_args.do_whole_word_mask, mlm_data_args.max_predictions_per_seq, mlm_data_args.masked_lm_prob)) 166 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 167 | 168 | # 缓存数据集 169 | if data_args.do_train: 170 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False,shuffle=True,mode='train',dupe_factor=mlm_data_args.dupe_factor, 171 | num_process_worker=20) 172 | if data_args.do_eval: 173 | dataHelper.make_dataset_with_args(data_args.eval_file,shuffle=False,mode='eval') 174 | if data_args.do_test: 175 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /pretraining/mlm_pretrain/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | import torch 6 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments, MlmDataArguments 7 | from deep_training.nlp.models.transformer import TransformerForMaskLM 8 | from lightning import Trainer 9 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor 10 | from torch.nn import CrossEntropyLoss 11 | from torch.utils.data import DataLoader, IterableDataset 12 | from transformers import HfArgumentParser 13 | 14 | from data_utils import NN_DataHelper, train_info_args 15 | from torch.nn.functional import one_hot 16 | 17 | mask_token_id = None 18 | 19 | 20 | class MyTransformer(TransformerForMaskLM, with_pl=True): 21 | def __init__(self, *args, **kwargs): 22 | super(MyTransformer, self).__init__(*args, **kwargs) 23 | self.loss_fct = CrossEntropyLoss(reduction='none') 24 | 25 | def compute_loss_mlm(self, y_trues, y_preds, mask): 26 | y_preds = torch.transpose(y_preds, 1, 2) 27 | masked_lm_loss = self.loss_fct(y_preds, y_trues) 28 | masked_lm_loss = torch.sum(mask * masked_lm_loss) / (torch.sum(mask) + 1e-8) 29 | return masked_lm_loss 30 | 31 | def compute_acc(self, y_trues, y_preds, mask): 32 | acc = torch.eq(torch.argmax(y_preds, dim=-1), y_trues) 33 | acc = torch.sum(mask * acc) / (torch.sum(mask) + 1e-8) 34 | return acc 35 | 36 | def compute_loss(self, *args, **batch) -> tuple: 37 | labels = None 38 | mask = None 39 | if 'labels' in batch: 40 | labels = batch.pop('labels') 41 | mask = batch.pop('mask') 42 | 43 | outputs = self.model(*args, **batch) 44 | logits = outputs[0] 45 | if labels is not None: 46 | loss = self.compute_loss_mlm(labels, logits, mask) 47 | acc = self.compute_acc(labels, logits, batch['attention_mask']) 48 | mlm_acc = self.compute_acc(labels, logits, mask) 49 | loss = { 50 | 'loss': loss, 51 | 'acc': acc, 52 | 'mlm_acc': mlm_acc, 53 | } 54 | outputs = (loss, logits, labels) 55 | else: 56 | outputs = (logits,) 57 | return outputs 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, MlmDataArguments)) 63 | model_args, training_args, data_args, mlm_data_args = parser.parse_dict(train_info_args) 64 | 65 | checkpoint_callback = ModelCheckpoint(save_last=True, 66 | verbose=True, 67 | monitor="loss", 68 | save_top_k=5, 69 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 70 | trainer = Trainer( 71 | callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], 72 | max_epochs=training_args.max_epochs, 73 | max_steps=training_args.max_steps, 74 | accelerator="gpu", 75 | devices=data_args.devices, 76 | enable_progress_bar=True, 77 | default_root_dir=data_args.output_dir, 78 | gradient_clip_val=training_args.max_grad_norm, 79 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 80 | num_sanity_val_steps=0, 81 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 82 | ) 83 | 84 | rng = random.Random(training_args.seed) 85 | dataHelper = NN_DataHelper(model_args, training_args, data_args, mlm_args=( 86 | rng, mlm_data_args.do_whole_word_mask, mlm_data_args.max_predictions_per_seq, mlm_data_args.masked_lm_prob)) 87 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 88 | mask_token_id = tokenizer.mask_token_id 89 | # 缓存数据集 90 | if data_args.do_train: 91 | dataHelper.make_dataset_with_args(data_args.train_file, mixed_data=False, shuffle=True, mode='train', 92 | dupe_factor=mlm_data_args.dupe_factor, num_process_worker=10) 93 | if data_args.do_eval: 94 | dataHelper.make_dataset_with_args(data_args.eval_file, shuffle=False, mode='eval') 95 | if data_args.do_test: 96 | dataHelper.make_dataset_with_args(data_args.test_file, mode='test') 97 | 98 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 99 | 100 | if not data_args.convert_onnx: 101 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 102 | with_load_memory=False, 103 | with_record_iterable_dataset=True, 104 | collate_fn=dataHelper.collate_fn, 105 | batch_size=training_args.train_batch_size, 106 | shuffle=True, infinite=True, num_processes=trainer.world_size, 107 | process_index=trainer.global_rank) 108 | # 恢复断点训练 109 | resume_ckpt_path = r'./epoch=0-step=4200.ckpt' 110 | if not os.path.exists(resume_ckpt_path): 111 | resume_ckpt_path = None 112 | 113 | if train_datasets is not None: 114 | trainer.fit(model, train_dataloaders=train_datasets, ckpt_path=resume_ckpt_path) 115 | else: 116 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files, 117 | batch_size=training_args.eval_batch_size, 118 | collate_fn=dataHelper.collate_fn) 119 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files, 120 | batch_size=training_args.test_batch_size, 121 | collate_fn=dataHelper.collate_fn) 122 | if eval_datasets is not None: 123 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 124 | 125 | if test_datasets is not None: 126 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 127 | -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/README.md: -------------------------------------------------------------------------------- 1 | # 数据示例 2 | 3 | {"input": "我可以用以下的句子:“花呗在什么时间段可以用”,来替换这个句子:“什么时候用了花贝”,并且它们有相同的意思?。选项:是的,不是。答案:", "target": "不是", "type": "classify"} 4 | {"input": "摘要:针对水平受荷桩在不同的长径比和桩土刚度比条件下可以表现出刚性桩、半刚性桩或柔性桩变形特性的特点,运用刚性桩和柔性桩的临界桩长计算公式,结合相似原理,推导了重力加速度为1g条件下缩尺模型桩与原型桩的临界桩长相似比,与几何相似比进行对比,评判模型桩和原型桩的变形特性,分析桩身材料模量相似比与几何相似比对模型桩和原型桩变形特性相似性的影响,并通过有限元方法进行数值模拟验证.研究结果表明:桩身材料模量是控制模型桩与原型桩满足变形特性相似的主要参数;直接采用原型桩材和原型土开展的模型试验与原型桩的变形特性不具有相似性,但通过选择模量相似比接近于几何相似比的模型桩材可以使得模型试验结果与原型相似.\n 以下的关键词都是这篇摘要合适的关键词吗?关键词:几何,模型试验,特性,相似性。答案是:\n选项:是的,不是\n答案:", "target": "不是", "type": "classify"} 5 | {"input": "下面两个句子语义是“相同”或“不同”?“我买商品怎么用不了花呗”,“我的花呗怎么用不了”。选项:相同,不同。答案:", "target": "不同", "type": "classify"} 6 | 7 | 8 | 9 | # 使用方法 10 | 11 | ## 生成训练record 12 | 13 | python data_utils.py 14 | 15 | ## 训练 16 | 17 | python task_poery_unilm.py 18 | 19 | ## 字典 16448 -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/1/22 16:22 2 | # @Author : tk 3 | # @FileName: data_utils.py 4 | #reference: https://github.com/clue-ai/PromptCLUE/blob/main/Fine_tuning_PyTorch.ipynb 5 | 6 | import copy 7 | import json 8 | import os 9 | import random 10 | import typing 11 | 12 | import numpy as np 13 | import torch 14 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, DataArguments 15 | from deep_training.utils.func import is_chinese_char 16 | from fastdatasets.record import load_dataset as Loader, RECORD, WriterObject, gfile 17 | from tqdm import tqdm 18 | from transformers import BertTokenizer, HfArgumentParser 19 | 20 | train_info_args = { 21 | 'devices': 1, 22 | 'data_backend': 'record', 23 | 'model_type': 't5', 24 | # 预训练模型路径 , 从0训练,则置空 25 | # 'model_name_or_path': '/data/nlp/pre_models/torch/', 26 | 'tokenizer_name': './t5_small_config', 27 | 'config_name': './t5_small_config/config.json', 28 | 'convert_onnx': False, # 转换onnx模型 29 | 'do_train': True, 30 | 'train_file': [ '/data/nlp/nlp_train_data/clueprompt/finetune_train_examples.json'], 31 | 'max_epochs': 3, 32 | 'train_batch_size': 10, 33 | 'eval_batch_size': 2, 34 | 'test_batch_size': 2, 35 | 'learning_rate': 5e-5, 36 | 'adam_epsilon': 1e-8, 37 | 'gradient_accumulation_steps': 1, 38 | 'max_grad_norm': 1.0, 39 | 'weight_decay': 0, 40 | 'warmup_steps': 0, 41 | 'output_dir': './output', 42 | 'max_seq_length': 512, 43 | 'max_target_length': 100 # 预测最大长度 44 | } 45 | 46 | 47 | 48 | class NN_DataHelper(DataHelper): 49 | index = 1 50 | 51 | def on_data_ready(self): 52 | self.index = -1 53 | 54 | # 切分词 55 | def on_data_process(self, data: typing.Any, mode: str): 56 | self.index += 1 57 | 58 | tokenizer: BertTokenizer 59 | max_seq_length = self.max_seq_length_dict[mode] 60 | tokenizer = self.tokenizer 61 | 62 | doc_type,src_text,tgt_text = data 63 | 64 | o1 = tokenizer.encode_plus(text=src_text, truncation=True,max_length=max_seq_length) 65 | o2 = tokenizer.encode_plus(text=tgt_text, truncation=True,max_length=max_seq_length) 66 | 67 | input_ids = np.asarray(o1['input_ids'], dtype=np.int64) 68 | attention_mask = np.asarray(o1['attention_mask'], dtype=np.int64) 69 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 70 | pad_len = max_seq_length - seqlen 71 | if pad_len > 0: 72 | pad_val = tokenizer.pad_token_id 73 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 74 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 75 | 76 | decoder_input_ids = np.asarray(o2['input_ids'], dtype=np.int64) 77 | decoder_attention_mask = np.asarray(o2['attention_mask'], dtype=np.int64) 78 | labels = np.asarray(o2['input_ids'][1:], dtype=np.int64) 79 | decoder_seqlen = np.asarray(len(decoder_input_ids), dtype=np.int64) 80 | pad_len = max_seq_length - decoder_seqlen 81 | if pad_len > 0: 82 | pad_val = tokenizer.pad_token_id 83 | decoder_input_ids = np.pad(decoder_input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 84 | decoder_attention_mask = np.pad(decoder_attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 85 | labels = np.pad(labels, (0, pad_len+1), 'constant', constant_values=(-100, -100)) 86 | 87 | d = { 88 | 'input_ids': input_ids, 89 | 'attention_mask': attention_mask, 90 | 'seqlen': seqlen, 91 | 'decoder_input_ids': decoder_input_ids, 92 | 'decoder_attention_mask': decoder_attention_mask, 93 | 'decoder_seqlen': decoder_seqlen, 94 | 'labels':labels 95 | } 96 | return d 97 | 98 | # 读取文件 99 | def on_get_corpus(self, files: typing.List, mode: str): 100 | D = [] 101 | #{"input": "我可以用以下的句子:“花呗在什么时间段可以用”,来替换这个句子:“什么时候用了花贝”,并且它们有相同的意思?。选项:是的,不是。答案:", "target": "不是", "type": "classify"} 102 | for file in files: 103 | with open(file,mode='r',encoding='utf-8',newline='\n') as f: 104 | lines = f.readlines() 105 | 106 | for line in lines: 107 | jd = json.loads(line) 108 | if not jd: 109 | continue 110 | doc_type = jd.get('type', '') 111 | src_text = jd['input'] 112 | tgt_text = jd['target'] 113 | D.append((doc_type,src_text,tgt_text)) 114 | return D 115 | 116 | def collate_fn(self, batch): 117 | o = {} 118 | for i, b in enumerate(batch): 119 | if i == 0: 120 | for k in b: 121 | o[k] = [torch.tensor(b[k])] 122 | else: 123 | for k in b: 124 | o[k].append(torch.tensor(b[k])) 125 | for k in o: 126 | o[k] = torch.stack(o[k]) 127 | 128 | max_len = torch.max(o.pop('seqlen')).numpy().tolist() 129 | decoder_seqlen = torch.max(o.pop('decoder_seqlen')).numpy().tolist() 130 | 131 | 132 | o['input_ids'] = o['input_ids'][:, :max_len] 133 | o['attention_mask'] = o['attention_mask'][:, :max_len] 134 | o['decoder_input_ids'] = o['decoder_input_ids'][:, :decoder_seqlen] 135 | o['decoder_attention_mask'] = o['decoder_attention_mask'][:, :decoder_seqlen] 136 | o['labels'] = o['labels'][:, :decoder_seqlen] 137 | return o 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 143 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 144 | 145 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 146 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 147 | config.decoder_start_token_id = tokenizer.cls_token_id 148 | # 缓存数据集 149 | # 检测是否存在 output/dataset_0-train.record ,不存在则制作数据集 150 | if data_args.do_train: 151 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 152 | if data_args.do_eval: 153 | dataHelper.make_dataset_with_args(data_args.eval_file, shuffle=False,mode='eval') 154 | if data_args.do_test: 155 | dataHelper.make_dataset_with_args(data_args.test_file, shuffle=False,mode='test') 156 | 157 | -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/evaluate_pclue.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/1/29 15:02 3 | import json,pylcs 4 | from rouge import Rouge 5 | import numpy as np 6 | import logging 7 | 8 | """ 9 | 计算pCLUE任务总分,及子分数 10 | """ 11 | 12 | def f1_sim(text_a, text_b): 13 | """F1相似度 14 | 说明:算出两个文本的最长公共子序列长度,然后乘2并处以两者 15 | 长度之和。推荐用pylcs算,速度较快。 16 | """ 17 | if not text_a and not text_b: 18 | return 0. 19 | else: 20 | lcs = pylcs.lcs(text_a, text_b) 21 | return 2. * lcs / (len(text_a) + len(text_b)) 22 | 23 | def rouge_l_zh(target, pred): 24 | """计算Rouge-l得分,Rouge-l指标常用于评估自动文本摘要及翻译任务 25 | target: 真实标签 26 | pred: 预测标签""" 27 | if not(isinstance(target, str) or isinstance(pred, str)): 28 | logging.info("target或pred为非字符串!请检查!") 29 | return 30 | else: 31 | rouge = Rouge() 32 | scores = rouge.get_scores(" ".join(list(pred)), " ".join(list(target))) 33 | score = scores[0]["rouge-l"] 34 | return score["f"] 35 | 36 | def normalize(text): 37 | """简单的文本标准化 38 | """ 39 | return ' '.join(text.lower().split()) 40 | 41 | 42 | def evaluate_pclue_file_fn(predict_file,target_file): 43 | predict_lines = open(predict_file, 'r').readlines() 44 | target_lines = open(target_file, 'r').readlines() 45 | 46 | return evaluate_pclue_fn(predict_file,target_file) 47 | 48 | def evaluate_pclue_fn(predict_lines,target_lines): 49 | """ 50 | 计算pclue的成绩 51 | :param predict_file: 预测文件 52 | :param target_file: 正确的文件 53 | :return: 一个dict,包括总分score,以及各个部分的分数(mrc, generate, classify, nli) 54 | """ 55 | # 1.记录 56 | classify_list=[] 57 | mrc_list=[] 58 | generate_list=[] 59 | nli_list=[] 60 | for i, target_line in enumerate(target_lines): 61 | # e.g. target_line = {"target": "不同"} 62 | predict_line=predict_lines[i] 63 | target_answer=json.loads(target_line.replace(",",","))["target"] # 正确的标签 64 | if isinstance(target_answer, list): # 将列表转换为字符串,如关键词生成 65 | target_answer = ",".join(target_answer) 66 | target_answer=normalize(target_answer) 67 | predict_answer=json.loads(predict_line)["target"] # 预测的标签 68 | predict_answer=normalize(predict_answer) 69 | # print(i,"target_answer:",target_answer,";predict_answer:",predict_answer) 70 | 71 | type=json.loads(target_line.replace(",",","))["type"] # 替换可能存在问题的数据,如有,以便能加载为json 72 | if type=='classify' or type=='anaphora_resolution': # 分类 73 | label_temp=True if target_answer==predict_answer else False 74 | classify_list.append(label_temp) 75 | elif type=='mrc': # 阅读理解 76 | em=1 if target_answer==predict_answer else 0 77 | f1=f1_sim(predict_answer,target_answer) 78 | mrc_list.append((em, f1)) 79 | elif type=='generate': # 生成 80 | rouge_l=rouge_l_zh(target_answer, predict_answer) 81 | generate_list.append(rouge_l) 82 | elif type=='nli': # 推理 83 | label_temp = True if target_answer == predict_answer else False 84 | nli_list.append(label_temp) 85 | else: 86 | print("error...predict_line:",predict_line,";target_line:",target_line) 87 | break # 中断运行 88 | # if predict_answer==target_answer: count_right=count_right+1 89 | if i<10: print(i, 'target_answer:',target_answer,";predict_answer:",predict_answer) # 显示部分内容 90 | 91 | 92 | # 2.计算最后的得分 93 | classify_score=np.average(classify_list) 94 | nli_score=np.average(nli_list) 95 | generate_score=np.average(generate_list) 96 | mrc_em_score=np.average([x[0] for x in mrc_list]) 97 | mrc_f1_score=np.average([x[1] for x in mrc_list]) 98 | mrc_score=np.average([mrc_em_score,mrc_f1_score]) 99 | # 计算总分 100 | score=np.average([classify_score,nli_score,generate_score,mrc_score]) 101 | # 保存分数 102 | result_dict={"score":score,"classify_score":classify_score,"nli_score":nli_score,"generate_score":generate_score, 103 | "mrc_em_score":mrc_em_score,"mrc_f1_score":mrc_f1_score} 104 | return result_dict 105 | 106 | 107 | if __name__ == '__main__': 108 | # 预测的文件,以及正确的文件 109 | target_file='datasets/pCLUE_test_public_1.json' 110 | predict_file='datasets/pCLUE_test_public_1.json' 111 | result=evaluate_pclue_fn(predict_file,target_file) 112 | print("result:",result) -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_base_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2048, 6 | "d_kv": 64, 7 | "d_model": 768, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 12, 17 | "num_layers": 12, 18 | "num_decoder_layers": 12, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_base_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_large_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2816, 6 | "d_kv": 64, 7 | "d_model": 1024, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 16, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_large_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_small_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 1024, 6 | "d_kv": 64, 7 | "d_model": 512, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 6, 17 | "num_layers": 8, 18 | "num_decoder_layers": 8, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_small_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_xl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 5120, 6 | "d_kv": 64, 7 | "d_model": 2048, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 32, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_xl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_xxl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 10240, 6 | "d_kv": 64, 7 | "d_model": 4096, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 64, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/t5_xxl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/prompt_t5_pretrain/task_prompt_t5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #reference: https://github.com/clue-ai/PromptCLUE/blob/main/Fine_tuning_PyTorch.ipynb 3 | 4 | import numpy as np 5 | import torch 6 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 7 | from deep_training.nlp.models.transformer import TransformerForSeq2SeqLM 8 | from deep_training.utils.trainer import SimpleModelCheckpoint 9 | from lightning import Trainer 10 | from torch.utils.data import DataLoader, IterableDataset 11 | from transformers import HfArgumentParser, BertTokenizer 12 | 13 | from data_utils import NN_DataHelper,train_info_args 14 | from evaluate_pclue import evaluate_pclue_fn 15 | 16 | 17 | class MyTransformer(TransformerForSeq2SeqLM, with_pl=True): 18 | def __init__(self, *args, **kwargs): 19 | super(MyTransformer, self).__init__(*args, **kwargs) 20 | 21 | 22 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 23 | def __init__(self, *args, **kwargs): 24 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 25 | self.weight_file = './best.pt' 26 | 27 | @staticmethod 28 | def generate_text(pl_module: MyTransformer, prefix, tokenizer, max_target_length, device=0): 29 | device = torch.device('cuda:{}'.format(device)) 30 | # 简易测试生成 31 | o = tokenizer.encode_plus(prefix, truncation=True, max_length=512, return_attention_mask=False, 32 | return_token_type_ids=False) 33 | gen_ids, gen_tokens = [], [] 34 | batch = {} 35 | for i in range(max_target_length): 36 | batch.clear() 37 | batch['input_ids'] = [o['input_ids']] 38 | batch['decoder_input_ids'] = [[tokenizer.cls_token_id] + gen_ids] 39 | for k in batch: 40 | batch[k] = torch.tensor(batch[k], dtype=torch.int32) 41 | for k in batch: 42 | batch[k] = batch[k].to(device) 43 | out = pl_module.test_step(batch, 0) 44 | logits = out['outputs'][0] 45 | logits = np.argmax(logits[:, -1], axis=-1) 46 | logits = logits[0] 47 | gen_ids.append(logits) 48 | token = tokenizer._convert_id_to_token(logits) 49 | if token.startswith('##'): 50 | token = token.replace('##', '') 51 | gen_tokens.append(token) 52 | return ''.join(gen_tokens) 53 | 54 | def on_save_model( 55 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 56 | ) -> None: 57 | # 保存权重 58 | super(MySimpleModelCheckpoint, self).on_save_model(trainer, pl_module) 59 | prefixs = [('law', '湖南高院给大学生送上一堂反电诈的“开学第一课”'), 60 | ('law', '最高检:检察公益诉讼5年追偿修复生态、治理环境费93.5亿'), 61 | ('classify', 62 | '我可以用以下的句子:“花呗在什么时间段可以用”,来替换这个句子:“什么时候用了花贝”,并且它们有相同的意思?。选项:是的,不是。答案:'), 63 | ('classify', 64 | '摘要:针对水平受荷桩在不同的长径比和桩土刚度比条件下可以表现出刚性桩、半刚性桩或柔性桩变形特性的特点,运用刚性桩和柔性桩的临界桩长计算公式,结合相似原理,推导了重力加速度为1g条件下缩尺模型桩与原型桩的临界桩长相似比,与几何相似比进行对比,评判模型桩和原型桩的变形特性,分析桩身材料模量相似比与几何相似比对模型桩和原型桩变形特性相似性的影响,并通过有限元方法进行数值模拟验证.研究结果表明:桩身材料模量是控制模型桩与原型桩满足变形特性相似的主要参数;直接采用原型桩材和原型土开展的模型试验与原型桩的变形特性不具有相似性,但通过选择模量相似比接近于几何相似比的模型桩材可以使得模型试验结果与原型相似.\n 以下的关键词都是这篇摘要合适的关键词吗?关键词:几何,模型试验,特性,相似性。答案是:\n选项:是的,不是\n答案:'), 65 | ('classify', 66 | '下面两个句子语义是“相同”或“不同”?“我买商品怎么用不了花呗”,“我的花呗怎么用不了”。选项:相同,不同。答案:'), 67 | ] 68 | print('*' * 30) 69 | device = trainer.global_rank 70 | self.tokenizer: BertTokenizer 71 | tokenizer = self.tokenizer 72 | data_args = self.data_args 73 | for prefix in prefixs: 74 | print(prefix[0], prefix[1]) 75 | prefix = prefix[1] 76 | output = MySimpleModelCheckpoint.generate_text(pl_module, prefix, tokenizer, 77 | data_args.max_target_length, device=device) 78 | print('input', prefix) 79 | print('output', output) 80 | print() 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 89 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 90 | # 保存最小loss模型 91 | checkpoint_callback = MySimpleModelCheckpoint(monitor="loss", 92 | every_n_epochs = 1, 93 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 94 | trainer = Trainer( 95 | callbacks=[checkpoint_callback], 96 | max_epochs=training_args.max_epochs, 97 | max_steps=training_args.max_steps, 98 | accelerator="gpu", 99 | devices=data_args.devices, 100 | enable_progress_bar=True, 101 | default_root_dir=data_args.output_dir, 102 | gradient_clip_val=training_args.max_grad_norm, 103 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 104 | num_sanity_val_steps=0, 105 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 106 | ) 107 | 108 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 109 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 110 | 111 | config.decoder_start_token_id = tokenizer.cls_token_id 112 | # 额外参数 113 | checkpoint_callback.tokenizer = tokenizer 114 | checkpoint_callback.data_args = data_args 115 | 116 | # 缓存数据集 117 | if data_args.do_train: 118 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False,shuffle=True,mode='train') 119 | if data_args.do_eval: 120 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 121 | if data_args.do_test: 122 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 123 | 124 | 125 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 126 | 127 | if not data_args.convert_onnx: 128 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 129 | with_load_memory=False, 130 | with_record_iterable_dataset=True, 131 | collate_fn=dataHelper.collate_fn, 132 | batch_size=training_args.train_batch_size, 133 | shuffle=True, infinite=True, num_processes=trainer.world_size, 134 | process_index=trainer.global_rank) 135 | if train_datasets is not None: 136 | trainer.fit(model, train_dataloaders=train_datasets) 137 | else: 138 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 139 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 140 | if eval_datasets is not None: 141 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 142 | 143 | if test_datasets is not None: 144 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 145 | else: 146 | # 加载权重 147 | model = MyTransformer.load_from_checkpoint('./best.pt', config=config, 148 | model_args=model_args, 149 | training_args=training_args) 150 | model.convert_to_onnx('./best.onnx') 151 | -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:12 3 | # @Author:XIE392 4 | # @File:data_utils.py 5 | import json 6 | 7 | import numpy as np 8 | import torch 9 | import typing 10 | 11 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, DataArguments 12 | from transformers import BertTokenizer, HfArgumentParser 13 | 14 | 15 | train_info_args = { 16 | 'devices': 1, 17 | 'data_backend': 'record', 18 | 'model_type': 't5', 19 | # 'model_name_or_path': '/data/nlp/pre_models/torch/', 20 | 'tokenizer_name': './t5_small_config', 21 | 'config_name': './t5_small_config/config.json', 22 | 'convert_onnx': False, # 转换onnx模型 23 | 'do_train': True, 24 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 25 | 'learning_rate': 5e-5, 26 | 'max_epochs': 3, 27 | 'train_batch_size': 10, 28 | 'test_batch_size': 2, 29 | 'adam_epsilon': 1e-8, 30 | 'gradient_accumulation_steps': 1, 31 | 'max_grad_norm': 1.0, 32 | 'weight_decay': 0, 33 | 'warmup_steps': 0, 34 | 'output_dir': './output', 35 | 'train_max_seq_length': 512, 36 | 'eval_max_seq_length': 512, 37 | 'test_max_seq_length': 512, 38 | 'max_target_length': 64, 39 | } 40 | 41 | class NN_DataHelper(DataHelper): 42 | # 切分词 43 | def on_data_process(self, data: typing.Any, mode: str): 44 | tokenizer: BertTokenizer 45 | max_seq_length = self.max_seq_length_dict[mode] 46 | tokenizer = self.tokenizer 47 | 48 | x = data 49 | def get_tokenizer_output(text): 50 | o1 = tokenizer.encode_plus(text, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 51 | 52 | input_ids = np.asarray(o1['input_ids'], dtype=np.int64) 53 | attention_mask = np.asarray(o1['attention_mask'], dtype=np.int64) 54 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 55 | pad_len = max_seq_length - seqlen 56 | if pad_len > 0: 57 | pad_val = tokenizer.pad_token_id 58 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 59 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 60 | 61 | out = { 62 | 'input_ids': input_ids, 63 | 'attention_mask': attention_mask, 64 | 'seqlen': seqlen, 65 | } 66 | return out 67 | 68 | o1 = get_tokenizer_output(x[0]) 69 | o2 = get_tokenizer_output(x[1]) 70 | 71 | d = o1 72 | 73 | d['decoder_input_ids'] = o2['input_ids'] 74 | d['decoder_attention_mask'] = o2['attention_mask'] 75 | d['decoder_seqlen'] = o2['seqlen'] 76 | 77 | labels = np.ones_like(d['decoder_input_ids']) * -100 78 | labels[:o2['seqlen']-1] = d['decoder_input_ids'][1:o2['seqlen']] 79 | return d 80 | 81 | # 读取文件 82 | def on_get_corpus(self, files: typing.List, mode: str): 83 | D = [] 84 | for filename in files: 85 | with open(filename, mode='r', encoding='utf-8') as f: 86 | lines = f.readlines() 87 | for i, line in enumerate(lines): 88 | jd = json.loads(line) 89 | D.append((jd['content'], jd['title'])) 90 | return D 91 | 92 | def collate_fn(self, batch): 93 | o = {} 94 | for i, b in enumerate(batch): 95 | if i == 0: 96 | for k in b: 97 | o[k] = [torch.tensor(b[k])] 98 | else: 99 | for k in b: 100 | o[k].append(torch.tensor(b[k])) 101 | for k in o: 102 | o[k] = torch.stack(o[k]) 103 | 104 | 105 | max_len = torch.max(o.pop('seqlen')) 106 | o['input_ids'] = o['input_ids'][:, :max_len] 107 | o['attention_mask'] = o['attention_mask'][:, :max_len] 108 | 109 | max_len = torch.max(o.pop('decoder_seqlen')) 110 | o['decoder_input_ids'] = o['decoder_input_ids'][:, :max_len] 111 | o['decoder_attention_mask'] = o['decoder_attention_mask'][:, :max_len] 112 | o['labels'] = o['labels'][:, :max_len] 113 | return o 114 | 115 | 116 | if __name__ == '__main__': 117 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 118 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 119 | 120 | 121 | 122 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 123 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 124 | 125 | 126 | # 缓存数据集 127 | if data_args.do_train: 128 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 129 | if data_args.do_eval: 130 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 131 | if data_args.do_test: 132 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_base_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2048, 6 | "d_kv": 64, 7 | "d_model": 768, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 12, 17 | "num_layers": 12, 18 | "num_decoder_layers": 12, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_base_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_large_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2816, 6 | "d_kv": 64, 7 | "d_model": 1024, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 16, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_large_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_small_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 1024, 6 | "d_kv": 64, 7 | "d_model": 512, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 6, 17 | "num_layers": 8, 18 | "num_decoder_layers": 8, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_small_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_xl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 5120, 6 | "d_kv": 64, 7 | "d_model": 2048, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 32, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_xl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_xxl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 10240, 6 | "d_kv": 64, 7 | "d_model": 4096, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 64, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/t5_xxl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/seq2seq_pretrain/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 5 | from deep_training.nlp.models.transformer import TransformerForSeq2SeqLM 6 | from lightning import Trainer 7 | from lightning.pytorch.callbacks import ModelCheckpoint 8 | from torch.nn import CrossEntropyLoss 9 | from torch.utils.data import DataLoader, IterableDataset 10 | from transformers import HfArgumentParser 11 | from data_utils import NN_DataHelper,train_info_args 12 | 13 | class MyTransformer(TransformerForSeq2SeqLM, with_pl=True): 14 | def __init__(self, *args, **kwargs): 15 | super(MyTransformer, self).__init__(*args, **kwargs) 16 | self.loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) 17 | 18 | def compute_loss(self, *args, **batch) -> tuple: 19 | labels = batch.pop('labels', None) 20 | outputs = self.model(*args, **batch) 21 | lm_logits = outputs[0] 22 | if labels is not None: 23 | loss = self.loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 24 | outputs = (loss, lm_logits, labels) 25 | else: 26 | outputs = (lm_logits,) 27 | return outputs 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 32 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 33 | 34 | checkpoint_callback = ModelCheckpoint(monitor="loss", save_top_k=5, 35 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 36 | trainer = Trainer( 37 | callbacks=[checkpoint_callback], 38 | max_epochs=training_args.max_epochs, 39 | max_steps=training_args.max_steps, 40 | accelerator="gpu", 41 | devices=data_args.devices, 42 | enable_progress_bar=True, 43 | default_root_dir=data_args.output_dir, 44 | gradient_clip_val=training_args.max_grad_norm, 45 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 46 | num_sanity_val_steps=0, 47 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 48 | ) 49 | 50 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 51 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 52 | 53 | 54 | # 缓存数据集 55 | if data_args.do_train: 56 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 57 | if data_args.do_eval: 58 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 59 | if data_args.do_test: 60 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 61 | 62 | 63 | 64 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 65 | 66 | if not data_args.convert_onnx: 67 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 68 | with_load_memory=False, 69 | with_record_iterable_dataset=True, 70 | collate_fn=dataHelper.collate_fn, 71 | batch_size=training_args.train_batch_size, 72 | shuffle=True, infinite=True, num_processes=trainer.world_size, 73 | process_index=trainer.global_rank) 74 | if train_datasets is not None: 75 | trainer.fit(model, train_dataloaders=train_datasets) 76 | else: 77 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 78 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 79 | if eval_datasets is not None: 80 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 81 | 82 | if test_datasets is not None: 83 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 84 | -------------------------------------------------------------------------------- /pretraining/simbert-v2_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:15 3 | # @Author:XIE392 4 | # @File:data_utils.py 5 | import json 6 | import typing 7 | 8 | import numpy as np 9 | import torch 10 | from deep_training.data_helper import DataHelper, ModelArguments, DataArguments, TrainingArguments 11 | from transformers import BertTokenizer, HfArgumentParser 12 | 13 | train_info_args = { 14 | 'devices': '1', 15 | 'data_backend': 'record', 16 | 'model_type': 'bert', 17 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 18 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 19 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 20 | 'convert_onnx': False, # 转换onnx模型 21 | 'do_train': True, 22 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 23 | 'max_steps': 100000, 24 | 'optimizer': 'adamw', 25 | 'learning_rate': 5e-5, 26 | 'train_batch_size': 10, 27 | 'test_batch_size': 2, 28 | 'adam_epsilon': 1e-8, 29 | 'gradient_accumulation_steps': 1, 30 | 'max_grad_norm': 1.0, 31 | 'weight_decay': 0, 32 | 'warmup_steps': 0, 33 | 'output_dir': './output', 34 | 'max_seq_length': 512, 35 | 'max_target_length': 50 36 | } 37 | 38 | 39 | 40 | class NN_DataHelper(DataHelper): 41 | # 切分词 42 | def on_data_process(self, data: typing.Any, mode: str): 43 | tokenizer: BertTokenizer 44 | max_seq_length = self.max_seq_length_dict[mode] 45 | tokenizer = self.tokenizer 46 | 47 | x = data 48 | assert isinstance(x, tuple) 49 | 50 | o = tokenizer(text=x[0], text_pair=x[1], max_length=max_seq_length, truncation=True, 51 | add_special_tokens=True) 52 | 53 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 54 | token_type_ids = np.asarray(o['token_type_ids'], dtype=np.int64) 55 | 56 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 57 | pad_len = max_seq_length - len(input_ids) 58 | if pad_len > 0: 59 | pad_val = tokenizer.pad_token_id 60 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 61 | token_type_ids = np.pad(token_type_ids, (0, pad_len), 'constant', constant_values=(0, 0)) 62 | d = { 63 | 'input_ids': input_ids, 64 | 'token_type_ids': token_type_ids, 65 | 'labels': input_ids, 66 | 'seqlen': seqlen 67 | } 68 | return d 69 | 70 | # 读取文件 71 | def on_get_corpus(self, files: typing.List, mode: str): 72 | D = [] 73 | for filename in files: 74 | with open(filename, mode='r', encoding='utf-8') as f: 75 | lines = f.readlines() 76 | for i, line in enumerate(lines): 77 | jd = json.loads(line) 78 | D.append((jd['content'], jd['title'])) 79 | return D[0:1000] if mode == 'train' else D[:100] 80 | 81 | def collate_fn(self, batch): 82 | o = {} 83 | for i, b in enumerate(batch): 84 | if i == 0: 85 | for k in b: 86 | o[k] = [torch.tensor(b[k])] 87 | else: 88 | for k in b: 89 | o[k].append(torch.tensor(b[k])) 90 | for k in o: 91 | o[k] = torch.stack(o[k]) 92 | 93 | max_len = torch.max(o.pop('seqlen')) 94 | 95 | o['input_ids'] = o['input_ids'][:, :max_len] 96 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 97 | o['labels'] = o['labels'][:, :max_len] 98 | return o 99 | 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 105 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 106 | 107 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 108 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 109 | 110 | # 缓存数据集 111 | if data_args.do_train: 112 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 113 | if data_args.do_eval: 114 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 115 | if data_args.do_test: 116 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /pretraining/simbert-v2_pretrain/task_simsce_unilm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 5 | from deep_training.nlp.layers.mask import unilm_mask 6 | from deep_training.nlp.losses.contrast import SimcseLoss 7 | from deep_training.nlp.models.transformer import TransformerModelForUnilm 8 | from lightning import Trainer 9 | from lightning.pytorch.callbacks import ModelCheckpoint 10 | from torch import nn 11 | from torch.utils.data import DataLoader, IterableDataset 12 | from transformers import HfArgumentParser 13 | from data_utils import NN_DataHelper,train_info_args 14 | 15 | 16 | class MyTransformer(TransformerModelForUnilm, with_pl=True): 17 | def __init__(self, *args, **kwargs): 18 | super(MyTransformer, self).__init__(*args, **kwargs) 19 | config = self.config 20 | self.sim_head = nn.Linear(config.hidden_size, 512, bias=False) 21 | self.loss_fn = SimcseLoss() 22 | 23 | def get_model_lr(self,model=None,lr=None): 24 | return super(MyTransformer, self).get_model_lr() + [ 25 | (self.sim_head, self.config.task_specific_params['learning_rate_for_task']) 26 | ] 27 | 28 | def compute_loss(self, *args, **batch) -> tuple: 29 | if self.training: 30 | batch = {k: torch.repeat_interleave(v, 2, dim=0) for k, v in batch.items()} 31 | labels = batch.pop('labels', None) 32 | batch['attention_mask'] = unilm_mask(batch['token_type_ids']) 33 | outputs = self.model(*args, **batch) 34 | lm_logits = self.model.lm_head(outputs[0]) 35 | simcse_logits = self.sim_head(outputs[1]) 36 | 37 | if labels is not None: 38 | shift_logits = lm_logits[..., :-1, :].contiguous() 39 | shift_labels = labels[..., 1:].contiguous() 40 | loss1 = self.model.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 41 | loss2 = self.loss_fn(simcse_logits) 42 | loss = loss1 + loss2 43 | loss_dict = { 44 | 'loss': loss, 45 | 'unilm_loss': loss1, 46 | 'simcse_loss': loss2, 47 | } 48 | outputs = (loss_dict, lm_logits, simcse_logits) 49 | self.log_dict(loss_dict, prog_bar=True) 50 | else: 51 | outputs = (lm_logits, simcse_logits) 52 | return outputs 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 57 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 58 | 59 | checkpoint_callback = ModelCheckpoint(monitor="loss", 60 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 61 | trainer = Trainer( 62 | callbacks=[checkpoint_callback], 63 | max_epochs=training_args.max_epochs, 64 | max_steps=training_args.max_steps, 65 | accelerator="gpu", 66 | devices=data_args.devices, 67 | enable_progress_bar=True, 68 | default_root_dir=data_args.output_dir, 69 | gradient_clip_val=training_args.max_grad_norm, 70 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 71 | num_sanity_val_steps=0, 72 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 73 | ) 74 | 75 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 76 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 77 | 78 | # 缓存数据集 79 | if data_args.do_train: 80 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train') 81 | if data_args.do_eval: 82 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 83 | if data_args.do_test: 84 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 85 | 86 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 87 | 88 | if not data_args.convert_onnx: 89 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 90 | with_load_memory=False, 91 | with_record_iterable_dataset=True, 92 | collate_fn=dataHelper.collate_fn, 93 | batch_size=training_args.train_batch_size, 94 | shuffle=True, infinite=True, num_processes=trainer.world_size, 95 | process_index=trainer.global_rank) 96 | if train_datasets is not None: 97 | trainer.fit(model, train_dataloaders=train_datasets) 98 | else: 99 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 100 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 101 | if eval_datasets is not None: 102 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 103 | 104 | if test_datasets is not None: 105 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 106 | -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:09 3 | # @File:data_utils.py 4 | import json 5 | import random 6 | 7 | import torch 8 | import typing 9 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, MlmDataArguments, DataArguments 10 | from deep_training.utils.maskedlm import make_mlm_wwm_sample 11 | from transformers import BertTokenizer, HfArgumentParser 12 | 13 | 14 | train_info_args = { 15 | 'devices': 1, 16 | 'data_backend': 'record', 17 | 'model_type': 't5', 18 | # 预训练模型路径 , 从0训练,则置空 19 | # 'model_name_or_path': '/data/nlp/pre_models/torch/', 20 | 'tokenizer_name': './t5_base_config', 21 | 'config_name': './t5_base_config/config.json', 22 | 'convert_onnx': False, # 转换onnx模型 23 | 'do_train': True, 24 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 25 | 'learning_rate': 5e-5, 26 | 'max_epochs': None, 27 | 'max_steps': 300000, 28 | 'train_batch_size': 8, 29 | 'test_batch_size': 2, 30 | 'adam_epsilon': 1e-8, 31 | 'gradient_accumulation_steps': 1, 32 | 'max_grad_norm': 1.0, 33 | 'weight_decay': 0.01, 34 | 'warmup_steps': 10000, 35 | 'output_dir': './output', 36 | 'train_max_seq_length': 512, 37 | 'eval_max_seq_length': 512, 38 | 'test_max_seq_length': 512, 39 | 'do_lower_case': True, 40 | 'do_whole_word_mask': True, 41 | 'max_predictions_per_seq': 20, 42 | 'dupe_factor': 5, 43 | 'masked_lm_prob': 0.15 44 | } 45 | 46 | 47 | class NN_DataHelper(DataHelper): 48 | # 切分词 49 | def on_data_process(self, data: typing.Any, mode: typing.Any): 50 | tokenizer: BertTokenizer 51 | max_seq_length = self.max_seq_length_dict[mode] 52 | tokenizer = self.tokenizer 53 | 54 | rng, do_whole_word_mask, max_predictions_per_seq, masked_lm_prob = self.external_kwargs['mlm_args'] 55 | 56 | documents = data 57 | document_text_string = ''.join(documents) 58 | document_texts = [] 59 | pos = 0 60 | while pos < len(document_text_string): 61 | text = document_text_string[pos:pos + max_seq_length - 2] 62 | pos += len(text) 63 | document_texts.append(text) 64 | # 返回多个文档 65 | document_nodes = [] 66 | for text in document_texts: 67 | node = make_mlm_wwm_sample(text, tokenizer, max_seq_length, rng, do_whole_word_mask, 68 | max_predictions_per_seq, masked_lm_prob) 69 | document_nodes.append(node) 70 | return document_nodes 71 | 72 | # 读取文件 73 | def on_get_corpus(self, files: typing.List, mode: str): 74 | D = [] 75 | line_no = 0 76 | for input_file in files: 77 | with open(input_file, 'r', encoding='utf-8') as f: 78 | lines = f.readlines() 79 | for line in lines: 80 | jd = json.loads(line) 81 | if not jd: 82 | continue 83 | text = jd['content'] 84 | docs = text.split('\n\n') 85 | D.append([doc for doc in docs if doc]) 86 | line_no += 1 87 | 88 | if line_no > 1000: 89 | break 90 | 91 | if line_no % 10000 == 0: 92 | print('read_line', line_no) 93 | print(D[-1]) 94 | return D 95 | 96 | def collate_fn(self, batch): 97 | o = {} 98 | for i, b in enumerate(batch): 99 | if i == 0: 100 | for k in b: 101 | o[k] = [torch.tensor(b[k])] 102 | else: 103 | for k in b: 104 | o[k].append(torch.tensor(b[k])) 105 | for k in o: 106 | o[k] = torch.stack(o[k]) 107 | 108 | max_len = torch.max(o.pop('seqlen')) 109 | 110 | o['input_ids'] = o['input_ids'][:, :max_len] 111 | o['attention_mask'] = o['attention_mask'][:, :max_len] 112 | if 'token_type_ids' in o: 113 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 114 | 115 | input_ids = o['input_ids'] 116 | masked_lm_positions = o.pop('masked_lm_positions') 117 | masked_lm_ids = o.pop('masked_lm_ids') 118 | masked_lm_weights = o.pop('masked_lm_weights') 119 | labels = torch.clone(input_ids) 120 | mask = torch.zeros_like(input_ids) 121 | for i, (index, value, weight) in enumerate(zip(masked_lm_positions, masked_lm_ids, masked_lm_weights.long())): 122 | s = torch.sum(weight) 123 | labels[i, index[:s]] = value[:s] 124 | mask[i, index[:s]] = 1 125 | o['labels'] = labels 126 | o['mask'] = mask 127 | return o 128 | 129 | if __name__ == '__main__': 130 | 131 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, MlmDataArguments)) 132 | model_args, training_args, data_args, mlm_data_args = parser.parse_dict(train_info_args) 133 | 134 | rng = random.Random(training_args.seed) 135 | dataHelper = NN_DataHelper(model_args, training_args, data_args, mlm_args=( 136 | rng, mlm_data_args.do_whole_word_mask, mlm_data_args.max_predictions_per_seq, mlm_data_args.masked_lm_prob)) 137 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 138 | 139 | # 缓存数据集 140 | if data_args.do_train: 141 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False,shuffle=True,mode='train',dupe_factor=mlm_data_args.dupe_factor) 142 | if data_args.do_eval: 143 | dataHelper.make_dataset_with_args(data_args.eval_file,shuffle=False,mode='eval') 144 | if data_args.do_test: 145 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_base_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5EncoderModel" 4 | ], 5 | "d_ff": 2048, 6 | "d_kv": 64, 7 | "d_model": 768, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": false, 12 | "is_decoder": false, 13 | "layer_norm_epsilon": 1e-8, 14 | "model_type": "t5", 15 | "feed_forward_proj": "gated-gelu", 16 | "n_positions": 512, 17 | "num_heads": 12, 18 | "num_layers": 12, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_base_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_large_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5EncoderModel" 4 | ], 5 | "d_ff": 2816, 6 | "d_kv": 64, 7 | "d_model": 1024, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": false, 12 | "is_decoder": false, 13 | "layer_norm_epsilon": 1e-8, 14 | "model_type": "t5", 15 | "feed_forward_proj": "gated-gelu", 16 | "n_positions": 512, 17 | "num_heads": 16, 18 | "num_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_large_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_xl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5EncoderModel" 4 | ], 5 | "d_ff": 5120, 6 | "d_kv": 64, 7 | "d_model": 2048, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": false, 12 | "is_decoder": false, 13 | "layer_norm_epsilon": 1e-8, 14 | "model_type": "t5", 15 | "feed_forward_proj": "gated-gelu", 16 | "n_positions": 512, 17 | "num_heads": 32, 18 | "num_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_xl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_xxl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5EncoderModel" 4 | ], 5 | "d_ff": 10240, 6 | "d_kv": 64, 7 | "d_model": 4096, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": false, 12 | "is_decoder": false, 13 | "layer_norm_epsilon": 1e-8, 14 | "model_type": "t5", 15 | "feed_forward_proj": "gated-gelu", 16 | "n_positions": 512, 17 | "num_heads": 64, 18 | "num_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/t5_xxl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /pretraining/t5encoder_mlm_pretrain/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import torch 5 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments, MlmDataArguments 6 | from deep_training.nlp.models.t5encoder import TransformerT5EncoderMaskedLM 7 | from lightning import Trainer 8 | from lightning.pytorch.callbacks import ModelCheckpoint 9 | from torch.nn import CrossEntropyLoss 10 | from torch.utils.data import DataLoader, IterableDataset 11 | from transformers import HfArgumentParser 12 | from data_utils import NN_DataHelper,train_info_args 13 | 14 | mask_token_id = None 15 | 16 | class MyTransformer(TransformerT5EncoderMaskedLM, with_pl=True): 17 | def __init__(self, *args, **kwargs): 18 | super(MyTransformer, self).__init__(*args, **kwargs) 19 | self.loss_fct = CrossEntropyLoss(reduction='mean') 20 | 21 | def compute_loss_mlm(self, y_trues, y_preds, mask): 22 | y_preds = torch.transpose(y_preds, 1, 2) 23 | masked_lm_loss = self.loss_fct(y_preds, y_trues) 24 | masked_lm_loss = torch.sum(mask * masked_lm_loss) / (torch.sum(mask) + 1e-8) 25 | return masked_lm_loss 26 | 27 | def compute_acc(self, y_trues, y_preds, mask): 28 | acc = torch.eq(torch.argmax(y_preds, dim=-1), y_trues) 29 | acc = torch.sum(mask * acc) / (torch.sum(mask) + 1e-8) 30 | return acc 31 | 32 | def compute_loss(self, *args, **batch) -> tuple: 33 | labels = None 34 | mask = None 35 | if 'labels' in batch: 36 | labels = batch.pop('labels') 37 | mask = batch.pop('mask') 38 | 39 | outputs = self.model(*args, **batch) 40 | logits = outputs[0] 41 | if labels is not None: 42 | loss = self.compute_loss_mlm(labels, logits, mask) 43 | acc = self.compute_acc(labels, logits, batch['attention_mask']) 44 | mlm_acc = self.compute_acc(labels, logits, mask) 45 | loss = { 46 | 'loss': loss, 47 | 'acc': acc, 48 | 'mlm_acc': mlm_acc, 49 | } 50 | outputs = (loss, logits, labels) 51 | else: 52 | outputs = (logits,) 53 | return outputs 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, MlmDataArguments)) 59 | model_args, training_args, data_args, mlm_data_args = parser.parse_dict(train_info_args) 60 | 61 | checkpoint_callback = ModelCheckpoint(save_last=True, 62 | verbose=True, 63 | monitor="loss", 64 | save_top_k=5, 65 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 66 | trainer = Trainer( 67 | callbacks=[checkpoint_callback], 68 | max_epochs=training_args.max_epochs, 69 | max_steps=training_args.max_steps, 70 | accelerator="gpu", 71 | devices=data_args.devices, 72 | enable_progress_bar=True, 73 | default_root_dir=data_args.output_dir, 74 | gradient_clip_val=training_args.max_grad_norm, 75 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 76 | num_sanity_val_steps=0, 77 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 78 | ) 79 | 80 | rng = random.Random(training_args.seed) 81 | dataHelper = NN_DataHelper(model_args, training_args, data_args,mlm_args = (rng, mlm_data_args.do_whole_word_mask, mlm_data_args.max_predictions_per_seq,mlm_data_args.masked_lm_prob)) 82 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 83 | mask_token_id = tokenizer.mask_token_id 84 | # 缓存数据集 85 | if data_args.do_train: 86 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False,shuffle=True,mode='train', dupe_factor=mlm_data_args.dupe_factor) 87 | if data_args.do_eval: 88 | dataHelper.make_dataset_with_args(data_args.eval_file,shuffle=False,mode='eval') 89 | if data_args.do_test: 90 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 91 | 92 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 93 | 94 | if not data_args.convert_onnx: 95 | train_datasets = dataHelper.load_random_sampler(dataHelper.train_files, 96 | with_load_memory=False, 97 | with_record_iterable_dataset=True, 98 | collate_fn=dataHelper.collate_fn, 99 | batch_size=training_args.train_batch_size, 100 | shuffle=True, infinite=True, num_processes=trainer.world_size, 101 | process_index=trainer.global_rank) 102 | if train_datasets is not None: 103 | trainer.fit(model, train_dataloaders=train_datasets) 104 | else: 105 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 106 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 107 | if eval_datasets is not None: 108 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 109 | 110 | if test_datasets is not None: 111 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 112 | -------------------------------------------------------------------------------- /task_classify/task_tnews.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | from deep_training.data_helper import DataHelper 9 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 10 | from deep_training.nlp.models.transformer import TransformerForSequenceClassification 11 | from deep_training.utils.trainer import SimpleModelCheckpoint 12 | from lightning import Trainer 13 | 14 | from sklearn.metrics import f1_score, classification_report 15 | from torch.utils.data import DataLoader, IterableDataset 16 | from tqdm import tqdm 17 | from transformers import HfArgumentParser, BertTokenizer 18 | 19 | train_info_args = { 20 | 'devices': 1, 21 | 'data_backend': 'memory_raw', 22 | 'model_type': 'bert', 23 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 24 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 25 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 26 | 'convert_onnx': False, # 转换onnx模型 27 | 'do_train': True, 28 | 'do_eval': True, 29 | 'train_file': [ '/data/nlp/nlp_train_data/clue/tnews/train.json'], 30 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/tnews/dev.json'], 31 | 'test_file': [ '/data/nlp/nlp_train_data/clue/tnews/test.json'], 32 | 'label_file': [ '/data/nlp/nlp_train_data/clue/tnews/labels.json'], 33 | 'learning_rate': 5e-5, 34 | 'max_epochs': 3, 35 | 'train_batch_size': 10, 36 | 'test_batch_size': 2, 37 | 'adam_epsilon': 1e-8, 38 | 'gradient_accumulation_steps': 1, 39 | 'max_grad_norm': 1.0, 40 | 'weight_decay': 0, 41 | 'warmup_steps': 0, 42 | 'output_dir': './output', 43 | 'train_max_seq_length': 380, 44 | 'eval_max_seq_length': 512, 45 | 'test_max_seq_length': 512, 46 | } 47 | 48 | 49 | class NN_DataHelper(DataHelper): 50 | # 切分词 51 | def on_data_process(self, data: typing.Any, mode: str): 52 | tokenizer: BertTokenizer 53 | max_seq_length = self.max_seq_length_dict[mode] 54 | tokenizer = self.tokenizer 55 | do_lower_case = tokenizer.do_lower_case 56 | label2id = self.label2id 57 | 58 | sentence, label_str = data 59 | 60 | o = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 61 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 62 | attention_mask = np.asarray(o['attention_mask'], dtype=np.int64) 63 | 64 | labels = np.asarray(label2id[label_str] if label_str is not None else 0, dtype=np.int64) 65 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 66 | pad_len = max_seq_length - len(input_ids) 67 | if pad_len > 0: 68 | pad_val = tokenizer.pad_token_id 69 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 70 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 71 | d = { 72 | 'input_ids': input_ids, 73 | 'attention_mask': attention_mask, 74 | 'labels': labels, 75 | 'seqlen': seqlen 76 | } 77 | return d 78 | 79 | # 读取标签 80 | def on_get_labels(self, files: typing.List[str]): 81 | if files is None: 82 | return None, None 83 | label_fname = files[0] 84 | is_json_file = label_fname.endswith('.json') 85 | D = set() 86 | with open(label_fname, 'r', encoding='utf-8') as f: 87 | lines = f.readlines() 88 | for line in lines: 89 | line = line.replace('\r\n', '').replace('\n', '') 90 | if not line: continue 91 | if is_json_file: 92 | jd = json.loads(line) 93 | line = jd['label'] 94 | D.add(line) 95 | label2id = {label: i for i, label in enumerate(D)} 96 | id2label = {i: label for i, label in enumerate(D)} 97 | return label2id, id2label 98 | 99 | # 读取文件 100 | def on_get_corpus(self, files: typing.List, mode: str): 101 | D = [] 102 | for filename in files: 103 | with open(filename, mode='r', encoding='utf-8') as f: 104 | lines = f.readlines() 105 | for line in lines: 106 | jd = json.loads(line) 107 | if not jd: 108 | continue 109 | D.append((jd['sentence'], jd.get('label', None))) 110 | return D[0:1000] if mode == 'train' else D[:100] 111 | 112 | def collate_fn(self,batch): 113 | o = {} 114 | for i, b in enumerate(batch): 115 | if i == 0: 116 | for k in b: 117 | o[k] = [torch.tensor(b[k])] 118 | else: 119 | for k in b: 120 | o[k].append(torch.tensor(b[k])) 121 | for k in o: 122 | o[k] = torch.stack(o[k]) 123 | 124 | max_len = torch.max(o.pop('seqlen')) 125 | 126 | o['input_ids'] = o['input_ids'][:, :max_len] 127 | o['attention_mask'] = o['attention_mask'][:, :max_len] 128 | if 'token_type_ids' in o: 129 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 130 | return o 131 | 132 | 133 | class MyTransformer(TransformerForSequenceClassification, with_pl=True): 134 | def __init__(self, *args, **kwargs): 135 | super(MyTransformer, self).__init__(*args, **kwargs) 136 | 137 | def compute_loss(self, *args, **batch) -> tuple: 138 | outputs = self.model(*args, **batch) 139 | labels = batch.get('labels', None) 140 | if labels is not None: 141 | loss, logits = outputs[0:2] 142 | acc = torch.sum(torch.eq(labels.view(-1), 143 | torch.argmax(logits, dim=1, keepdim=False))) / labels.view(-1).size()[0] 144 | loss_dict = { 145 | 'loss': loss, 146 | 'acc': acc 147 | } 148 | outputs = (loss_dict, logits, labels) 149 | else: 150 | outputs = (outputs[0],) 151 | 152 | return outputs 153 | 154 | # def validation_epoch_end(self, outputs: typing.Union[EPOCH_OUTPUT, typing.List[EPOCH_OUTPUT]]) -> None: 155 | # y_preds, y_trues = [], [] 156 | # for o in outputs: 157 | # preds, labels = o['outputs'] 158 | # preds = np.argmax(preds, -1) 159 | # for p, l in zip(preds, labels): 160 | # y_preds.append(p) 161 | # y_trues.append(int(l)) 162 | # 163 | # y_preds = np.asarray(y_preds, dtype=np.int32) 164 | # y_trues = np.asarray(y_trues, dtype=np.int32) 165 | # f1 = f1_score(y_trues, y_preds, average='micro') 166 | # report = classification_report(y_trues, y_preds, digits=4, 167 | # labels=list(self.config.label2id.values()), 168 | # target_names=list(self.config.label2id.keys())) 169 | # 170 | # print(f1, report) 171 | # self.log('val_f1', f1) 172 | 173 | 174 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 175 | def __init__(self, *args, **kwargs): 176 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 177 | self.weight_file = './best.pt' 178 | 179 | def on_save_model( 180 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 181 | ) -> None: 182 | pl_module: MyTransformer 183 | 184 | # 当前设备 185 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 186 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 187 | 188 | config = pl_module.config 189 | 190 | y_preds, y_trues = [], [] 191 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 192 | for k in batch: 193 | batch[k] = batch[k].to(device) 194 | o = pl_module.validation_step(batch, i) 195 | 196 | preds, labels = o['outputs'] 197 | preds = np.argmax(preds, -1) 198 | for p, l in zip(preds, labels): 199 | y_preds.append(p) 200 | y_trues.append(int(l)) 201 | 202 | y_preds = np.asarray(y_preds, dtype=np.int32) 203 | y_trues = np.asarray(y_trues, dtype=np.int32) 204 | f1 = f1_score(y_trues, y_preds, average='micro') 205 | report = classification_report(y_trues, y_preds, digits=4, 206 | labels=list(config.label2id.values()), 207 | target_names=list(config.label2id.keys())) 208 | 209 | print(f1, report) 210 | 211 | best_f1 = self.best.get('f1', -np.inf) 212 | print('current', f1, 'best', best_f1) 213 | if f1 >= best_f1: 214 | self.best['f1'] = f1 215 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 216 | trainer.save_checkpoint(self.weight_file) 217 | 218 | 219 | if __name__ == '__main__': 220 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 221 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 222 | 223 | checkpoint_callback = MySimpleModelCheckpoint(monitor="val_f1", every_n_epochs=1) 224 | trainer = Trainer( 225 | callbacks=[checkpoint_callback], 226 | max_epochs=training_args.max_epochs, 227 | max_steps=training_args.max_steps, 228 | accelerator="gpu", 229 | devices=data_args.devices, 230 | enable_progress_bar=True, 231 | default_root_dir=data_args.output_dir, 232 | gradient_clip_val=training_args.max_grad_norm, 233 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 234 | num_sanity_val_steps=0, 235 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 236 | ) 237 | 238 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 239 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 240 | 241 | # 缓存数据集 242 | if data_args.do_train: 243 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 244 | if data_args.do_eval: 245 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 246 | if data_args.do_test: 247 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 248 | 249 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 250 | 251 | if not data_args.convert_onnx: 252 | train_datasets = dataHelper.load_distributed_random_sampler( 253 | dataHelper.train_files, 254 | with_load_memory=True, 255 | collate_fn=dataHelper.collate_fn, 256 | batch_size=training_args.train_batch_size, 257 | num_processes = trainer.world_size, process_index=trainer.global_rank) 258 | 259 | if train_datasets is not None: 260 | trainer.fit(model, train_dataloaders=train_datasets) 261 | else: 262 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 263 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 264 | if eval_datasets is not None: 265 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 266 | 267 | if test_datasets is not None: 268 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 269 | -------------------------------------------------------------------------------- /task_classify/task_tnews_adversarial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | from deep_training.data_helper import DataHelper 9 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 10 | from deep_training.nlp.models.transformer import TransformerForSequenceClassification 11 | from deep_training.utils.trainer import SimpleModelCheckpoint 12 | from lightning import Trainer 13 | # 14 | from sklearn.metrics import f1_score, classification_report 15 | from torch.utils.data import DataLoader, IterableDataset 16 | from tqdm import tqdm 17 | from transformers import HfArgumentParser, BertTokenizer 18 | 19 | train_info_args = { 20 | 'devices': 1, 21 | 'data_backend': 'memory_raw', 22 | 'model_type': 'bert', 23 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 24 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 25 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 26 | 'convert_onnx': False, # 转换onnx模型 27 | 'do_train': True, 28 | 'do_eval': True, 29 | 'train_file': [ '/data/nlp/nlp_train_data/clue/tnews/train.json'], 30 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/tnews/dev.json'], 31 | 'test_file': [ '/data/nlp/nlp_train_data/clue/tnews/test.json'], 32 | 'label_file': [ '/data/nlp/nlp_train_data/clue/tnews/labels.json'], 33 | 'learning_rate': 5e-5, 34 | 'max_epochs': 3, 35 | 'train_batch_size': 10, 36 | 'test_batch_size': 2, 37 | 'adam_epsilon': 1e-8, 38 | 'gradient_accumulation_steps': 1, 39 | 'max_grad_norm': 1.0, 40 | 'weight_decay': 0, 41 | 'warmup_steps': 0, 42 | 'output_dir': './output', 43 | 'train_max_seq_length': 380, 44 | 'eval_max_seq_length': 512, 45 | 'test_max_seq_length': 512, 46 | # 对抗训练就一个配置 47 | 'adv': { 48 | 'mode': 'fgm', # None, fgm, fgsm_local, fgsm(不推荐), pgd, free_local, free(不推荐) 49 | 'emb_name': 'embedding', 50 | 'attack_iters': 2, # pgd 51 | 'minibatch_replays': 2, # free 52 | 'alpha': 0.5, # pgd,fgsm 53 | 'epsilon': 0.5, # pgd,fgm 54 | } 55 | } 56 | 57 | 58 | class NN_DataHelper(DataHelper): 59 | # 切分词 60 | def on_data_process(self, data: typing.Any, mode: str): 61 | tokenizer: BertTokenizer 62 | max_seq_length = self.max_seq_length_dict[mode] 63 | tokenizer = self.tokenizer 64 | do_lower_case = tokenizer.do_lower_case 65 | label2id = self.label2id 66 | 67 | sentence, label_str = data 68 | 69 | o = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 70 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 71 | attention_mask = np.asarray(o['attention_mask'], dtype=np.int64) 72 | 73 | labels = np.asarray(label2id[label_str] if label_str is not None else 0, dtype=np.int64) 74 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 75 | pad_len = max_seq_length - len(input_ids) 76 | if pad_len > 0: 77 | pad_val = tokenizer.pad_token_id 78 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 79 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 80 | d = { 81 | 'input_ids': input_ids, 82 | 'attention_mask': attention_mask, 83 | 'labels': labels, 84 | 'seqlen': seqlen 85 | } 86 | return d 87 | 88 | # 读取标签 89 | def on_get_labels(self, files: typing.List[str]): 90 | if files is None: 91 | return None, None 92 | label_fname = files[0] 93 | is_json_file = label_fname.endswith('.json') 94 | D = set() 95 | with open(label_fname, 'r', encoding='utf-8') as f: 96 | lines = f.readlines() 97 | for line in lines: 98 | line = line.replace('\r\n', '').replace('\n', '') 99 | if not line: continue 100 | if is_json_file: 101 | jd = json.loads(line) 102 | line = jd['label'] 103 | D.add(line) 104 | label2id = {label: i for i, label in enumerate(D)} 105 | id2label = {i: label for i, label in enumerate(D)} 106 | return label2id, id2label 107 | 108 | # 读取文件 109 | def on_get_corpus(self, files: typing.List, mode: str): 110 | D = [] 111 | for filename in files: 112 | with open(filename, mode='r', encoding='utf-8') as f: 113 | lines = f.readlines() 114 | for line in lines: 115 | jd = json.loads(line) 116 | if not jd: 117 | continue 118 | D.append((jd['sentence'], jd.get('label', None))) 119 | return D[0:1000] if mode == 'train' else D[:100] 120 | 121 | def collate_fn(self,batch): 122 | o = {} 123 | for i, b in enumerate(batch): 124 | if i == 0: 125 | for k in b: 126 | o[k] = [torch.tensor(b[k])] 127 | else: 128 | for k in b: 129 | o[k].append(torch.tensor(b[k])) 130 | for k in o: 131 | o[k] = torch.stack(o[k]) 132 | 133 | max_len = torch.max(o.pop('seqlen')) 134 | 135 | o['input_ids'] = o['input_ids'][:, :max_len] 136 | o['attention_mask'] = o['attention_mask'][:, :max_len] 137 | if 'token_type_ids' in o: 138 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 139 | return o 140 | 141 | 142 | class MyTransformer(TransformerForSequenceClassification, with_pl=True): 143 | def __init__(self, *args, **kwargs): 144 | super(MyTransformer, self).__init__(*args, **kwargs) 145 | 146 | def compute_loss(self, *args, **batch) -> tuple: 147 | outputs = self.model(*args, **batch) 148 | labels = batch.get('labels', None) 149 | if labels is not None: 150 | loss, logits = outputs[0:2] 151 | acc = torch.sum(torch.eq(labels.view(-1), 152 | torch.argmax(logits, dim=1, keepdim=False))) / labels.view(-1).size()[0] 153 | loss_dict = { 154 | 'loss': loss, 155 | 'acc': acc 156 | } 157 | outputs = (loss_dict, logits, labels) 158 | else: 159 | outputs = (outputs[0],) 160 | 161 | return outputs 162 | 163 | # def validation_epoch_end(self, outputs: typing.Union[EPOCH_OUTPUT, typing.List[EPOCH_OUTPUT]]) -> None: 164 | # y_preds, y_trues = [], [] 165 | # for o in outputs: 166 | # preds, labels = o['outputs'] 167 | # preds = np.argmax(preds, -1) 168 | # for p, l in zip(preds, labels): 169 | # y_preds.append(p) 170 | # y_trues.append(int(l)) 171 | # 172 | # y_preds = np.asarray(y_preds, dtype=np.int32) 173 | # y_trues = np.asarray(y_trues, dtype=np.int32) 174 | # f1 = f1_score(y_trues, y_preds, average='micro') 175 | # report = classification_report(y_trues, y_preds, digits=4, 176 | # labels=list(self.config.label2id.values()), 177 | # target_names=list(self.config.label2id.keys())) 178 | # 179 | # print(f1, report) 180 | # self.log('val_f1', f1) 181 | 182 | 183 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 184 | def __init__(self, *args, **kwargs): 185 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 186 | self.weight_file = './best.pt' 187 | 188 | def on_save_model( 189 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 190 | ) -> None: 191 | pl_module: MyTransformer 192 | 193 | # 当前设备 194 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 195 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 196 | 197 | config = pl_module.config 198 | 199 | y_preds, y_trues = [], [] 200 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 201 | for k in batch: 202 | batch[k] = batch[k].to(device) 203 | o = pl_module.validation_step(batch, i) 204 | 205 | preds, labels = o['outputs'] 206 | preds = np.argmax(preds, -1) 207 | for p, l in zip(preds, labels): 208 | y_preds.append(p) 209 | y_trues.append(int(l)) 210 | 211 | y_preds = np.asarray(y_preds, dtype=np.int32) 212 | y_trues = np.asarray(y_trues, dtype=np.int32) 213 | f1 = f1_score(y_trues, y_preds, average='micro') 214 | report = classification_report(y_trues, y_preds, digits=4, 215 | labels=list(config.label2id.values()), 216 | target_names=list(config.label2id.keys())) 217 | 218 | print(f1, report) 219 | 220 | best_f1 = self.best.get('f1', -np.inf) 221 | print('current', f1, 'best', best_f1) 222 | if f1 >= best_f1: 223 | self.best['f1'] = f1 224 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 225 | trainer.save_checkpoint(self.weight_file) 226 | 227 | 228 | if __name__ == '__main__': 229 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 230 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 231 | 232 | checkpoint_callback = MySimpleModelCheckpoint(monitor="val_f1", every_n_epochs=1) 233 | trainer = Trainer( 234 | callbacks=[checkpoint_callback], 235 | max_epochs=training_args.max_epochs, 236 | max_steps=training_args.max_steps, 237 | accelerator="gpu", 238 | devices=data_args.devices, 239 | enable_progress_bar=True, 240 | default_root_dir=data_args.output_dir, 241 | gradient_clip_val=training_args.max_grad_norm, 242 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 243 | num_sanity_val_steps=0, 244 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 245 | ) 246 | 247 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 248 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 249 | 250 | # 缓存数据集 251 | if data_args.do_train: 252 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 253 | if data_args.do_eval: 254 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 255 | if data_args.do_test: 256 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 257 | 258 | 259 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 260 | 261 | if not data_args.convert_onnx: 262 | train_datasets = dataHelper.load_distributed_random_sampler( 263 | dataHelper.train_files, 264 | with_load_memory=True, 265 | collate_fn=dataHelper.collate_fn, 266 | batch_size=training_args.train_batch_size, 267 | num_processes = trainer.world_size, process_index=trainer.global_rank) 268 | if train_datasets is not None: 269 | trainer.fit(model, train_dataloaders=train_datasets) 270 | else: 271 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 272 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 273 | if eval_datasets is not None: 274 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 275 | 276 | if test_datasets is not None: 277 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 278 | -------------------------------------------------------------------------------- /task_classify/task_tnews_hierarchical_position.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | from deep_training.data_helper import DataHelper 9 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 10 | from deep_training.nlp.models.transformer import TransformerForSequenceClassification 11 | from deep_training.utils.trainer import SimpleModelCheckpoint 12 | from lightning import Trainer 13 | 14 | from sklearn.metrics import f1_score, classification_report 15 | from torch.utils.data import DataLoader, IterableDataset 16 | from tqdm import tqdm 17 | from transformers import HfArgumentParser, BertTokenizer 18 | 19 | train_info_args = { 20 | 'devices': 1, 21 | 'data_backend': 'memory_raw', 22 | 'model_type': 'bert', 23 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 24 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 25 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 26 | 'convert_onnx': False, # 转换onnx模型 27 | 'do_train': True, 28 | 'do_eval': True, 29 | 'train_file': [ '/data/nlp/nlp_train_data/clue/tnews/train.json'], 30 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/tnews/dev.json'], 31 | 'test_file': [ '/data/nlp/nlp_train_data/clue/tnews/test.json'], 32 | 'label_file': [ '/data/nlp/nlp_train_data/clue/tnews/labels.json'], 33 | 'hierarchical_position': 0.4, 34 | 'learning_rate': 5e-5, 35 | 'max_epochs': 3, 36 | 'train_batch_size': 10, 37 | 'test_batch_size': 2, 38 | 'adam_epsilon': 1e-8, 39 | 'gradient_accumulation_steps': 1, 40 | 'max_grad_norm': 1.0, 41 | 'weight_decay': 0, 42 | 'warmup_steps': 0, 43 | 'output_dir': './output', 44 | 'train_max_seq_length': 1024, 45 | 'eval_max_seq_length': 1024, 46 | 'test_max_seq_length': 1024, 47 | } 48 | 49 | 50 | class NN_DataHelper(DataHelper): 51 | # 切分词 52 | def on_data_process(self, data: typing.Any, mode: str): 53 | tokenizer: BertTokenizer 54 | max_seq_length = self.max_seq_length_dict[mode] 55 | tokenizer = self.tokenizer 56 | do_lower_case = tokenizer.do_lower_case 57 | label2id = self.label2id 58 | 59 | sentence, label_str = data 60 | 61 | o = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 62 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 63 | attention_mask = np.asarray(o['attention_mask'], dtype=np.int64) 64 | 65 | labels = np.asarray(label2id[label_str] if label_str is not None else 0, dtype=np.int64) 66 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 67 | pad_len = max_seq_length - len(input_ids) 68 | if pad_len > 0: 69 | pad_val = tokenizer.pad_token_id 70 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 71 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 72 | d = { 73 | 'input_ids': input_ids, 74 | 'attention_mask': attention_mask, 75 | 'labels': labels, 76 | 'seqlen': seqlen 77 | } 78 | return d 79 | 80 | # 读取标签 81 | def on_get_labels(self, files: typing.List[str]): 82 | if files is None: 83 | return None, None 84 | label_fname = files[0] 85 | is_json_file = label_fname.endswith('.json') 86 | D = set() 87 | with open(label_fname, 'r', encoding='utf-8') as f: 88 | lines = f.readlines() 89 | for line in lines: 90 | line = line.replace('\r\n', '').replace('\n', '') 91 | if not line: continue 92 | if is_json_file: 93 | jd = json.loads(line) 94 | line = jd['label'] 95 | D.add(line) 96 | label2id = {label: i for i, label in enumerate(D)} 97 | id2label = {i: label for i, label in enumerate(D)} 98 | return label2id, id2label 99 | 100 | # 读取文件 101 | def on_get_corpus(self, files: typing.List, mode: str): 102 | D = [] 103 | for filename in files: 104 | with open(filename, mode='r', encoding='utf-8') as f: 105 | lines = f.readlines() 106 | for line in lines: 107 | jd = json.loads(line) 108 | if not jd: 109 | continue 110 | D.append((jd['sentence'], jd.get('label', None))) 111 | return D[0:1000] if mode == 'train' else D[:100] 112 | 113 | def collate_fn(self,batch): 114 | o = {} 115 | for i, b in enumerate(batch): 116 | if i == 0: 117 | for k in b: 118 | o[k] = [torch.tensor(b[k])] 119 | else: 120 | for k in b: 121 | o[k].append(torch.tensor(b[k])) 122 | for k in o: 123 | o[k] = torch.stack(o[k]) 124 | 125 | max_len = torch.max(o.pop('seqlen')) 126 | 127 | o['input_ids'] = o['input_ids'][:, :max_len] 128 | o['attention_mask'] = o['attention_mask'][:, :max_len] 129 | if 'token_type_ids' in o: 130 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 131 | return o 132 | 133 | 134 | class MyTransformer(TransformerForSequenceClassification, with_pl=True): 135 | def __init__(self, *args, **kwargs): 136 | super(MyTransformer, self).__init__(*args, **kwargs) 137 | 138 | def compute_loss(self, *args, **batch) -> tuple: 139 | outputs = self.model(*args, **batch) 140 | labels = batch.get('labels', None) 141 | if labels is not None: 142 | loss, logits = outputs[0:2] 143 | acc = torch.sum(torch.eq(labels.view(-1), 144 | torch.argmax(logits, dim=1, keepdim=False))) / labels.view(-1).size()[0] 145 | loss_dict = { 146 | 'loss': loss, 147 | 'acc': acc 148 | } 149 | outputs = (loss_dict, logits, labels) 150 | else: 151 | outputs = (outputs[0],) 152 | 153 | return outputs 154 | 155 | # def validation_epoch_end(self, outputs: typing.Union[EPOCH_OUTPUT, typing.List[EPOCH_OUTPUT]]) -> None: 156 | # y_preds, y_trues = [], [] 157 | # for o in outputs: 158 | # preds, labels = o['outputs'] 159 | # preds = np.argmax(preds, -1) 160 | # for p, l in zip(preds, labels): 161 | # y_preds.append(p) 162 | # y_trues.append(int(l)) 163 | # 164 | # y_preds = np.asarray(y_preds, dtype=np.int32) 165 | # y_trues = np.asarray(y_trues, dtype=np.int32) 166 | # f1 = f1_score(y_trues, y_preds, average='micro') 167 | # report = classification_report(y_trues, y_preds, digits=4, 168 | # labels=list(self.config.label2id.values()), 169 | # target_names=list(self.config.label2id.keys())) 170 | # 171 | # print(f1, report) 172 | # self.log('val_f1', f1) 173 | 174 | 175 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 176 | def __init__(self, *args, **kwargs): 177 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 178 | self.weight_file = './best.pt' 179 | 180 | def on_save_model( 181 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 182 | ) -> None: 183 | pl_module: MyTransformer 184 | 185 | # 当前设备 186 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 187 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 188 | 189 | config = pl_module.config 190 | 191 | y_preds, y_trues = [], [] 192 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 193 | for k in batch: 194 | batch[k] = batch[k].to(device) 195 | o = pl_module.validation_step(batch, i) 196 | 197 | preds, labels = o['outputs'] 198 | preds = np.argmax(preds, -1) 199 | for p, l in zip(preds, labels): 200 | y_preds.append(p) 201 | y_trues.append(int(l)) 202 | 203 | y_preds = np.asarray(y_preds, dtype=np.int32) 204 | y_trues = np.asarray(y_trues, dtype=np.int32) 205 | f1 = f1_score(y_trues, y_preds, average='micro') 206 | report = classification_report(y_trues, y_preds, digits=4, 207 | labels=list(config.label2id.values()), 208 | target_names=list(config.label2id.keys())) 209 | 210 | print(f1, report) 211 | 212 | best_f1 = self.best.get('f1', -np.inf) 213 | print('current', f1, 'best', best_f1) 214 | if f1 >= best_f1: 215 | self.best['f1'] = f1 216 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 217 | trainer.save_checkpoint(self.weight_file) 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 222 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 223 | 224 | checkpoint_callback = MySimpleModelCheckpoint(monitor="val_f1", every_n_epochs=1) 225 | trainer = Trainer( 226 | callbacks=[checkpoint_callback], 227 | max_epochs=training_args.max_epochs, 228 | max_steps=training_args.max_steps, 229 | accelerator="gpu", 230 | devices=data_args.devices, 231 | enable_progress_bar=True, 232 | default_root_dir=data_args.output_dir, 233 | gradient_clip_val=training_args.max_grad_norm, 234 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 235 | num_sanity_val_steps=0, 236 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 237 | ) 238 | 239 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 240 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 241 | 242 | # 缓存数据集 243 | if data_args.do_train: 244 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 245 | if data_args.do_eval: 246 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 247 | if data_args.do_test: 248 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 249 | 250 | 251 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 252 | 253 | if not data_args.convert_onnx: 254 | train_datasets = dataHelper.load_distributed_random_sampler( 255 | dataHelper.train_files, 256 | with_load_memory=True, 257 | collate_fn=dataHelper.collate_fn, 258 | batch_size=training_args.train_batch_size, 259 | num_processes = trainer.world_size, process_index=trainer.global_rank) 260 | 261 | if train_datasets is not None: 262 | trainer.fit(model, train_dataloaders=train_datasets) 263 | else: 264 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 265 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 266 | if eval_datasets is not None: 267 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 268 | 269 | if test_datasets is not None: 270 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 271 | -------------------------------------------------------------------------------- /task_extract_ner/task_cluener_pointer_prefixtuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import logging 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | from deep_training.data_helper import DataHelper 9 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments, \ 10 | PrefixModelArguments 11 | from deep_training.nlp.metrics.pointer import metric_for_pointer 12 | from deep_training.nlp.models.prefixtuning import PrefixTransformerPointer 13 | from deep_training.utils.trainer import SimpleModelCheckpoint 14 | from lightning import Trainer 15 | from torch.utils.data import DataLoader, IterableDataset 16 | from tqdm import tqdm 17 | from transformers import HfArgumentParser, BertTokenizer 18 | 19 | train_info_args = { 20 | 'devices': 1, 21 | 'data_backend': 'memory_raw', 22 | 'model_type': 'bert', 23 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 24 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 25 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 26 | 'convert_onnx': False, # 转换onnx模型 27 | 'do_train': True, 28 | 'do_eval': True, 29 | 'train_file': [ '/data/nlp/nlp_train_data/clue/cluener/train.json'], 30 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/cluener/dev.json'], 31 | 'test_file': [ '/data/nlp/nlp_train_data/clue/cluener/test.json'], 32 | 'learning_rate': 1e-3, 33 | 'max_epochs': 80, 34 | 'train_batch_size': 140, 35 | 'eval_batch_size': 2, 36 | 'test_batch_size': 2, 37 | 'adam_epsilon': 1e-8, 38 | 'gradient_accumulation_steps': 1, 39 | 'max_grad_norm': 1.0, 40 | 'weight_decay': 0, 41 | 'warmup_steps': 0, 42 | 'output_dir': './output', 43 | 'max_seq_length': 160, 44 | 'pre_seq_len': 16 45 | } 46 | 47 | 48 | class NN_DataHelper(DataHelper): 49 | index = -1 50 | eval_labels = [] 51 | 52 | # 切分成开始 53 | def on_data_ready(self): 54 | self.index = -1 55 | 56 | # 切分词 57 | def on_data_process(self, data: typing.Any, mode: str): 58 | self.index += 1 59 | tokenizer: BertTokenizer 60 | max_seq_length = self.max_seq_length_dict[mode] 61 | tokenizer = self.tokenizer 62 | do_lower_case = tokenizer.do_lower_case 63 | label2id = self.label2id 64 | sentence, label_dict = data 65 | 66 | tokens = list(sentence) if not do_lower_case else list(sentence.lower()) 67 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 68 | if len(input_ids) > max_seq_length - 2: 69 | input_ids = input_ids[:max_seq_length - 2] 70 | input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] 71 | attention_mask = [1] * len(input_ids) 72 | 73 | input_ids = np.asarray(input_ids, dtype=np.int32) 74 | attention_mask = np.asarray(attention_mask, dtype=np.int32) 75 | seqlen = np.asarray(len(input_ids), dtype=np.int32) 76 | labels = np.zeros(shape=(len(label2id), max_seq_length, max_seq_length), dtype=np.int32) 77 | real_label = [] 78 | 79 | if label_dict is not None: 80 | for label_str, o in label_dict.items(): 81 | pts = [_ for a_ in list(o.values()) for _ in a_] 82 | labelid = label2id[label_str] 83 | for pt in pts: 84 | assert pt[0] <= pt[1] 85 | if pt[1] < max_seq_length - 2: 86 | labels[labelid, pt[0] + 1, pt[1] + 1] = 1 87 | real_label.append((labelid, pt[0], pt[1])) 88 | 89 | pad_len = max_seq_length - len(input_ids) 90 | if pad_len > 0: 91 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', 92 | constant_values=(tokenizer.pad_token_id, tokenizer.pad_token_id)) 93 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 94 | d = { 95 | 'input_ids': input_ids, 96 | 'attention_mask': attention_mask, 97 | 'labels': labels, 98 | 'seqlen': seqlen, 99 | } 100 | if self.index < 5: 101 | print(tokens) 102 | print(input_ids[:seqlen]) 103 | print(attention_mask[:seqlen]) 104 | print(seqlen) 105 | 106 | if mode == 'eval': 107 | self.eval_labels.append(real_label) 108 | # if mode == 'eval': 109 | # d['real_label'] = np.asarray(bytes(json.dumps(real_label, ensure_ascii=False), encoding='utf-8')) 110 | return d 111 | 112 | # 读取标签 113 | def on_get_labels(self, files: typing.List[str]): 114 | labels = [ 115 | 'address', 'book', 'company', 'game', 'government', 'movie', 'name', 'organization', 'position', 'scene' 116 | ] 117 | labels = list(set(labels)) 118 | labels = sorted(labels) 119 | label2id = {label: i for i, label in enumerate(labels)} 120 | id2label = {i: label for i, label in enumerate(labels)} 121 | return label2id, id2label 122 | 123 | # 读取文件 124 | def on_get_corpus(self, files: typing.List, mode: str): 125 | D = [] 126 | for filename in files: 127 | with open(filename, mode='r', encoding='utf-8') as f: 128 | lines = f.readlines() 129 | for line in lines: 130 | jd = json.loads(line) 131 | if not jd: 132 | continue 133 | D.append((jd['text'], jd.get('label', None))) 134 | return D 135 | 136 | def collate_fn(self,batch): 137 | o = {} 138 | for i, b in enumerate(batch): 139 | if i == 0: 140 | for k in b: 141 | o[k] = [torch.tensor(b[k])] 142 | else: 143 | for k in b: 144 | o[k].append(torch.tensor(b[k])) 145 | for k in o: 146 | o[k] = torch.stack(o[k]) 147 | max_len = torch.max(o.pop('seqlen')) 148 | o['input_ids'] = o['input_ids'][:, :max_len] 149 | o['attention_mask'] = o['attention_mask'][:, :max_len] 150 | if 'token_type_ids' in o: 151 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 152 | o['labels'] = o['labels'][:, :, :max_len, :max_len] 153 | return o 154 | 155 | 156 | class MyTransformer(PrefixTransformerPointer, with_pl=True): 157 | def __init__(self, eval_labels, *args, **kwargs): 158 | super(MyTransformer, self).__init__(*args, **kwargs) 159 | self.model.eval_labels = eval_labels 160 | self.eval_labels = eval_labels 161 | 162 | 163 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 164 | def __init__(self, *args, **kwargs): 165 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 166 | self.weight_file = './best.pt' 167 | 168 | def on_save_model( 169 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 170 | ) -> None: 171 | pl_module: MyTransformer 172 | 173 | # 当前设备 174 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 175 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 176 | 177 | threshold = 1e-8 178 | # eval_labels = pl_module.eval_labels 179 | config = pl_module.config 180 | 181 | y_preds, y_trues = [], [] 182 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 183 | for k in batch: 184 | batch[k] = batch[k].to(device) 185 | o = pl_module.validation_step(batch, i) 186 | 187 | logits, label = o['outputs'] 188 | logits[:, :, [0, -1]] -= np.inf 189 | logits[:, :, :, [0, -1]] -= np.inf 190 | assert len(logits) == len(label) 191 | for p, t in zip(logits, label): 192 | a_result = [] 193 | for (l, s, e) in zip(*np.where(p > threshold)): 194 | a_result.append((l, s, e)) 195 | y_preds.append(a_result) 196 | b_result = [] 197 | for (l, s, e) in zip(*np.where(t > threshold)): 198 | b_result.append((l, s, e)) 199 | y_trues.append(b_result) 200 | f1, str_report = metric_for_pointer(y_trues, y_preds, config.id2label) 201 | print(f1) 202 | print(str_report) 203 | 204 | best_f1 = self.best.get('f1', -np.inf) 205 | print('current', f1, 'best', best_f1) 206 | if f1 >= best_f1: 207 | self.best['f1'] = f1 208 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 209 | trainer.save_checkpoint(self.weight_file) 210 | 211 | 212 | if __name__ == '__main__': 213 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PrefixModelArguments)) 214 | model_args, training_args, data_args, prompt_args = parser.parse_dict(train_info_args) 215 | 216 | checkpoint_callback = MySimpleModelCheckpoint(monitor='val_f1', every_n_epochs=1) 217 | trainer = Trainer( 218 | log_every_n_steps=10, 219 | callbacks=[checkpoint_callback], 220 | max_epochs=training_args.max_epochs, 221 | max_steps=training_args.max_steps, 222 | accelerator="gpu", 223 | devices=data_args.devices, 224 | enable_progress_bar=True, 225 | default_root_dir=data_args.output_dir, 226 | gradient_clip_val=training_args.max_grad_norm, 227 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 228 | num_sanity_val_steps=0, 229 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 230 | ) 231 | 232 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 233 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 234 | 235 | # 缓存数据集 236 | if data_args.do_train: 237 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 238 | if data_args.do_eval: 239 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 240 | if data_args.do_test: 241 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 242 | 243 | model = MyTransformer(dataHelper.eval_labels, with_efficient=True, prompt_args=prompt_args, config=config, 244 | model_args=model_args, training_args=training_args) 245 | 246 | if not data_args.convert_onnx: 247 | train_datasets = dataHelper.load_distributed_random_sampler( 248 | dataHelper.train_files, 249 | with_load_memory=True, 250 | collate_fn=dataHelper.collate_fn, 251 | batch_size=training_args.train_batch_size, 252 | num_processes = trainer.world_size, process_index=trainer.global_rank) 253 | 254 | if train_datasets is not None: 255 | trainer.fit(model, train_dataloaders=train_datasets) 256 | else: 257 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 258 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 259 | if eval_datasets is not None: 260 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 261 | 262 | if test_datasets is not None: 263 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 264 | -------------------------------------------------------------------------------- /task_grammatical_error_correction/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbuild/pytorch-task-example/7d2341562c4ae3070fc7fc18b3b1886a74391ca2/task_grammatical_error_correction/1.png -------------------------------------------------------------------------------- /task_grammatical_error_correction/README.md: -------------------------------------------------------------------------------- 1 | ## 概述 2 | 3 | 由于Seq2Seq在机器翻译等领域的成功应用,把这种方法用到类似的语法纠错问题上也是非常自然的想法。 4 | 机器翻译的输入是源语言(比如英语),输出是另外一个目标语言(比如法语)。 5 | 而语法纠错的输入是有语法错误的句子,输出是与之对应的语法正确的句子,区别似乎只在于机器翻译的输入输出是不同的语言而语法纠错的输入输出是相同的语言。 6 | 随着Transformer在机器翻译领域的成功,主流的语法纠错也都使用了Transformer来作为Seq2Seq模型的Encoder和Decoder。 7 | 当然随着BERT等Pretraining模型的出现,机器翻译和语法纠错都使用了这些Pretraining的Transformer模型来作为初始化参数,并且使用领域的数据进行Fine-Tuning。 8 | 由于领域数据相对Pretraining的无监督数据量太少,最近合成的(synthetic)数据用于Fine-tuning变得流行起来。 9 | 查看一下nlpprogress的GEC任务,排行榜里的方法大多都是使用了BERT等Pretraining的Seq2Seq模型。 10 | 11 | ## seq2seq 缺点 12 | 13 | 但是Seq2Seq模型有如下缺点: 14 | 15 | 解码速度慢 16 | 因为解码不能并行计算 17 | 需要大量训练数据 18 | 因为输出的长度不定,相对本文的序列标签模型需要更多的数据 19 | 不可解释 20 | 输入了错误的句子,输出只是正确的句子,不能直接知道到底是什么类型的语法错误,通常还需要使用其它工具来分析错误,比如errant。 21 | 22 | ## gector 23 | 24 | gector思路是使用序列标签模型替代生成模型。注意:我这里使用的是序列标签而不是更常见的序列标注来翻译Sequence Tagging,原因在于它和用来解决NER等问题的序列标注不同。序列标注的标签通常是有关联的,比如以”BIO”三标签为例,I只能出现在B或者I后面,它们的组合是有意义的。 25 | 而本文的给每一个Token打的标签和前后的标签没有关联,当然给当前Token打标签需要参考上下文,但这只是在输入层面,而在标签层面是无关的。本文的训练分为三个阶段:在合成数据上的Pretraining; 26 | 在错误-正确的句对上的fine-tuning;在同时包含错误-正确和正确-正确句对数据上的fine-tuning。 27 | 28 | 29 | 怎么把纠错问题用序列标注来解决呢?我们的数据是有语法错误和语法正确的两个句子。和机器翻译不同,语法纠错的两个句子通常非常相似,只是在某些局部会有不同的地方。因此类似于比较两个句子的diff,我们可以找到一系列编辑操作,从而把语法错误的句子变成语法正确的句子,这和编辑距离的编辑很类似。编辑操作怎么变成序列打标签呢?我们可以把编辑映射某个Token上,认为是对这个Token的操作。但是这里还有一个问题,有时候需要对同一个Token进行多个编辑操作,因为序列打标签的输出只能是一个,那怎么办呢?本文采取了一种迭代的方法,也就是通过多次(其实最多也就两三次)序列打标签。说起来有点抽象,我们来看一个例子。 30 | 31 | ![image](1.png) 32 | 33 | 比如上图的例子,红色的句子是语法错误的句子:”A ten years old boy go school”。 34 | 35 | 我们先经过一次序列打标签,找到了需要对ten和go进行操作,也就是把ten和years合并成ten-years,把go变成goes。注意:这里的用连字符”-“把两个词合并的操作定义在前面的Token上。 36 | 37 | 接着再进行一次序列打标签,发现需要对ten-years和goes进行操作,把ten-years变成ten-year然后与old合并,在goes后面增加to。 38 | 39 | 最后一次序列打标签在school后面增加句号”.”。 40 | 41 | -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_gector/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/13 16:21 3 | import json 4 | import typing 5 | 6 | import Levenshtein 7 | import numpy as np 8 | import torch 9 | from deep_training.data_helper import DataHelper, TrainingArguments, DataArguments, ModelArguments 10 | from transformers import BertTokenizer, HfArgumentParser 11 | 12 | train_info_args = { 13 | 'devices': 1, 14 | 'data_backend': 'memory_raw', 15 | 'model_type': 'bert', 16 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 17 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 18 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 19 | 'convert_onnx': False, # 转换onnx模型 20 | 'do_train': True, 21 | 'do_eval': True, 22 | 'train_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/train.json'], 23 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/dev.json'], 24 | 'test_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/test.json'], 25 | # 'label_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/labels.json'], 26 | 'label_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/vocab.txt'], 27 | 'learning_rate': 5e-5, 28 | 'max_epochs': 3, 29 | 'train_batch_size': 10, 30 | 'test_batch_size': 2, 31 | 'adam_epsilon': 1e-8, 32 | 'gradient_accumulation_steps': 1, 33 | 'max_grad_norm': 1.0, 34 | 'weight_decay': 0, 35 | 'warmup_steps': 0, 36 | 'output_dir': './output', 37 | 'train_max_seq_length': 380, 38 | 'eval_max_seq_length': 512, 39 | 'test_max_seq_length': 512, 40 | } 41 | 42 | 43 | class NN_DataHelper(DataHelper): 44 | # 切分词 45 | def on_data_process(self, data: typing.Any, mode: str): 46 | tokenizer: BertTokenizer 47 | max_seq_length = self.max_seq_length_dict[mode] 48 | tokenizer = self.tokenizer 49 | do_lower_case = tokenizer.do_lower_case 50 | label2id = self.label2id 51 | 52 | sentence, label_ops = data 53 | 54 | tokens = list(sentence) if not do_lower_case else list(sentence.lower()) 55 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 56 | tokens = ['[CLS]'] + tokenizer.convert_ids_to_tokens(input_ids) + ['[SEP]'] 57 | if len(input_ids) > max_seq_length - 2: 58 | input_ids = input_ids[:max_seq_length - 2] 59 | input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] 60 | attention_mask = [1] * len(input_ids) 61 | 62 | labels_action = [-100] * max_seq_length 63 | labels_probs = [-100] * max_seq_length 64 | 65 | 66 | for op in label_ops: 67 | s = op[1] + 1 68 | e = op[2] + 1 69 | 70 | if e >= max_seq_length: 71 | print('corpus long length!') 72 | continue 73 | for j in range(s,e): 74 | labels_action[j] = op[0] 75 | labels_probs[j] = label2id[tokens[j]] 76 | 77 | input_ids = np.asarray(input_ids,np.int32) 78 | attention_mask = np.asarray(attention_mask, np.int32) 79 | labels_action = np.asarray(labels_action, np.int32) 80 | labels_probs = np.asarray(labels_probs, np.int32) 81 | 82 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 83 | pad_len = max_seq_length - len(input_ids) 84 | if pad_len > 0: 85 | pad_val = tokenizer.pad_token_id 86 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 87 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 88 | d = { 89 | 'input_ids': input_ids, 90 | 'attention_mask': attention_mask, 91 | 'labels_action': labels_action, 92 | 'labels_probs': labels_probs, 93 | 'seqlen': seqlen 94 | } 95 | return d 96 | 97 | # 读取标签 98 | def on_get_labels(self, files: typing.List[str]): 99 | if files is None: 100 | return None, None 101 | label_fname = files[0] 102 | is_json_file = label_fname.endswith('.json') 103 | D = set() 104 | with open(label_fname, 'r', encoding='utf-8') as f: 105 | lines = f.readlines() 106 | for line in lines: 107 | line = line.replace('\r\n', '').replace('\n', '') 108 | if not line: continue 109 | if is_json_file: 110 | jd = json.loads(line) 111 | line = jd['label'] 112 | D.add(line) 113 | label2id = {label: i for i, label in enumerate(D)} 114 | id2label = {i: label for i, label in enumerate(D)} 115 | return label2id, id2label 116 | 117 | # 读取文件 118 | def on_get_corpus(self, files: typing.List, mode: str): 119 | op_map = { 120 | 'equal': 0, 121 | 'insert': 1, 122 | 'delete': 2, 123 | 'replace':3 124 | } 125 | D = [] 126 | for filename in files: 127 | with open(filename, mode='r', encoding='utf-8') as f: 128 | lines = f.readlines() 129 | for line in lines: 130 | jd = json.loads(line) 131 | if not jd: 132 | continue 133 | src = jd['source'] 134 | dst = jd.get('target',None) 135 | if mode != 'test': 136 | assert dst is not None 137 | if dst is not None: 138 | edits = Levenshtein.opcodes(src, dst) 139 | ops = [] 140 | for item in edits: 141 | op = op_map[item[0]] 142 | s = item[1] 143 | e = item[2] 144 | ops.append((op,s,e)) 145 | else: 146 | ops = None 147 | D.append((src,ops)) 148 | if mode == 'eval': 149 | return D[:500] 150 | return D 151 | 152 | def collate_fn(self,batch): 153 | o = {} 154 | for i, b in enumerate(batch): 155 | if i == 0: 156 | for k in b: 157 | o[k] = [torch.tensor(b[k])] 158 | else: 159 | for k in b: 160 | o[k].append(torch.tensor(b[k])) 161 | for k in o: 162 | o[k] = torch.stack(o[k]) 163 | 164 | max_len = torch.max(o.pop('seqlen')) 165 | 166 | o['input_ids'] = o['input_ids'][:, :max_len].long() 167 | o['attention_mask'] = o['attention_mask'][:, :max_len].long() 168 | if 'token_type_ids' in o: 169 | o['token_type_ids'] = o['token_type_ids'][:, :max_len].long() 170 | 171 | if 'labels_action' in o: 172 | o['labels_action'] = o['labels_action'][:, :max_len].long() 173 | o['labels_probs'] = o['labels_probs'][:, :max_len].long() 174 | return o 175 | 176 | 177 | 178 | if __name__ == '__main__': 179 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 180 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 181 | 182 | 183 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 184 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 185 | 186 | # 缓存数据集 187 | if data_args.do_train: 188 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 189 | if data_args.do_eval: 190 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 191 | if data_args.do_test: 192 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_gector/task_ctc_gector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/10 17:18 3 | 4 | import logging 5 | import numpy as np 6 | import torch 7 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 8 | from deep_training.nlp.metrics.pointer import metric_for_pointer 9 | from deep_training.nlp.models.gec_model import TransformerForGec, extract_gec, extract_gec_from_labels 10 | from deep_training.utils.trainer import SimpleModelCheckpoint 11 | from lightning import Trainer 12 | from torch.utils.data import DataLoader, IterableDataset 13 | from tqdm import tqdm 14 | from transformers import HfArgumentParser 15 | from data_utils import NN_DataHelper, train_info_args 16 | 17 | 18 | class MyTransformer(TransformerForGec, with_pl=True): 19 | def __init__(self, *args, **kwargs): 20 | super(MyTransformer, self).__init__(*args, **kwargs) 21 | 22 | 23 | 24 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 25 | def __init__(self, *args, **kwargs): 26 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 27 | self.weight_file = './best.pt' 28 | 29 | def on_save_model( 30 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 31 | ) -> None: 32 | pl_module: MyTransformer 33 | 34 | # 当前设备 35 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 36 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 37 | 38 | config = pl_module.config 39 | 40 | y_preds, y_trues = [], [] 41 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 42 | for k in batch: 43 | batch[k] = batch[k].to(device) 44 | o = pl_module.validation_step(batch, i) 45 | logits_action,logits_probs,seqlens,labels_action,labels_probs = o['outputs'] 46 | #抽取 三元组(action,position,vocab) 47 | output_list = extract_gec([logits_action,logits_probs,seqlens]) 48 | true_list = extract_gec_from_labels([labels_action,labels_probs,seqlens]) 49 | # y_preds.extend(output_list) 50 | # y_trues.extend(true_list) 51 | for ones in output_list: 52 | y_preds.append( [(_[0]-1,*_[1:]) for _ in ones]) 53 | 54 | for ones in true_list: 55 | y_trues.append([(_[0] - 1, *_[1:]) for _ in ones]) 56 | 57 | 58 | # 三元组(action,position,vocab) 59 | print(y_preds[:3]) 60 | print(y_trues[:3]) 61 | 62 | label2id = { 63 | 'insert': 0, 64 | 'delete': 1, 65 | 'replace': 2 66 | } 67 | 68 | f1, str_report = metric_for_pointer(y_trues, y_preds, label2id) 69 | print(f1) 70 | print(str_report) 71 | 72 | best_f1 = self.best.get('f1', -np.inf) 73 | print('current', f1, 'best', best_f1) 74 | if f1 >= best_f1: 75 | self.best['f1'] = f1 76 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 77 | trainer.save_checkpoint(self.weight_file) 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 83 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 84 | 85 | checkpoint_callback = MySimpleModelCheckpoint(every_n_epochs=1, 86 | every_n_train_steps=2000) 87 | trainer = Trainer( 88 | callbacks=[checkpoint_callback], 89 | max_epochs=training_args.max_epochs, 90 | max_steps=training_args.max_steps, 91 | accelerator="gpu", 92 | devices=data_args.devices, 93 | enable_progress_bar=True, 94 | default_root_dir=data_args.output_dir, 95 | gradient_clip_val=training_args.max_grad_norm, 96 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 97 | num_sanity_val_steps=0, 98 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 99 | ) 100 | 101 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 102 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 103 | 104 | # 缓存数据集 105 | if data_args.do_train: 106 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 107 | if data_args.do_eval: 108 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 109 | if data_args.do_test: 110 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 111 | 112 | 113 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 114 | 115 | if not data_args.convert_onnx: 116 | train_datasets = dataHelper.load_distributed_random_sampler( 117 | dataHelper.train_files, 118 | with_load_memory=True, 119 | collate_fn=dataHelper.collate_fn, 120 | batch_size=training_args.train_batch_size, 121 | num_processes = trainer.world_size, process_index=trainer.global_rank) 122 | 123 | if train_datasets is not None: 124 | trainer.fit(model, train_dataloaders=train_datasets) 125 | else: 126 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 127 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 128 | if eval_datasets is not None: 129 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 130 | 131 | if test_datasets is not None: 132 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 3:12 3 | # @File:data_utils.py 4 | import json 5 | import typing 6 | 7 | import numpy as np 8 | import torch 9 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, DataArguments 10 | from transformers import BertTokenizer, HfArgumentParser 11 | 12 | train_info_args = { 13 | 'devices': 1, 14 | 'data_backend': 'record', 15 | 'model_type': 't5', 16 | # 'model_name_or_path': '/data/nlp/pre_models/torch/', 17 | 'tokenizer_name': './t5_small_config', 18 | 'config_name': './t5_small_config/config.json', 19 | 'convert_onnx': False, # 转换onnx模型 20 | 'do_train': True, 21 | 'do_eval': True, 22 | 'train_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/train.json'], 23 | 'eval_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/dev.json'], 24 | 'test_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/test.json'], 25 | #'label_file': [ '/data/nlp/nlp_train_data/clue/CTC2021/labels.json'], 26 | 'label_file': [ ], 27 | 'learning_rate': 5e-5, 28 | 'max_epochs': 3, 29 | 'train_batch_size': 10, 30 | 'eval_batch_size': 10, 31 | 'test_batch_size': 2, 32 | 'adam_epsilon': 1e-8, 33 | 'gradient_accumulation_steps': 1, 34 | 'max_grad_norm': 1.0, 35 | 'weight_decay': 0, 36 | 'warmup_steps': 0, 37 | 'output_dir': './output', 38 | 'train_max_seq_length': 512, 39 | 'eval_max_seq_length': 512, 40 | 'test_max_seq_length': 512, 41 | 'max_target_length': 64, 42 | } 43 | 44 | class NN_DataHelper(DataHelper): 45 | # 切分词 46 | def on_data_process(self, data: typing.Any, mode: str): 47 | tokenizer: BertTokenizer 48 | max_seq_length = self.max_seq_length_dict[mode] 49 | tokenizer = self.tokenizer 50 | 51 | x = data 52 | def get_tokenizer_output(text): 53 | o1 = tokenizer.encode_plus(text, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 54 | 55 | input_ids = np.asarray(o1['input_ids'], dtype=np.int32) 56 | attention_mask = np.asarray(o1['attention_mask'], dtype=np.int32) 57 | seqlen = np.asarray(len(input_ids), dtype=np.int32) 58 | pad_len = max_seq_length - seqlen 59 | if pad_len > 0: 60 | pad_val = tokenizer.pad_token_id 61 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 62 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 63 | 64 | out = { 65 | 'input_ids': input_ids, 66 | 'attention_mask': attention_mask, 67 | 'seqlen': seqlen, 68 | } 69 | return out 70 | 71 | o1 = get_tokenizer_output(x[0]) 72 | o2 = get_tokenizer_output(x[1]) 73 | 74 | d = o1 75 | 76 | d['decoder_input_ids'] = o2['input_ids'] 77 | d['decoder_attention_mask'] = o2['attention_mask'] 78 | d['decoder_seqlen'] = o2['seqlen'] 79 | 80 | labels = np.ones_like(d['decoder_input_ids'],dtype=np.int32) * -100 81 | labels[:o2['seqlen']-1] = d['decoder_input_ids'][1:o2['seqlen']] 82 | 83 | d['labels'] = labels 84 | return d 85 | 86 | 87 | 88 | # 读取文件 89 | def on_get_corpus(self, files: typing.List, mode: str): 90 | D = [] 91 | for filename in files: 92 | with open(filename, mode='r', encoding='utf-8') as f: 93 | lines = f.readlines() 94 | for i, line in enumerate(lines): 95 | jd = json.loads(line) 96 | 97 | if mode == 'eval': 98 | if i > 50: 99 | break 100 | D.append((jd['source'], jd['target'])) 101 | return D 102 | 103 | def collate_fn(self, batch): 104 | o = {} 105 | for i, b in enumerate(batch): 106 | if i == 0: 107 | for k in b: 108 | o[k] = [torch.tensor(b[k])] 109 | else: 110 | for k in b: 111 | o[k].append(torch.tensor(b[k])) 112 | for k in o: 113 | o[k] = torch.stack(o[k]) 114 | 115 | 116 | max_len = torch.max(o.pop('seqlen')) 117 | o['input_ids'] = o['input_ids'][:, :max_len].long() 118 | o['attention_mask'] = o['attention_mask'][:, :max_len].long() 119 | 120 | max_len = torch.max(o.pop('decoder_seqlen')) 121 | o['decoder_input_ids'] = o['decoder_input_ids'][:, :max_len].long() 122 | o['decoder_attention_mask'] = o['decoder_attention_mask'][:, :max_len].long() 123 | o['labels'] = o['labels'][:, :max_len].long() 124 | return o 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 129 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 130 | 131 | 132 | 133 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 134 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 135 | 136 | 137 | # 缓存数据集 138 | if data_args.do_train: 139 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 140 | if data_args.do_eval: 141 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 142 | if data_args.do_test: 143 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_base_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2048, 6 | "d_kv": 64, 7 | "d_model": 768, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 12, 17 | "num_layers": 12, 18 | "num_decoder_layers": 12, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_base_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_large_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 2816, 6 | "d_kv": 64, 7 | "d_model": 1024, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 16, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_large_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_small_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 1024, 6 | "d_kv": 64, 7 | "d_model": 512, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 6, 17 | "num_layers": 8, 18 | "num_decoder_layers": 8, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_small_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_xl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 5120, 6 | "d_kv": 64, 7 | "d_model": 2048, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 32, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_xl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_xxl_config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "T5WithLMHeadModel" 4 | ], 5 | "d_ff": 10240, 6 | "d_kv": 64, 7 | "d_model": 4096, 8 | "decoder_start_token_id": 101, 9 | "dropout_rate": 0.1, 10 | "initializer_factor": 1.0, 11 | "is_encoder_decoder": true, 12 | "layer_norm_epsilon": 1e-8, 13 | "model_type": "t5", 14 | "feed_forward_proj": "gated-gelu", 15 | "n_positions": 512, 16 | "num_heads": 64, 17 | "num_layers": 24, 18 | "num_decoder_layers": 24, 19 | "output_past": true, 20 | "pad_token_id": 0, 21 | "relative_attention_num_buckets": 32, 22 | "tokenizer_class": "BertTokenizer", 23 | "vocab_size": 16448 24 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/t5_xxl_config/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "mask_token": "[MASK]", 6 | "never_split": null, 7 | "pad_token": "[PAD]", 8 | "sep_token": "[SEP]", 9 | "special_tokens_map_file": null, 10 | "strip_accents": null, 11 | "tokenize_chinese_chars": true, 12 | "tokenizer_class": "BertTokenizer", 13 | "unk_token": "[UNK]" 14 | } -------------------------------------------------------------------------------- /task_grammatical_error_correction/task_ctc_seq2seq/task_ctc_seq2seq.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/10 17:18 3 | 4 | import logging 5 | 6 | import Levenshtein 7 | import numpy as np 8 | import torch 9 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 10 | from deep_training.nlp.metrics.pointer import metric_for_pointer 11 | from deep_training.nlp.models.transformer import TransformerForSeq2SeqLM 12 | from deep_training.utils.trainer import SimpleModelCheckpoint 13 | from lightning import Trainer 14 | from torch.utils.data import DataLoader, IterableDataset 15 | from tqdm import tqdm 16 | from transformers import HfArgumentParser, T5ForConditionalGeneration 17 | 18 | from data_utils import train_info_args, NN_DataHelper 19 | 20 | 21 | class MyTransformer(TransformerForSeq2SeqLM, with_pl=True): 22 | def __init__(self, *args, **kwargs): 23 | super(MyTransformer, self).__init__(*args, **kwargs) 24 | 25 | 26 | 27 | class MySimpleModelCheckpoint(SimpleModelCheckpoint): 28 | def __init__(self, *args, **kwargs): 29 | super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs) 30 | self.weight_file = './best.pt' 31 | 32 | @staticmethod 33 | def generate_text_huggingface(pl_module: MyTransformer, input_ids, tokenizer, max_target_length, device=0): 34 | device = torch.device('cuda:{}'.format(device)) 35 | 36 | input_ids = torch.tensor(input_ids, dtype=torch.int32,device = device).unsqueeze(0) 37 | output = pl_module.backbone.model.generate(input_ids, 38 | max_length = max_target_length, 39 | bos_token_id = tokenizer.cls_token_id, 40 | pad_token_id = tokenizer.pad_token_id, 41 | eos_token_id = tokenizer.sep_token_id, 42 | ) 43 | 44 | gen_tokens = [] 45 | gen_ids = output[0].cpu().numpy() 46 | for logits in output[0]: 47 | # gen_ids.append(logits.cpu().numpy()) 48 | token = tokenizer._convert_id_to_token(logits) 49 | if token.startswith('##'): 50 | token = token.replace('##', '') 51 | gen_tokens.append(token) 52 | return ''.join(gen_tokens),gen_ids 53 | 54 | def on_save_model( 55 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 56 | ) -> None: 57 | pl_module: MyTransformer 58 | 59 | # 当前设备 60 | device = torch.device('cuda:{}'.format(trainer.global_rank)) 61 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files, 62 | batch_size=training_args.eval_batch_size, 63 | collate_fn=dataHelper.collate_fn) 64 | 65 | config = pl_module.config 66 | 67 | y_preds, y_trues = [], [] 68 | 69 | op_map = { 70 | 'insert': 0, 71 | 'delete': 1, 72 | 'replace': 2 73 | } 74 | 75 | # 三元组(action,position,vocab) 76 | def get_ops(source,target): 77 | edits = Levenshtein.opcodes(source, target) 78 | ops = [] 79 | for item in edits: 80 | if item[0] == 'equal': 81 | continue 82 | action = op_map[item[0]] 83 | s = item[1] 84 | e = item[2] 85 | ds = item[3] 86 | de = item[4] 87 | #insert,replace 88 | if action == 0 or action == 2: 89 | for idx in range(de-ds): 90 | ops.append((action, s+idx, target[ds + idx])) 91 | #delete 92 | elif action == 1: 93 | for idx in range(s, e): 94 | ops.append((action, s+idx, 0)) 95 | else: 96 | raise ValueError('invalid action ',action) 97 | 98 | return ops 99 | 100 | for i, batch in tqdm(enumerate(eval_datasets), total=len(eval_datasets), desc='evalute'): 101 | batch_labels = batch.pop('labels',None) 102 | for k in batch: 103 | batch[k] = batch[k].to(device) 104 | for input_ids,attention_mask,labels in zip(batch['input_ids'],batch['attention_mask'],batch_labels): 105 | seqlen = torch.sum(attention_mask,dim=-1) 106 | output = MySimpleModelCheckpoint.generate_text_huggingface(pl_module, 107 | input_ids, 108 | tokenizer=tokenizer, 109 | max_target_length=data_args.max_target_length, 110 | device=trainer.global_rank) 111 | source = input_ids[1:seqlen-1].cpu().numpy() 112 | # 三元组(action,position,vocab) 113 | pred_ops = get_ops(source, output[1]) 114 | 115 | 116 | _ = np.where(labels==-100)[0] 117 | if len(_): 118 | seqlen = _[0] + 1 119 | else: 120 | seqlen = len(labels) 121 | labels = labels[1:seqlen - 1] 122 | # 三元组(action,position,vocab) 123 | true_ops = get_ops(source, labels) 124 | 125 | y_preds.append(pred_ops) 126 | y_trues.append(true_ops) 127 | 128 | print(y_preds[:3]) 129 | print(y_trues[:3]) 130 | 131 | label2id = { 132 | 'insert': 0, 133 | 'delete': 1, 134 | 'replace': 2 135 | } 136 | 137 | f1, str_report = metric_for_pointer(y_trues, y_preds, label2id) 138 | print(f1) 139 | print(str_report) 140 | 141 | best_f1 = self.best.get('f1', -np.inf) 142 | print('current', f1, 'best', best_f1) 143 | if f1 >= best_f1: 144 | self.best['f1'] = f1 145 | logging.info('save best {}, {}\n'.format(self.best['f1'], self.weight_file)) 146 | trainer.save_checkpoint(self.weight_file) 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 151 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 152 | 153 | checkpoint_callback = MySimpleModelCheckpoint(every_n_epochs=1, 154 | every_n_train_steps=2000) 155 | trainer = Trainer( 156 | callbacks=[checkpoint_callback], 157 | max_epochs=training_args.max_epochs, 158 | max_steps=training_args.max_steps, 159 | accelerator="gpu", 160 | devices=data_args.devices, 161 | enable_progress_bar=True, 162 | default_root_dir=data_args.output_dir, 163 | gradient_clip_val=training_args.max_grad_norm, 164 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 165 | num_sanity_val_steps=0, 166 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 167 | ) 168 | 169 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 170 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 171 | 172 | # 缓存数据集 173 | if data_args.do_train: 174 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 175 | if data_args.do_eval: 176 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 177 | if data_args.do_test: 178 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 179 | 180 | 181 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 182 | 183 | if not data_args.convert_onnx: 184 | 185 | train_datasets = dataHelper.load_distributed_random_sampler( 186 | dataHelper.train_files, 187 | with_load_memory=True, 188 | collate_fn=dataHelper.collate_fn, 189 | batch_size=training_args.train_batch_size, 190 | num_processes = trainer.world_size, process_index=trainer.global_rank) 191 | 192 | if train_datasets is not None: 193 | trainer.fit(model, train_dataloaders=train_datasets) 194 | else: 195 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 196 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 197 | if eval_datasets is not None: 198 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 199 | 200 | if test_datasets is not None: 201 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/convert_train_pos_neg_for_infonce.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/12/16 11:03 2 | # @Author : tk 3 | # @FileName: split_record.py 4 | import copy 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | from fastdatasets.record import load_dataset as Loader, RECORD, NumpyWriter 10 | from tqdm import tqdm 11 | 12 | 13 | # 从分类数据构造正负样本池 14 | def gen_pos_neg_records(all_example): 15 | all_example_new = [] 16 | all_keys = list(all_example.keys()) 17 | all_example_num = {lable: list(range(len(all_example[lable]))) for lable in all_example} 18 | 19 | np.random.shuffle(all_keys) 20 | while len(all_keys): 21 | current_labels = np.random.choice(all_keys, replace=False, size=min(40, len(all_keys))) 22 | pos_label, neg_labels = current_labels[0], current_labels[1:] 23 | 24 | examples = all_example[pos_label] 25 | idx_list: list 26 | idx_list_negs: list 27 | idx_list = all_example_num[pos_label] 28 | 29 | if len(idx_list) == 0: 30 | continue 31 | 32 | one_sample_pos, one_sample_neg = [], [] 33 | idx = np.random.choice(idx_list, replace=False, size=min(10, len(idx_list))) 34 | for value in idx: 35 | idx_list.remove(value) 36 | one_sample_pos.append(examples[value]) 37 | 38 | # 去除空标签数据 39 | if len(idx_list) == 0: 40 | all_keys.remove(pos_label) 41 | 42 | if len(one_sample_pos) < 2: 43 | continue 44 | 45 | neg_labels = list(set(copy.deepcopy(neg_labels))) 46 | for key in neg_labels: 47 | examples_negs = all_example[key] 48 | idx_list_negs = all_example_num[key] 49 | if len(idx_list_negs) == 0: 50 | # 去除空标签数据 51 | all_keys.remove(key) 52 | continue 53 | ids = np.random.choice(idx_list_negs, replace=False, size=min(10, len(idx_list_negs))) 54 | for value in ids: 55 | if random.random() < 0.7: 56 | idx_list_negs.remove(value) 57 | one_sample_neg.append(examples_negs[value]) 58 | 59 | if len(idx_list_negs) == 0: 60 | # 去除空标签数据 61 | all_keys.remove(key) 62 | 63 | if len(one_sample_neg) < 5: 64 | continue 65 | 66 | all_example_new.append((one_sample_pos, one_sample_neg)) 67 | 68 | if len(all_example_new) % 10000 == 0: 69 | print('current num', len(all_example_new)) 70 | 71 | return all_example_new 72 | 73 | 74 | def make_pos_neg_records(input_record_filenames, output_file, compression_type='GZIP'): 75 | print('make_pos_neg_records record...') 76 | options = RECORD.TFRecordOptions(compression_type=compression_type) 77 | dataset_reader = Loader.RandomDataset(input_record_filenames, options=options, 78 | with_share_memory=True).parse_from_numpy_writer() 79 | data_size = len(dataset_reader) 80 | all_example = {} 81 | 82 | for i in tqdm(range(data_size), desc='load records'): 83 | serialized = dataset_reader[i] 84 | labels = serialized['labels'] 85 | labels = np.squeeze(labels).tolist() 86 | if labels not in all_example: 87 | all_example[labels] = [] 88 | all_example[labels].append(serialized) 89 | 90 | if hasattr(dataset_reader, 'close'): 91 | dataset_reader.close() 92 | else: 93 | dataset_reader.reset() 94 | 95 | print(all_example.keys()) 96 | all_example_new = gen_pos_neg_records(all_example) 97 | print('all_example_new', len(all_example_new)) 98 | writer = NumpyWriter(output_file, options=options) 99 | shuffle_idx = list(range(len(all_example_new))) 100 | random.shuffle(shuffle_idx) 101 | 102 | num_train = 0 103 | total_n = 0 104 | for i in tqdm(shuffle_idx, desc='shuffle record', total=len(shuffle_idx)): 105 | example = all_example_new[i] 106 | num_train += 1 107 | example_new = {} 108 | pos, neg = example 109 | total_n += len(pos) + len(neg) 110 | 111 | example_new['pos_len'] = np.asarray(len(pos), dtype=np.int32) 112 | example_new['neg_len'] = np.asarray(len(neg), dtype=np.int32) 113 | d: dict 114 | for idx, d in enumerate(pos): 115 | example_new['input_ids_pos{}'.format(idx)] = d['input_ids'] 116 | example_new['attention_mask_pos{}'.format(idx)] = d['attention_mask'] 117 | example_new['labels_pos{}'.format(idx)] = d['labels'] 118 | example_new['seqlen_pos{}'.format(idx)] = d['seqlen'] 119 | 120 | for idx, d in enumerate(neg): 121 | example_new['input_ids_neg{}'.format(idx)] = d['input_ids'] 122 | example_new['attention_mask_neg{}'.format(idx)] = d['attention_mask'] 123 | example_new['labels_neg{}'.format(idx)] = d['labels'] 124 | example_new['seqlen_neg{}'.format(idx)] = d['seqlen'] 125 | 126 | writer.write(example_new) 127 | writer.close() 128 | print('num train record', num_train, 'total record', total_n) 129 | 130 | 131 | if __name__ == '__main__': 132 | example_files = '/data/record/cse_0130/train.record' 133 | output_train_file = os.path.join('/data/record/cse_0130/train_pos_neg.record') 134 | make_pos_neg_records(input_record_filenames=example_files, output_file=output_train_file, ) 135 | 136 | example_files = '/data/record/cse_0130/train_jieba.record' 137 | output_train_file = os.path.join('/data/record/cse_0130/train_jieba_pos_neg.record') 138 | make_pos_neg_records(input_record_filenames=example_files, output_file=output_train_file, ) 139 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/corpus_process/jieba_process_corpus.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/1/9 23:26 2 | # @Author : tk 3 | # @FileName: stopwards.py 4 | import json 5 | import jieba 6 | from tqdm import tqdm 7 | import re 8 | from collections import Counter 9 | import os 10 | 11 | 12 | def get_cipin(fs,outdir,stopwards_file='./stopwards.txt'): 13 | stopwards = set() 14 | with open(stopwards_file, mode='r', encoding='utf-8', newline='\n') as f: 15 | while True: 16 | text = f.readline() 17 | if not text: 18 | break 19 | text = text.strip('\r\n').strip('\n') 20 | stopwards.add(text) 21 | 22 | print(list(stopwards)[:100]) 23 | counter = Counter() 24 | f_out = open(os.path.join(outdir, 'jieba_process.json'), mode='w', encoding='utf-8', newline='\n') 25 | f_out2 = open(os.path.join(outdir, 'raw.json'), mode='w', encoding='utf-8', newline='\n') 26 | for filename in tqdm(fs,total=len(fs)): 27 | with open(filename,mode='r',encoding='utf-8') as f: 28 | while True: 29 | line = f.readline() 30 | if not line: 31 | break 32 | jd = json.loads(line) 33 | if not jd: 34 | continue 35 | text = jd['text'] 36 | label = jd['label'] 37 | text = text.strip('\n') 38 | text = re.sub("[A-Za-z0-9\:\·\—\,\。\“ \”]", "", text) 39 | seg_list = jieba.cut(text,cut_all=False) 40 | 41 | seg_list_new = [s for s in seg_list if s not in stopwards] 42 | counter.update(seg_list_new) 43 | 44 | o = { 45 | 'text': ' '.join(seg_list_new), 46 | 'label': label 47 | } 48 | f_out.write(json.dumps(o,ensure_ascii=False) + '\n') 49 | f_out2.write(json.dumps(jd, ensure_ascii=False) + '\n') 50 | f_out.close() 51 | f_out2.close() 52 | 53 | print('\n词频统计结果:') 54 | vocabfile = os.path.join(outdir, 'vocab.txt') 55 | with open(vocabfile,mode='w',encoding='utf-8',newline='\n') as f: 56 | for (k,v) in counter.most_common():# 输出词频最高的前两个词 57 | print("%s:%d"%(k,v)) 58 | f.write("{} {}\n".format(k,v)) 59 | 60 | 61 | if __name__ == '__main__': 62 | data_dir = '/data/nlp/nlp_train_data/lawcup2018/top122/process' 63 | outdir='/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output' 64 | 65 | fs = os.listdir(data_dir) 66 | get_cipin([os.path.join(data_dir,f) for f in fs],outdir) -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/corpus_process/split_corpus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/1/30 15:32 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | np.random.seed(123456) 9 | 10 | shuffle_idx = None 11 | 12 | def process_file(in_file,train_file,eval_file): 13 | global shuffle_idx 14 | with open(in_file,mode='r',encoding='utf-8') as f: 15 | lines = f.readlines() 16 | 17 | f1 = open(train_file, mode='w', encoding='utf-8', newline='\n') 18 | f2 = open(eval_file, mode='w', encoding='utf-8', newline='\n') 19 | 20 | if shuffle_idx is None: 21 | shuffle_idx = list(range(len(lines))) 22 | np.random.shuffle(shuffle_idx) 23 | else: 24 | if len(lines) != len(shuffle_idx): 25 | raise ValueError('NOT EQ') 26 | print(shuffle_idx[:100]) 27 | for i,idx in tqdm(enumerate(shuffle_idx),total=len(lines)): 28 | jd = json.loads(lines[idx]) 29 | if i % 15 == 0: 30 | f = f2 31 | else: 32 | f = f1 33 | f.write(json.dumps(jd, ensure_ascii=False) + '\n') 34 | f1.close() 35 | f2.close() 36 | 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | in_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/jieba_process.json' 42 | train_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/train_jieba.json' 43 | eval_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/eval_jieba.json' 44 | 45 | process_file(in_file,train_file,eval_file) 46 | 47 | in_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/raw.json' 48 | train_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/train.json' 49 | eval_file = '/data/nlp/nlp_train_data/lawcup2018/top122/jieba_process_output/eval.json' 50 | 51 | process_file(in_file, train_file, eval_file) -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/corpus_process/stopwards.txt: -------------------------------------------------------------------------------- 1 | 的 2 | 点多 3 | 晚 4 | 分 5 | 傍晚 6 | 月份 7 | 日夜 8 | 许 9 | 年月日时 10 | 年月日 11 | 年 12 | 月 13 | 日 14 | 清晨 15 | 凌晨 16 | 中午 17 | 晌午 18 | 早上 19 | 上午 20 | 晚上 21 | 下午 22 | 时许 23 | 时分 24 | 左右 25 | 的一天 26 | 一天 27 | 份许 28 | 时许 29 | 分许 30 | 民警 31 | 公安民警 32 | 查获 33 | 公安局 34 | 另案处理 35 | , 36 | , 37 | 、 38 | × 39 | . 40 | / 41 | ; 42 | 》 43 | 《 44 | + 45 | ? 46 | _ 47 | “ 48 | ” 49 | 。 50 | ! 51 | , 52 | : 53 | ; 54 | ? 55 | ×× 56 | × 57 | * 58 | - -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/load_record.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/1/30 11:18 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | from fastdatasets.record import load_dataset as Loader, gfile, RECORD, WriterObject 8 | from tqdm import tqdm 9 | from transformers import BertTokenizer 10 | 11 | 12 | 13 | path_list = [ 14 | '/data/nlp/pre_models/torch/bert/bert-base-chinese', 15 | '/data/torch/bert-base-chinese', 16 | '/opt/tk/torch/bert-base-chinese' 17 | ] 18 | path = '' 19 | for p in path_list: 20 | if os.path.exists(p): 21 | path = p 22 | break 23 | 24 | tokenizer = BertTokenizer.from_pretrained(path) 25 | 26 | # 拆分数据集 27 | def load_record(input_record_filenames, compression_type='GZIP'): 28 | print('load_record record...') 29 | options = RECORD.TFRecordOptions(compression_type=compression_type) 30 | dataset_reader = Loader.RandomDataset(input_record_filenames, options=options, with_share_memory=True) 31 | dataset_reader = dataset_reader.parse_from_numpy_writer() 32 | 33 | for i in tqdm(range(len(dataset_reader)), desc='load records'): 34 | exampe = dataset_reader[i] 35 | 36 | print(exampe.keys()) 37 | seqlen = exampe['seqlen'] 38 | seqlen = np.squeeze(seqlen,axis=-1) 39 | input_ids = exampe['input_ids'] 40 | input_ids = input_ids[:seqlen] 41 | tokens = tokenizer.decode(input_ids) 42 | print(''.join(tokens)) 43 | if i > 10: 44 | break 45 | dataset_reader.close() 46 | 47 | print('*' * 30) 48 | 49 | load_record('./train.record') 50 | 51 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/make_record_for_classify.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/12/13 8:55 3 | 4 | import json 5 | import random 6 | import typing 7 | 8 | import numpy as np 9 | from deep_training.data_helper import DataHelper 10 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 11 | from fastdatasets import gfile 12 | from transformers import HfArgumentParser, BertTokenizer 13 | 14 | train_info_args = { 15 | 'devices': 1, 16 | 'data_backend': 'record', 17 | 'model_type': 'bert', 18 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 19 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 20 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 21 | 'convert_onnx': False, # 转换onnx模型 22 | 'do_train': True, 23 | 'do_eval': False, 24 | 'train_file': gfile.glob('/data/nlp/nlp_train_data/lawcup2018/top122/process/*.json'), 25 | 'eval_file': [ ''], 26 | 'test_file': [ ''], 27 | 'label_file': [ '/data/nlp/nlp_train_data/lawcup2018/top122/labels_122.txt'], 28 | # 'train_file': [ '/data/nlp/nlp_train_data/clue/tnews/train.json'], 29 | # 'eval_file': [ ''], 30 | # 'test_file': [ ''], 31 | # 'label_file': [ '/data/nlp/nlp_train_data/clue/tnews/labels.txt'], 32 | 'learning_rate': 5e-5, 33 | 'max_epochs': 3, 34 | 'train_batch_size': 10, 35 | 'test_batch_size': 2, 36 | 'adam_epsilon': 1e-8, 37 | 'gradient_accumulation_steps': 1, 38 | 'max_grad_norm': 1.0, 39 | 'weight_decay': 0, 40 | 'warmup_steps': 0, 41 | 'output_dir': './output', 42 | 'max_seq_length': 512 43 | } 44 | 45 | 46 | class NN_DataHelper(DataHelper): 47 | # 切分词 48 | def on_data_process(self, data: typing.Any, mode: str): 49 | tokenizer: BertTokenizer 50 | max_seq_length = self.max_seq_length_dict[mode] 51 | tokenizer = self.tokenizer 52 | do_lower_case = tokenizer.do_lower_case 53 | label2id = self.label2id 54 | sentence, label_str = data 55 | 56 | o = tokenizer(sentence, max_length=max_seq_length, truncation=True, add_special_tokens=True, ) 57 | input_ids = np.asarray(o['input_ids'], dtype=np.int64) 58 | attention_mask = np.asarray(o['attention_mask'], dtype=np.int64) 59 | 60 | labels = np.asarray(label2id[label_str] if label_str is not None else 0, dtype=np.int64) 61 | seqlen = np.asarray(len(input_ids), dtype=np.int64) 62 | pad_len = max_seq_length - len(input_ids) 63 | if pad_len > 0: 64 | pad_val = tokenizer.pad_token_id 65 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 66 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 67 | d = { 68 | 'input_ids': input_ids, 69 | 'attention_mask': attention_mask, 70 | 'labels': labels, 71 | 'seqlen': seqlen 72 | } 73 | return d 74 | 75 | # 读取标签 76 | def on_get_labels(self, files: typing.List[str]): 77 | file = files[0] 78 | with open(file, mode='r', encoding='utf-8') as f: 79 | lines = f.readlines() 80 | 81 | labels = [] 82 | for line in lines: 83 | line = line.replace('\r\n', '').replace('\n', '') 84 | if not line: 85 | continue 86 | labels.append(line) 87 | labels = list(set(labels)) 88 | labels = sorted(labels) 89 | label2id = {l: i for i, l in enumerate(labels)} 90 | id2label = {i: l for i, l in enumerate(labels)} 91 | return label2id, id2label 92 | 93 | # 读取文件 94 | def on_get_corpus(self, files: typing.List, mode: str): 95 | assert len(files) > 0 96 | D = [] 97 | filenames = gfile.glob(files[0]) 98 | for fname in filenames: 99 | with open(fname, mode='r', encoding='utf-8') as f: 100 | lines = f.readlines() 101 | for line in lines: 102 | jd = json.loads(line) 103 | if not jd: 104 | continue 105 | if 'text' in jd: 106 | text = jd['text'] 107 | else: 108 | text = jd['sentence'] 109 | D.append((text, jd.get('label', None))) 110 | random.shuffle(D) 111 | return D 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 116 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 117 | 118 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 119 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 120 | 121 | if data_args.do_train: 122 | dataHelper.make_dataset_with_args(data_args.train_file,shuffle=True, mode='train') 123 | if data_args.do_eval: 124 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 125 | if data_args.do_test: 126 | dataHelper.make_dataset_with_args(data_args.test_file, mode='test') 127 | 128 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/merge_record.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/12/16 11:03 2 | # @Author : tk 3 | # @FileName: split_record.py 4 | 5 | import os 6 | 7 | from fastdatasets.record import load_dataset as Loader, gfile, RECORD, WriterObject 8 | from tfrecords import TFRecordOptions 9 | from tqdm import tqdm 10 | 11 | 12 | # 合并数据集 13 | def merge_records(input_record_filenames, output_file, compression_type='GZIP'): 14 | print('split_records record...') 15 | options = RECORD.TFRecordOptions(compression_type=compression_type) 16 | dataset_reader = Loader.RandomDataset(input_record_filenames, options=options, with_share_memory=True) 17 | 18 | all_example = [] 19 | for i in tqdm(range(len(dataset_reader)), desc='load records'): 20 | serialized = dataset_reader[i] 21 | all_example.append(serialized) 22 | dataset_reader.close() 23 | 24 | # #小样本 25 | # all_example = all_example[:10000] 26 | data_size = len(all_example) 27 | shuffle_idx = list(range(data_size)) 28 | writer_output = WriterObject(output_file, options=TFRecordOptions(compression_type='GZIP')) 29 | 30 | for i in tqdm(shuffle_idx, desc='write record'): 31 | example = all_example[i] 32 | writer_output.write(example) 33 | writer_output.close() 34 | 35 | print('num', len(shuffle_idx)) 36 | 37 | 38 | if __name__ == '__main__': 39 | src_files = [ 40 | '/data/record/cse_0110/eval_pos.record.cache', 41 | '/data/record/cse_0110/eval_neg.record.cache' 42 | ] 43 | dst_dir = './' 44 | if not os.path.exists(dst_dir): 45 | gfile.makedirs(dst_dir) 46 | 47 | output_file = os.path.join(dst_dir, 'eval_pos_neg.record.cache') 48 | 49 | merge_records(input_record_filenames=src_files, 50 | output_file=output_file, ) 51 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/shuffle_record.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/12/16 10:58 2 | # @Author : tk 3 | # @FileName: shuffle_record.py 4 | 5 | import os 6 | import random 7 | 8 | from fastdatasets.record import load_dataset as Loader, RECORD, WriterObject 9 | from tqdm import tqdm 10 | 11 | 12 | def shuffle_records(record_filenames, out_dir, out_record_num, compression_type='GZIP'): 13 | print('shuffle_records record...') 14 | options = RECORD.TFRecordOptions(compression_type=compression_type) 15 | dataset_reader = Loader.RandomDataset(record_filenames, options=options, with_share_memory=True) 16 | data_size = len(dataset_reader) 17 | all_example = [] 18 | for i in tqdm(range(data_size), desc='load records'): 19 | serialized = dataset_reader[i] 20 | all_example.append(serialized) 21 | dataset_reader.close() 22 | 23 | shuffle_idx = list(range(data_size)) 24 | random.shuffle(shuffle_idx) 25 | writers = [WriterObject(os.path.join(out_dir, 'record_gzip_shuffle_{}.record'.format(i)), options=options) for i in 26 | range(out_record_num)] 27 | for i in tqdm(shuffle_idx, desc='shuffle record'): 28 | example = all_example[i] 29 | writers[i % out_record_num].write(example) 30 | for writer in writers: 31 | writer.close() 32 | 33 | 34 | if __name__ == '__main__': 35 | src_records = ['/tmp/train.record'] 36 | dst_dir = '/tmp/' 37 | shuffle_records(record_filenames=src_records, out_dir=dst_dir, out_record_num=1) 38 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/split_record.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/12/16 11:03 2 | # @Author : tk 3 | # @FileName: split_record.py 4 | 5 | import os 6 | import random 7 | 8 | from fastdatasets.record import load_dataset as Loader, gfile, RECORD, WriterObject 9 | from tfrecords import TFRecordOptions 10 | from tqdm import tqdm 11 | 12 | 13 | # 拆分数据集 14 | def split_records(input_record_filenames, output_train_file, output_eval_file, compression_type='GZIP'): 15 | print('split_records record...') 16 | options = RECORD.TFRecordOptions(compression_type=compression_type) 17 | dataset_reader = Loader.RandomDataset(input_record_filenames, options=options, with_share_memory=True) 18 | 19 | all_example = [] 20 | for i in tqdm(range(len(dataset_reader)), desc='load records'): 21 | serialized = dataset_reader[i] 22 | all_example.append(serialized) 23 | dataset_reader.close() 24 | 25 | # #小样本 26 | # all_example = all_example[:10000] 27 | data_size = len(all_example) 28 | shuffle_idx = list(range(data_size)) 29 | random.shuffle(shuffle_idx) 30 | 31 | writer_train = WriterObject(output_train_file, options=TFRecordOptions(compression_type='GZIP')) 32 | writer_eval = WriterObject(output_eval_file, options=TFRecordOptions(compression_type='GZIP')) 33 | 34 | num_train = 0 35 | num_eval = 0 36 | for i in tqdm(shuffle_idx, desc='shuffle record'): 37 | example = all_example[i] 38 | 39 | if (i + 1) % 15 == 0: 40 | num_eval += 1 41 | writer = writer_eval 42 | else: 43 | num_train += 1 44 | writer = writer_train 45 | 46 | writer.write(example) 47 | 48 | writer_train.close() 49 | writer_eval.close() 50 | 51 | print('num_train', num_train, 'num_eval', num_eval) 52 | 53 | 54 | if __name__ == '__main__': 55 | src_files = [ 56 | '/data/record/cse/dataset_0-train.record' 57 | ] 58 | dst_dir = '/data/record/cse_1226/' 59 | 60 | if not os.path.exists(dst_dir): 61 | gfile.makedirs(dst_dir) 62 | 63 | output_train_file = os.path.join(dst_dir, 'train.record') 64 | output_eval_file = os.path.join(dst_dir, 'eval.record') 65 | 66 | split_records(input_record_filenames=src_files, 67 | output_train_file=output_train_file, 68 | output_eval_file=output_eval_file) 69 | -------------------------------------------------------------------------------- /task_sentence_vector/task_classify_vector_record/split_record_and_modify.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/12/16 11:03 2 | # @Author : tk 3 | # @FileName: split_record.py 4 | 5 | import os 6 | import random 7 | 8 | from fastdatasets.record import load_dataset as Loader, RECORD, NumpyWriter 9 | from tqdm import tqdm 10 | 11 | 12 | # 拆分数据集 13 | def split_records(input_record_filenames, output_train_file, output_eval_file, compression_type='GZIP'): 14 | print('split_records record...') 15 | options = RECORD.TFRecordOptions(compression_type=compression_type) 16 | dataset_reader = Loader.RandomDataset(input_record_filenames, options=options, 17 | with_share_memory=True).parse_from_numpy_writer() 18 | data_size = len(dataset_reader) 19 | all_example = [] 20 | for i in tqdm(range(data_size), desc='load records'): 21 | serialized = dataset_reader[i] 22 | all_example.append(serialized) 23 | 24 | if hasattr(dataset_reader, 'close'): 25 | dataset_reader.close() 26 | else: 27 | dataset_reader.reset() 28 | 29 | shuffle_idx = list(range(data_size)) 30 | random.shuffle(shuffle_idx) 31 | 32 | writer_train = NumpyWriter(output_train_file, options=options) 33 | writer_eval = NumpyWriter(output_eval_file, options=options) 34 | 35 | num_train = 0 36 | num_eval = 0 37 | for i in tqdm(shuffle_idx, desc='shuffle record'): 38 | example = all_example[i] 39 | 40 | if (i + 1) % 8 == 0: 41 | num_eval += 1 42 | count = num_eval 43 | writer = writer_eval 44 | else: 45 | num_train += 1 46 | count = num_train 47 | writer = writer_train 48 | # 添加键值,测试 49 | # example['id'] = np.asarray(count - 1,dtype=np.int32) 50 | 51 | writer.write(example) 52 | 53 | writer_train.close() 54 | writer_eval.close() 55 | print('num_train', num_train, 'num_eval', num_eval) 56 | 57 | 58 | if __name__ == '__main__': 59 | example_files = r'/home/tk/train/make_big_data/output/dataset_0-train.record' 60 | 61 | output_train_file = os.path.join('/home/tk/train/make_big_data/output', 'train.record') 62 | 63 | output_eval_file = os.path.join('/home/tk/train/make_big_data/output', 'eval.record') 64 | 65 | split_records(input_record_filenames=example_files, 66 | output_train_file=output_train_file, 67 | output_eval_file=output_eval_file) 68 | -------------------------------------------------------------------------------- /task_text_generate/task_autotitle_unilm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import typing 4 | 5 | import numpy as np 6 | import torch 7 | from deep_training.data_helper import DataHelper 8 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 9 | from deep_training.nlp.models.transformer import TransformerModelForUnilm 10 | from deep_training.utils.func import seq_padding 11 | from deep_training.utils.trainer import SimpleModelCheckpoint 12 | from lightning import Trainer 13 | from torch.utils.data import DataLoader, IterableDataset 14 | from transformers import BertTokenizer 15 | from transformers import HfArgumentParser 16 | 17 | train_info_args = { 18 | 'devices': 1, 19 | 'data_backend': 'memory_raw', 20 | 'model_type': 'bert', 21 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 22 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 23 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 24 | 'convert_onnx': False, # 转换onnx模型 25 | 'do_train': True, 26 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 27 | 'max_steps': 100000, 28 | 'train_batch_size': 8, 29 | 'test_batch_size': 2, 30 | 'adam_epsilon': 1e-8, 31 | 'gradient_accumulation_steps': 1, 32 | 'max_grad_norm': 1.0, 33 | 'weight_decay': 0, 34 | 'warmup_steps': 0, 35 | 'output_dir': './output', 36 | 'max_seq_length': 512, 37 | 'max_target_length': 50 38 | } 39 | 40 | 41 | class NN_DataHelper(DataHelper): 42 | # 切分词 43 | def on_data_process(self, data: typing.Any, mode: str): 44 | tokenizer: BertTokenizer 45 | max_seq_length = self.max_seq_length_dict[mode] 46 | tokenizer = self.tokenizer 47 | do_lower_case = tokenizer.do_lower_case 48 | label2id = self.label2id 49 | x = data 50 | assert isinstance(x, tuple) 51 | o = tokenizer.encode_plus(text=x[0], text_pair=x[1], max_length=max_seq_length, truncation=True) 52 | seqlen = np.asarray(len(o['input_ids']), dtype=np.int32) 53 | input_ids = seq_padding(o['input_ids'], max_seq_length=max_seq_length, pad_val=tokenizer.pad_token_id) 54 | token_type_ids = seq_padding(o['token_type_ids'], max_seq_length=max_seq_length, pad_val=0) 55 | 56 | d = { 57 | 'input_ids': input_ids, 58 | 'token_type_ids': token_type_ids, 59 | 'labels': input_ids, 60 | 'seqlen': seqlen 61 | } 62 | return d 63 | 64 | # 读取文件 65 | def on_get_corpus(self, files: typing.List, mode: str): 66 | D = [] 67 | for filename in files: 68 | with open(filename, mode='r', encoding='utf-8') as f: 69 | lines = f.readlines() 70 | for i, line in enumerate(lines): 71 | jd = json.loads(line) 72 | D.append((jd['content'], jd['title'])) 73 | if i > 1000: 74 | break 75 | return D 76 | 77 | def collate_fn(self,batch): 78 | o = {} 79 | for i, b in enumerate(batch): 80 | if i == 0: 81 | for k in b: 82 | o[k] = [torch.tensor(b[k])] 83 | else: 84 | for k in b: 85 | o[k].append(torch.tensor(b[k])) 86 | for k in o: 87 | o[k] = torch.stack(o[k]) 88 | 89 | max_len = torch.max(o.pop('seqlen')) 90 | o['input_ids'] = o['input_ids'][:, :max_len] 91 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 92 | o['labels'] = o['labels'][:, :max_len] 93 | return o 94 | 95 | 96 | class MyTransformer(TransformerModelForUnilm, with_pl=True): 97 | def __init__(self, *args, **kwargs): 98 | super(MyTransformer, self).__init__(*args, **kwargs) 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 103 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 104 | 105 | checkpoint_callback = SimpleModelCheckpoint(monitor="loss", 106 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 107 | trainer = Trainer( 108 | callbacks=[checkpoint_callback], 109 | max_epochs=training_args.max_epochs, 110 | max_steps=training_args.max_steps, 111 | accelerator="gpu", 112 | devices=data_args.devices, 113 | enable_progress_bar=True, 114 | default_root_dir=data_args.output_dir, 115 | gradient_clip_val=training_args.max_grad_norm, 116 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 117 | num_sanity_val_steps=0, 118 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 119 | ) 120 | 121 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 122 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 123 | 124 | # 缓存数据集 125 | if data_args.do_train: 126 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 127 | if data_args.do_eval: 128 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 129 | if data_args.do_test: 130 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 131 | 132 | model = MyTransformer(config=config, model_args=model_args, training_args=training_args) 133 | 134 | if not data_args.convert_onnx: 135 | train_datasets = dataHelper.load_distributed_random_sampler( 136 | dataHelper.train_files, 137 | with_load_memory=True, 138 | collate_fn=dataHelper.collate_fn, 139 | batch_size=training_args.train_batch_size, 140 | num_processes = trainer.world_size, process_index=trainer.global_rank) 141 | if train_datasets is not None: 142 | trainer.fit(model, train_dataloaders=train_datasets) 143 | else: 144 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 145 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 146 | if eval_datasets is not None: 147 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 148 | 149 | if test_datasets is not None: 150 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 151 | -------------------------------------------------------------------------------- /task_text_generate/task_autotitle_unilm_distillation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import typing 4 | 5 | import numpy as np 6 | import torch 7 | from deep_training.data_helper import DataHelper 8 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 9 | from deep_training.nlp.layers.mask import unilm_mask 10 | from deep_training.nlp.losses.loss_kl import KLDivLoss 11 | from deep_training.nlp.models.transformer import TransformerModelForUnilm 12 | from deep_training.utils.func import seq_padding 13 | from deep_training.utils.trainer import SimpleModelCheckpoint 14 | from lightning import Trainer 15 | from torch.utils.data import DataLoader, IterableDataset 16 | from transformers import BertTokenizer 17 | from transformers import HfArgumentParser 18 | 19 | train_info_args = { 20 | 'devices': 1, 21 | 'data_backend': 'memory_raw', 22 | 'model_type': 'bert', 23 | 'model_name_or_path': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 24 | 'tokenizer_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese', 25 | 'config_name': '/data/nlp/pre_models/torch/bert/bert-base-chinese/config.json', 26 | 'convert_onnx': False, # 转换onnx模型 27 | 'do_train': True, 28 | 'train_file': [ '/data/nlp/nlp_train_data/thucnews/train.json'], 29 | 'max_steps': 100000, 30 | 'train_batch_size': 8, 31 | 'test_batch_size': 2, 32 | 'adam_epsilon': 1e-8, 33 | 'gradient_accumulation_steps': 1, 34 | 'max_grad_norm': 1.0, 35 | 'weight_decay': 0, 36 | 'warmup_steps': 0, 37 | 'output_dir': './output', 38 | 'max_seq_length': 200, 39 | 'max_target_length': 50 40 | } 41 | 42 | 43 | class NN_DataHelper(DataHelper): 44 | # 切分词 45 | def on_data_process(self, data: typing.Any, mode: str): 46 | tokenizer: BertTokenizer 47 | max_seq_length = self.max_seq_length_dict[mode] 48 | tokenizer = self.tokenizer 49 | do_lower_case = tokenizer.do_lower_case 50 | label2id = self.label2id 51 | x = data 52 | assert isinstance(x, tuple) 53 | o = tokenizer.encode_plus(text=x[0], text_pair=x[1], max_length=max_seq_length, truncation=True) 54 | seqlen = np.asarray(len(o['input_ids']), dtype=np.int32) 55 | input_ids = seq_padding(o['input_ids'], max_seq_length=max_seq_length, pad_val=tokenizer.pad_token_id) 56 | token_type_ids = seq_padding(o['token_type_ids'], max_seq_length=max_seq_length, pad_val=0) 57 | 58 | d = { 59 | 'input_ids': input_ids, 60 | 'token_type_ids': token_type_ids, 61 | 'labels': input_ids, 62 | 'seqlen': seqlen 63 | } 64 | return d 65 | 66 | # 读取文件 67 | def on_get_corpus(self, files: typing.List, mode: str): 68 | D = [] 69 | for filename in files: 70 | with open(filename, mode='r', encoding='utf-8') as f: 71 | lines = f.readlines() 72 | for i, line in enumerate(lines): 73 | jd = json.loads(line) 74 | D.append((jd['content'], jd['title'])) 75 | if i > 1000: 76 | break 77 | return D 78 | 79 | def collate_fn(self,batch): 80 | o = {} 81 | for i, b in enumerate(batch): 82 | if i == 0: 83 | for k in b: 84 | o[k] = [torch.tensor(b[k])] 85 | else: 86 | for k in b: 87 | o[k].append(torch.tensor(b[k])) 88 | for k in o: 89 | o[k] = torch.stack(o[k]) 90 | 91 | max_len = torch.max(o.pop('seqlen')) 92 | o['input_ids'] = o['input_ids'][:, :max_len] 93 | o['token_type_ids'] = o['token_type_ids'][:, :max_len] 94 | o['labels'] = o['labels'][:, :max_len] 95 | return o 96 | 97 | 98 | # 教师12层 99 | class TeacherTransformer(TransformerModelForUnilm, with_pl=True): 100 | def __init__(self, *args, **kwargs): 101 | super(TeacherTransformer, self).__init__(*args, **kwargs) 102 | 103 | def compute_loss(self, *args, **batch) -> tuple: 104 | batch['attention_mask'] = unilm_mask(batch['token_type_ids']) 105 | if getattr(self.config, 'type_vocab_size', 0) != 2: 106 | batch.pop('token_type_ids') 107 | 108 | labels = batch.pop('labels', None) 109 | outputs = self.model(*args, **batch) 110 | hidden_states = outputs[0] 111 | lm_logits = self.model.lm_head(hidden_states) 112 | 113 | if labels is not None: 114 | labels = labels.long() 115 | shift_logits = lm_logits[..., :-1, :].contiguous() 116 | shift_labels = labels[..., 1:].contiguous() 117 | loss = self.model.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 118 | 119 | outputs = (loss, lm_logits, labels) 120 | else: 121 | outputs = (lm_logits,) 122 | return outputs 123 | 124 | 125 | # 学生6层 126 | class StudentTransformer(TransformerModelForUnilm, with_pl=True): 127 | def __init__(self, teacher_model, *args, **kwargs): 128 | super(StudentTransformer, self).__init__(*args, **kwargs) 129 | self.teacher_model = teacher_model 130 | self.kl_loss = KLDivLoss('sum') 131 | 132 | def compute_loss(self, *args, **batch) -> tuple: 133 | labels = batch.pop('labels', None) 134 | 135 | inputs = {k: v for k, v in batch.items()} 136 | inputs['attention_mask'] = unilm_mask(inputs['token_type_ids']) 137 | if getattr(self.config, 'type_vocab_size', 0) != 2: 138 | inputs.pop('token_type_ids') 139 | 140 | outputs = self.model(*args, **inputs, output_hidden_states=True) 141 | # hidden_states = outputs[0] 142 | # 第六层 143 | hidden_states = outputs[2][-6] 144 | lm_logits = self.model.lm_head(hidden_states) 145 | if labels is not None: 146 | labels = labels.long() 147 | shift_logits = lm_logits[..., :-1, :].contiguous() 148 | shift_labels = labels[..., 1:].contiguous() 149 | loss_student = self.model.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 150 | 151 | teacher_logits = self.teacher_model.compute_loss(*args, **batch)[0] 152 | kl_Loss = self.kl_loss([teacher_logits, lm_logits]) 153 | loss_dict = { 154 | 'loss_student': loss_student, 155 | 'kl_Loss': kl_Loss, 156 | 'loss': loss_student * 0.1 + kl_Loss 157 | } 158 | 159 | outputs = (loss_dict, lm_logits, labels) 160 | else: 161 | outputs = (lm_logits,) 162 | return outputs 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments)) 167 | model_args, training_args, data_args = parser.parse_dict(train_info_args) 168 | 169 | checkpoint_callback = SimpleModelCheckpoint(monitor="loss", 170 | every_n_train_steps=2000 // training_args.gradient_accumulation_steps) 171 | trainer = Trainer( 172 | callbacks=[checkpoint_callback], 173 | max_epochs=training_args.max_epochs, 174 | max_steps=training_args.max_steps, 175 | accelerator="gpu", 176 | devices=data_args.devices, 177 | enable_progress_bar=True, 178 | default_root_dir=data_args.output_dir, 179 | gradient_clip_val=training_args.max_grad_norm, 180 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 181 | num_sanity_val_steps=0, 182 | strategy='ddp' if torch.cuda.device_count() > 1 else 'auto', 183 | ) 184 | 185 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 186 | tokenizer, config, label2id, id2label = dataHelper.load_tokenizer_and_config() 187 | 188 | # 缓存数据集 189 | if data_args.do_train: 190 | dataHelper.make_dataset_with_args(data_args.train_file, shuffle=True,mode='train') 191 | if data_args.do_eval: 192 | dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval') 193 | if data_args.do_test: 194 | dataHelper.make_dataset_with_args(data_args.test_file,mode='test') 195 | 196 | 197 | # 是否首先训练模型 198 | is_training_teacher = True 199 | 200 | if is_training_teacher: # 训练teacher 模型 201 | model = TeacherTransformer(config=config, model_args=model_args, training_args=training_args) 202 | else: # 蒸馏模型 203 | teacher_weight = './best_teacher.pt' 204 | # 加载训练好的权重 205 | teacher_model = TeacherTransformer.load_from_checkpoint(teacher_weight, config=config, model_args=model_args, 206 | training_args=training_args) 207 | for k, p in teacher_model.named_parameters(): 208 | p.requires_grad = False 209 | model = StudentTransformer(teacher_model, config=config, model_args=model_args, training_args=training_args) 210 | 211 | if not data_args.convert_onnx: 212 | train_datasets = dataHelper.load_distributed_random_sampler( 213 | dataHelper.train_files, 214 | with_load_memory=True, 215 | collate_fn=dataHelper.collate_fn, 216 | batch_size=training_args.train_batch_size, 217 | num_processes = trainer.world_size, process_index=trainer.global_rank) 218 | if train_datasets is not None: 219 | trainer.fit(model, train_dataloaders=train_datasets) 220 | else: 221 | eval_datasets = dataHelper.load_sequential_sampler(dataHelper.eval_files,batch_size=training_args.eval_batch_size,collate_fn=dataHelper.collate_fn) 222 | test_datasets = dataHelper.load_sequential_sampler(dataHelper.test_files,batch_size=training_args.test_batch_size,collate_fn=dataHelper.collate_fn) 223 | if eval_datasets is not None: 224 | trainer.validate(model, dataloaders=eval_datasets, ckpt_path='./best.pt') 225 | 226 | if test_datasets is not None: 227 | trainer.test(model, dataloaders=test_datasets, ckpt_path='best.pt') 228 | --------------------------------------------------------------------------------