├── .DS_Store ├── .gitignore ├── .idea ├── .gitignore ├── bert_seq2seq.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── bert_seq2seq ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── bart_chinese.cpython-37.pyc │ ├── basic_bert.cpython-37.pyc │ ├── bert_cls_classifier.cpython-37.pyc │ ├── bert_cls_multi_classifier.cpython-37.pyc │ ├── bert_cls_multi_seq2seq.cpython-37.pyc │ ├── bert_encoder.cpython-37.pyc │ ├── bert_relation_extraction.cpython-37.pyc │ ├── bert_seq_labeling.cpython-37.pyc │ ├── bert_seq_labeling_crf.cpython-37.pyc │ ├── config.cpython-37.pyc │ ├── extend_model_method.cpython-37.pyc │ ├── gpt2_generate_model.cpython-37.pyc │ ├── helper.cpython-37.pyc │ ├── seq2seq_model.cpython-37.pyc │ ├── simbert_model.cpython-37.pyc │ ├── t5_ch.cpython-37.pyc │ ├── tokenizer.cpython-37.pyc │ └── utils.cpython-37.pyc ├── bart_chinese.py ├── basic_bert.py ├── bert_cls_classifier.py ├── bert_cls_multi_classifier.py ├── bert_cls_multi_seq2seq.py ├── bert_relation_extraction.py ├── bert_seq_labeling.py ├── bert_seq_labeling_crf.py ├── config.py ├── dataset.py ├── extend_model_method.py ├── gpt2_generate_model.py ├── helper.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── bart_model.cpython-37.pyc │ │ ├── bert_model.cpython-37.pyc │ │ ├── crf.cpython-37.pyc │ │ ├── gpt2_model.cpython-37.pyc │ │ ├── nezha_model.cpython-37.pyc │ │ ├── roberta_model.cpython-37.pyc │ │ └── t5_model.cpython-37.pyc │ ├── bart_model.py │ ├── bert_model.py │ ├── crf.py │ ├── gpt2_model.py │ ├── nezha_model.py │ ├── roberta_model.py │ └── t5_model.py ├── paddle_model │ ├── data │ │ ├── __init__.py │ │ ├── collate.py │ │ ├── iterator.py │ │ ├── sampler.py │ │ ├── tokenizer.py │ │ └── vocab.py │ ├── transformers │ │ ├── __init__.py │ │ ├── attention_utils.py │ │ ├── bert │ │ │ ├── __init__.py │ │ │ ├── modeling.py │ │ │ └── tokenizer.py │ │ ├── generation_utils.py │ │ ├── gpt │ │ │ ├── __init__.py │ │ │ ├── modeling.py │ │ │ └── tokenizer.py │ │ ├── model_utils.py │ │ ├── nezha │ │ │ ├── __init__.py │ │ │ ├── modeling.py │ │ │ └── tokenizer.py │ │ ├── optimization.py │ │ ├── roberta │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── modeling.py │ │ │ └── tokenizer.py │ │ ├── tokenizer_utils.py │ │ └── utils.py │ └── utils │ │ ├── __init__.py │ │ ├── batch_sampler.py │ │ ├── downloader.py │ │ ├── env.py │ │ ├── log.py │ │ ├── profiler.py │ │ └── tools.py ├── seq2seq_model.py ├── simbert_model.py ├── t5_ch.py ├── tokenizer.py └── utils.py ├── examples ├── .DS_Store ├── bart_auto_title_train.py ├── examples_paddle │ ├── bert_couplet_paddle_train.py │ ├── mBart_translation_en_ro.py │ ├── roberta_autotitle_paddle_train.py │ └── roberta_math_paddle_train.py ├── gpt2_ancient_translation_train.py ├── gpt2_english_story_train.py ├── gpt2_explain_dream_train.py ├── gpt2_generate_article.py ├── nezha_auto_title_train.py ├── nezha_couplets_train.py ├── nezha_relation_extract_train.py ├── readme.md ├── relationship_classify_train.py ├── roberta_THUCNews_auto_title.py ├── roberta_auto_title_train.py ├── roberta_coarsness_NER_CRF_train.py ├── roberta_coarsness_NER_train.py ├── roberta_couplets_train.py ├── roberta_fine_grained_NER_CRF_train.py ├── roberta_large_auto_article_gen.py ├── roberta_large_auto_title_train.py ├── roberta_math_ques_train.py ├── roberta_medical_ner_train.py ├── roberta_news_classification_train.py ├── roberta_participle_CRF_train.py ├── roberta_poem_train.py ├── roberta_relation_extract_train.py ├── roberta_semantic_matching_train.py ├── simbert_train.py ├── t5_ancient_translation_train.py └── t5_auto_title_train.py ├── img ├── .DS_Store ├── fenci.png ├── ner-input.png ├── ner-out.png └── ner.jpg ├── setup.py └── test ├── auto_title_test.py ├── bert_english_autotitle_test.py ├── english_t5_test.py ├── get_bert_embedding.py ├── gpt_ancient_translation_test.py ├── gpt_english_story_test.py ├── gpt_explain_dream_test.py ├── gpt_test_english.py ├── nezha_auto_title_test.py ├── nezha_relation_extract_test.py ├── poem_test.py ├── relation_extract_test.py ├── semantic_matching_test.py ├── t5_chinese_autotitle_test.py ├── t5_chinese_test.py ├── test_paddle ├── bert_couplet_test_paddle.py ├── roberta_autotitle_test_paddle.py └── roberta_math_test_paddle.py ├── 做数学题_test.py ├── 新闻标题文本分类_test.py ├── 粗粒度ner_test.py └── 细粒度ner_test.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | nouse 2 | corpus 3 | dist 4 | 写文章 5 | 代码 6 | state_dict 7 | bert_seq2seq.egg-info 8 | .idea 9 | .vscode 10 | 11 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /.idea/bert_seq2seq.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/Users/xingzhaohu/.local/share/virtualenvs/ml-5foBrNl9/bin/python", 3 | "git.ignoreLimitWarning": true 4 | } -------------------------------------------------------------------------------- /bert_seq2seq/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/.DS_Store -------------------------------------------------------------------------------- /bert_seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import load_chinese_base_vocab, Tokenizer 2 | from .utils import load_bert, load_gpt 3 | from .t5_ch import T5PegasusTokenizer, T5Model -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bart_chinese.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bart_chinese.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/basic_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/basic_bert.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_cls_classifier.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_classifier.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_cls_multi_classifier.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_multi_classifier.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_cls_multi_seq2seq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_multi_seq2seq.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_relation_extraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_relation_extraction.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_seq_labeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_seq_labeling.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/bert_seq_labeling_crf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_seq_labeling_crf.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/extend_model_method.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/extend_model_method.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/gpt2_generate_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/gpt2_generate_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/helper.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/seq2seq_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/seq2seq_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/simbert_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/simbert_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/t5_ch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/t5_ch.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/bart_chinese.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq.model.bart_model import BartConfig, BartForConditionalGeneration, BartModel, shift_tokens_right 4 | from bert_seq2seq.tokenizer import Tokenizer,load_chinese_base_vocab 5 | from bert_seq2seq.basic_bert import BasicBart 6 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | class BartGenerationModel(BasicBart): 11 | 12 | def __init__(self, word2idx): 13 | super().__init__() 14 | config = BartConfig() 15 | self.config = config 16 | self.model = BartModel(config) 17 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 18 | 19 | self.word2idx = word2idx 20 | self.tokenizer = Tokenizer(self.word2idx) 21 | self.bos_id = self.word2idx["[CLS]"] 22 | self.eos_id = self.word2idx["[SEP]"] 23 | self.unk_id = self.word2idx["[UNK]"] 24 | 25 | def forward(self, input_ids, decoder_input_ids, labels=None): 26 | input_ids = input_ids.to(self.device) 27 | decoder_input_ids = decoder_input_ids.to(self.device) 28 | if labels is not None: 29 | labels = labels.to(self.device) 30 | if labels is not None: 31 | if decoder_input_ids is None: 32 | decoder_input_ids = shift_tokens_right( 33 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 34 | ) 35 | 36 | decoder_out, _ = self.model( 37 | input_ids, 38 | decoder_input_ids=decoder_input_ids, 39 | ) 40 | 41 | lm_logits = self.lm_head(decoder_out) 42 | target_mask = (decoder_input_ids > 0).float().view(-1) 43 | masked_lm_loss = None 44 | if labels is not None: 45 | loss_fct = nn.CrossEntropyLoss() 46 | masked_lm_loss = (loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) * target_mask).sum() / target_mask.sum() 47 | 48 | output = (lm_logits,) 49 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 50 | 51 | 52 | def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=True): 53 | 54 | token_out = self.tokenizer.encode(text, max_length=input_max_length) 55 | if len(token_out) == 2: 56 | token_ids = token_out[0] 57 | else: 58 | token_ids = token_out 59 | if not add_eos: 60 | token_ids = token_ids[:-1] 61 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1) 62 | output_ids = [] 63 | 64 | input_decoder_ids = torch.tensor(self.bos_id, device=self.device, dtype=torch.long).view(1, -1) 65 | with torch.no_grad(): 66 | for step in range(out_max_length): 67 | scores = self.model(input_ids=token_ids, decoder_input_ids=input_decoder_ids)[0] 68 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0) 69 | logit_score[self.unk_id] = -float('Inf') 70 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p) 71 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 72 | if self.eos_id == next_token.item(): 73 | break 74 | output_ids.append(next_token.item()) 75 | input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1) 76 | 77 | return self.tokenizer.decode(output_ids) -------------------------------------------------------------------------------- /bert_seq2seq/bert_cls_classifier.py: -------------------------------------------------------------------------------- 1 | ## bert encoder模型 2 | import torch 3 | import torch.nn as nn 4 | from bert_seq2seq.tokenizer import Tokenizer 5 | from bert_seq2seq.basic_bert import BasicBert 6 | 7 | class BertClsClassifier(BasicBert): 8 | """ 9 | """ 10 | def __init__(self, word2ix, target_size, model_name="roberta"): 11 | super(BertClsClassifier, self).__init__(word2ix=word2ix, model_name=model_name) 12 | self.target_size = target_size 13 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size) 14 | 15 | def compute_loss(self, predictions, labels): 16 | """ 17 | 计算loss 18 | predictions: (batch_size, 1) 19 | """ 20 | predictions = predictions.view(-1, self.target_size) 21 | labels = labels.view(-1) 22 | loss = nn.CrossEntropyLoss(reduction="mean") 23 | return loss(predictions, labels) 24 | 25 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1): 26 | 27 | text = text.to(self.device) 28 | if position_enc is not None: 29 | position_enc = position_enc.to(self.device) 30 | if labels is not None: 31 | labels = labels.to(self.device) 32 | all_layers, pooled_out = self.bert(text, 33 | output_all_encoded_layers=True) 34 | if use_layer_num != -1: 35 | pooled_out = all_layers[use_layer_num][:, 0] 36 | 37 | predictions = self.final_dense(pooled_out) 38 | if labels is not None: 39 | ## 计算loss 40 | loss = self.compute_loss(predictions, labels) 41 | return predictions, loss 42 | else : 43 | return predictions 44 | -------------------------------------------------------------------------------- /bert_seq2seq/bert_cls_multi_classifier.py: -------------------------------------------------------------------------------- 1 | ## bert encoder模型 2 | from multiprocessing import pool 3 | import torch 4 | import torch.nn as nn 5 | from bert_seq2seq.tokenizer import Tokenizer 6 | from bert_seq2seq.basic_bert import BasicBert 7 | 8 | class BertClsMultiClassifier(BasicBert): 9 | """ 10 | """ 11 | def __init__(self, word2ix, target_size, model_name="roberta"): 12 | super(BertClsMultiClassifier, self).__init__(word2ix=word2ix, model_name=model_name) 13 | self.target_size = target_size 14 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size) 15 | 16 | def compute_loss(self, predictions, labels): 17 | """ 18 | 计算loss 19 | predictions: (batch_size, 1) 20 | """ 21 | # predictions = torch.sigmoid(predictions) 22 | batch_size = predictions.shape[0] 23 | # predictions = predictions.view(-1) 24 | # labels = labels.view(-1) 25 | loss = nn.BCEWithLogitsLoss(reduction="none") 26 | return loss(predictions, labels).sum() / batch_size 27 | 28 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1): 29 | 30 | text = text.to(self.device) 31 | if position_enc is not None: 32 | position_enc = position_enc.to(self.device) 33 | if labels is not None: 34 | labels = labels.to(self.device) 35 | all_layers, pooled_out = self.bert(text, 36 | output_all_encoded_layers=True) 37 | 38 | if use_layer_num != -1: 39 | pooled_out = all_layers[use_layer_num][:, 0] 40 | 41 | predictions = self.final_dense(pooled_out) 42 | 43 | if labels is not None: 44 | ## 计算loss 45 | loss = self.compute_loss(predictions, labels) 46 | return predictions, loss 47 | else : 48 | return predictions 49 | -------------------------------------------------------------------------------- /bert_seq2seq/bert_relation_extraction.py: -------------------------------------------------------------------------------- 1 | ## bert 关系抽取模型 2 | import torch 3 | import torch.nn as nn 4 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, Tokenizer 5 | from bert_seq2seq.basic_bert import BasicBert 6 | 7 | class BertRelationExtrac(BasicBert): 8 | """ 9 | """ 10 | def __init__(self, word2ix, predicate_num, model_name="roberta"): 11 | super(BertRelationExtrac, self).__init__(word2ix=word2ix, model_name=model_name) 12 | 13 | self.predicate_num = predicate_num 14 | self.subject_pred_norm = nn.LayerNorm(self.config.hidden_size) 15 | self.subject_pred = nn.Linear(self.config.hidden_size, 2) 16 | self.activation = nn.Sigmoid() 17 | self.object_pred = nn.Linear(self.config.hidden_size, 2 * self.predicate_num) 18 | 19 | 20 | def binary_crossentropy(self, labels, pred): 21 | labels = labels.float() 22 | loss = (-labels) * torch.log(pred) - (1.0 - labels) * torch.log(1.0 - pred) 23 | return loss 24 | 25 | def compute_total_loss(self, subject_pred, object_pred, subject_labels, object_labels): 26 | """ 27 | 计算loss 28 | """ 29 | subject_loss = self.binary_crossentropy(subject_labels, subject_pred) 30 | subject_loss = torch.mean(subject_loss, dim=2) 31 | subject_loss = (subject_loss * self.target_mask).sum() / self.target_mask.sum() 32 | 33 | object_loss = self.binary_crossentropy(object_labels, object_pred) 34 | object_loss = torch.mean(object_loss, dim=3).sum(dim=2) 35 | object_loss = (object_loss * self.target_mask).sum() / self.target_mask.sum() 36 | 37 | return subject_loss + object_loss 38 | 39 | def extrac_subject(self, output, subject_ids): 40 | ## 抽取subject的向量表征 41 | batch_size = output.shape[0] 42 | hidden_size = output.shape[-1] 43 | start_end = torch.gather(output, index=subject_ids.unsqueeze(-1).expand((batch_size, 2, hidden_size)), dim=1) 44 | subject = torch.cat((start_end[:, 0], start_end[:, 1]), dim=-1) 45 | return subject 46 | 47 | def forward(self, text, subject_ids, position_enc=None, subject_labels=None, object_labels=None, use_layer_num=-1): 48 | if use_layer_num != -1: 49 | 50 | raise Exception("目前 use_layer_num 只支持-1") 51 | # 计算target mask 52 | text = text.to(self.device) 53 | subject_ids = subject_ids.to(self.device) 54 | self.target_mask = (text > 0).float() 55 | enc_layers, _ = self.bert(text, 56 | output_all_encoded_layers=True) 57 | 58 | squence_out = enc_layers[use_layer_num] 59 | 60 | tokens_hidden_state, _ = self.cls(squence_out) 61 | 62 | subject_pred_out = self.subject_pred(self.subject_pred_norm(tokens_hidden_state)) 63 | 64 | subject_pred_act = self.activation(subject_pred_out) 65 | 66 | subject_pred_act = subject_pred_act**2 67 | 68 | subject_vec = self.extrac_subject(tokens_hidden_state, subject_ids) 69 | object_layer_norm = self.layer_norm_cond([tokens_hidden_state, subject_vec]) 70 | object_pred_out = self.object_pred(object_layer_norm) 71 | object_pred_act = self.activation(object_pred_out) 72 | 73 | object_pred_act = object_pred_act**4 74 | 75 | batch_size, seq_len, target_size = object_pred_act.shape 76 | 77 | object_pred_act = object_pred_act.reshape((batch_size, seq_len, int(target_size/2), 2)) 78 | predictions = object_pred_act 79 | if subject_labels is not None and object_labels is not None: 80 | ## 计算loss 81 | subject_labels = subject_labels.to(self.device) 82 | object_labels = object_labels.to(self.device) 83 | loss = self.compute_total_loss(subject_pred_act, object_pred_act, subject_labels, object_labels) 84 | return predictions, loss 85 | else : 86 | return predictions 87 | 88 | def predict_subject(self, text,use_layer_num=-1): 89 | if use_layer_num != -1: 90 | 91 | raise Exception("use_layer_num目前只支持-1") 92 | text = text.to(self.device) 93 | 94 | self.target_mask = (text > 0).float() 95 | enc_layers, _ = self.bert(text, output_all_encoded_layers=True) 96 | squence_out = enc_layers[use_layer_num] 97 | tokens_hidden_state, _ = self.cls(squence_out) 98 | subject_pred_out = self.subject_pred(self.subject_pred_norm(tokens_hidden_state)) 99 | subject_pred_act = self.activation(subject_pred_out) 100 | 101 | subject_pred_act = subject_pred_act**2 102 | 103 | # subject_pred_act = (subject_pred_act > 0.5).long() 104 | return subject_pred_act 105 | 106 | def predict_object_predicate(self, text, subject_ids, use_layer_num=-1): 107 | if use_layer_num != -1: 108 | 109 | raise Exception("use_layer_num目前只支持-1") 110 | # 计算target mask 111 | text = text.to(self.device) 112 | subject_ids = subject_ids.to(self.device) 113 | 114 | enc_layers, _ = self.bert(text, output_all_encoded_layers=True) 115 | squence_out = enc_layers[use_layer_num] 116 | tokens_hidden_state, _ = self.cls(squence_out) 117 | 118 | subject_vec = self.extrac_subject(tokens_hidden_state, subject_ids) 119 | object_layer_norm = self.layer_norm_cond([tokens_hidden_state, subject_vec]) 120 | object_pred_out = self.object_pred(object_layer_norm) 121 | object_pred_act = self.activation(object_pred_out) 122 | 123 | object_pred_act = object_pred_act**4 124 | 125 | batch_size, seq_len, target_size = object_pred_act.shape 126 | object_pred_act = object_pred_act.view((batch_size, seq_len, int(target_size/2), 2)) 127 | predictions = object_pred_act 128 | return predictions -------------------------------------------------------------------------------- /bert_seq2seq/bert_seq_labeling.py: -------------------------------------------------------------------------------- 1 | ## bert encoder模型 2 | import torch 3 | import torch.nn as nn 4 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, Tokenizer 5 | from bert_seq2seq.basic_bert import BasicBert 6 | 7 | class BertSeqLabeling(BasicBert): 8 | """ 9 | """ 10 | def __init__(self, word2ix, target_size, model_name="roberta"): 11 | super(BertSeqLabeling, self).__init__(word2ix=word2ix, model_name=model_name) 12 | self.target_size = target_size 13 | 14 | 15 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size) 16 | 17 | def compute_loss(self, predictions, labels): 18 | """ 19 | 计算loss 20 | predictions: (batch_size, 1) 21 | """ 22 | predictions = predictions.view(-1, self.target_size) 23 | labels = labels.view(-1) 24 | self.target_mask = self.target_mask.view(-1) 25 | loss = nn.CrossEntropyLoss(reduction="none") 26 | return (loss(predictions, labels) * self.target_mask).sum() / self.target_mask.sum() 27 | 28 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1): 29 | if use_layer_num != -1: 30 | if use_layer_num < 0 or use_layer_num > 7: 31 | # 越界 32 | raise Exception("层数选择错误,因为bert base模型共8层,所以参数只只允许0 - 7, 默认为-1,取最后一层") 33 | self.target_mask = (text > 0).float().to(self.device) 34 | text = text.to(self.device) 35 | if position_enc is not None: 36 | position_enc = position_enc.to(self.device) 37 | if labels is not None: 38 | labels = labels.to(self.device) 39 | 40 | enc_layers, _ = self.bert(text, 41 | output_all_encoded_layers=True) 42 | squence_out = enc_layers[use_layer_num] 43 | 44 | tokens_hidden_state, _ = self.cls(squence_out) 45 | predictions = self.final_dense(tokens_hidden_state) 46 | if labels is not None: 47 | ## 计算loss 48 | loss = self.compute_loss(predictions, labels) 49 | return predictions, loss 50 | else : 51 | return predictions 52 | -------------------------------------------------------------------------------- /bert_seq2seq/bert_seq_labeling_crf.py: -------------------------------------------------------------------------------- 1 | ## bert encoder模型 2 | import torch 3 | import torch.nn as nn 4 | from bert_seq2seq.tokenizer import Tokenizer 5 | from bert_seq2seq.model.crf import CRFLayer 6 | from bert_seq2seq.basic_bert import BasicBert 7 | 8 | class BertSeqLabelingCRF(BasicBert): 9 | """ 10 | """ 11 | def __init__(self, word2ix, target_size, model_name="roberta"): 12 | super(BertSeqLabelingCRF, self).__init__(word2ix=word2ix, model_name=model_name) 13 | self.target_size = target_size 14 | 15 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size) 16 | self.crf_layer = CRFLayer(self.target_size) 17 | 18 | def compute_loss(self, predictions, labels): 19 | """ 20 | 计算loss 21 | """ 22 | loss = self.crf_layer(predictions, labels, self.target_mask) 23 | 24 | return loss.mean() 25 | 26 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1): 27 | if use_layer_num != -1: 28 | # 越界 29 | raise Exception("use_layer_num目前只支持-1") 30 | # 计算target mask 31 | self.target_mask = (text > 0).float().to(self.device) 32 | text = text.to(self.device) 33 | if position_enc is not None : 34 | position_enc = position_enc.to(self.device) 35 | if labels is not None : 36 | labels = labels.to(self.device) 37 | enc_layers, _ = self.bert(text, 38 | output_all_encoded_layers=True) 39 | squence_out = enc_layers[use_layer_num] 40 | 41 | tokens_hidden_state, _ = self.cls(squence_out) 42 | # print(cls_token) 43 | predictions = self.final_dense(tokens_hidden_state) 44 | 45 | if labels is not None: 46 | ## 计算loss 47 | loss = self.compute_loss(predictions, labels) 48 | return predictions, loss 49 | else : 50 | return predictions 51 | -------------------------------------------------------------------------------- /bert_seq2seq/config.py: -------------------------------------------------------------------------------- 1 | 2 | max_length = 256 3 | 4 | yayun_list = [ 5 | "东同铜桐筒童僮瞳中衷忠虫终戎崇嵩弓躬宫融雄熊穹穷冯风枫丰充隆空公功工攻蒙笼聋珑洪红鸿虹丛翁聪通蓬烘潼胧砻峒螽梦讧冻忡酆恫总侗窿懵庞种盅芎倥艨绒葱匆骢", 6 | "冬农宗钟龙舂松冲容蓉庸封胸雍浓重从逢缝踪茸峰锋烽蛩慵恭供淙侬松凶墉镛佣溶邛共憧喁邕壅纵龚枞脓淞匈汹禺蚣榕彤", 7 | "江扛窗邦缸降双庞逄腔撞幢桩淙豇", 8 | "支枝移为垂吹陂碑奇宜仪皮儿离施知驰池规危夷师姿迟眉悲之芝时诗棋旗辞词期祠基疑姬丝司葵医帷思滋持随痴维卮麋螭麾墀弥慈遗肌脂雌披嬉尸狸炊篱兹差疲茨卑亏蕤陲骑曦歧岐谁斯私窥熙欺疵赀笞羁彝颐资糜饥衰锥姨楣夔涯伊蓍追", 9 | "缁箕椎罴篪萎匙脾坻嶷治骊尸綦怡尼漪累牺饴而鸱推縻璃祁绥逵羲羸肢骐訾狮奇嗤咨堕其睢漓蠡噫馗辎胝鳍蛇陴淇淄丽筛厮氏痍貔比僖贻祺嘻鹂瓷琦嵋怩熹孜台蚩罹魑丕琪耆衰惟剂提禧居栀戏畸椅磁痿离佳虽仔寅委崎隋逶倭黎犁郦", 10 | "微薇晖徽挥韦围帏违霏菲妃绯飞非扉肥腓威畿机几讥矶稀希衣依沂巍归诽痱欷葳颀圻", 11 | "鱼渔初书舒居裾车渠余予誉舆胥狙锄疏蔬梳虚嘘徐猪闾庐驴诸除储如墟与畲疽苴于茹蛆且沮祛蜍榈淤好雎纾躇趄滁屠据匹咀衙涂虑", 12 | "虞愚娱隅刍无芜巫于盂衢儒濡襦须株诛蛛殊瑜榆谀愉腴区驱躯朱珠趋扶符凫雏敷夫肤纡输枢厨俱驹模谟蒲胡湖瑚乎壶狐弧孤辜姑觚菰徒途涂荼图屠奴呼吾七虞梧吴租卢鲈苏酥乌枯都铺禺诬竽吁瞿劬需俞逾觎揄萸臾渝岖镂娄夫孚桴俘迂姝拘摹糊鸪沽呱蛄驽逋舻垆徂孥泸栌嚅蚨诹扶母毋芙喁颅轳句邾洙麸机膜瓠恶芋呕驺喻枸侏龉葫懦帑拊", 13 | "齐蛴脐黎犁梨黧妻萋凄堤低氐诋题提荑缔折篦鸡稽兮奚嵇蹊倪霓西栖犀嘶撕梯鼙批挤迷泥溪圭闺睽奎携畦骊鹂儿", 14 | "佳街鞋牌柴钗差涯阶偕谐骸排乖怀淮豺侪埋霾斋娲蜗娃哇皆喈揩蛙楷槐俳", 15 | "灰恢魁隈回徊枚梅媒煤瑰雷催摧堆陪杯醅嵬推开哀埃台苔该才材财裁来莱栽哉灾猜胎孩虺崔裴培坏垓陔徕皑傀崃诙煨桅唉颏能茴酶偎隗咳", 16 | "真因茵辛新薪晨辰臣人仁神亲申伸绅身宾滨邻鳞麟珍尘陈春津秦频苹颦银垠筠巾民珉缗贫淳醇纯唇伦纶轮沦匀旬巡驯钧均臻榛姻寅彬鹑皴遵循振甄岷谆椿询恂峋莘堙屯呻粼磷辚濒闽豳逡填狺泯洵溱夤荀竣娠纫鄞抡畛嶙斌氤", 17 | "文闻纹云氛分纷芬焚坟群裙君军勤斤筋勋薰曛熏荤耘芸汾氲员欣芹殷昕贲郧雯蕲", 18 | "元原源园猿辕坦烦繁蕃樊翻萱喧冤言轩藩魂浑温孙门尊存蹲敦墩暾屯豚村盆奔论坤昏婚阍痕根恩吞沅媛援爰幡番反埙鸳宛掀昆琨鲲扪荪髡跟垠抡蕴犍袁怨蜿溷昆炖饨臀喷纯", 19 | "寒韩翰丹殚单安难餐滩坛檀弹残干肝竿乾阑栏澜兰看刊丸桓纨端湍酸团抟攒官观冠鸾銮栾峦欢宽盘蟠漫汗郸叹摊奸剜棺钻瘢谩瞒潘胖弁拦完莞獾拌掸萑倌繁曼馒鳗谰洹滦", 20 | "删潸关弯湾还环鹌鬟寰班斑颁般蛮颜菅攀顽山鳏艰闲娴悭孱潺殷扳讪患", 21 | "先前千阡笺天坚肩贤弦烟燕莲怜田填钿年颠巅牵妍研眠渊涓蠲编玄县泉迁仙鲜钱煎然延筵禅蝉缠连联涟篇偏便全宣镌穿川缘鸢铅捐旋娟船涎鞭专圆员乾虔愆骞权拳椽传焉跹溅舷咽零骈阗鹃翩扁平沿诠痊悛荃遄卷挛戋佃滇婵颛犍搴嫣癣澶单竣鄢扇键蜷棉", 22 | "萧箫挑貂刁凋雕迢条跳苕调枭浇聊辽寥撩僚寮尧幺宵消霄绡销超朝潮嚣樵谯骄娇焦蕉椒饶烧遥姚摇谣瑶韶昭招飚标杓镳瓢苗描猫要腰邀乔桥侨妖夭漂飘翘祧佻徼侥哨娆陶橇劭潇骁獠料硝灶鹞钊蛲峤轿荞嘹逍燎憔剽", 23 | "肴巢交郊茅嘲钞包胶爻苞梢蛟庖匏坳敲胞抛鲛崤铙炮哮捎茭淆泡跑咬啁教咆鞘剿刨佼抓姣唠", 24 | "豪毫操髦刀萄猱桃糟漕旄袍挠蒿涛皋号陶翱敖遭篙羔高嘈搔毛艘滔骚韬缫膏牢醪逃槽劳洮叨绸饕骜熬臊涝淘尻挑嚣捞嗥薅咎谣", 25 | "歌多罗河戈阿和波科柯陀娥蛾鹅萝荷过磨螺禾哥娑驼佗沱峨那苛诃珂轲莎蓑梭婆摩魔讹坡颇俄哦呵皤么涡窝茄迦伽磋跎番蹉搓驮献蝌箩锅倭罗嵯锣", 26 | "麻花霞家茶华沙车牙蛇瓜斜邪芽嘉瑕纱鸦遮叉葩奢楂琶衙赊涯夸巴加耶嗟遐笳差蟆蛙虾拿葭茄挝呀枷哑娲爬杷蜗爷芭鲨珈骅娃哇洼畲丫夸裟瘕些桠杈痂哆爹椰咤笆桦划迦揶吾佘", 27 | "阳杨扬香乡光昌堂章张王房芳长塘妆常凉霜藏场央泱鸯秧嫱床方浆觞梁娘庄黄仓皇装殇襄骧相湘箱缃创忘芒望尝偿樯枪坊囊郎唐狂强肠康冈苍匡荒遑行妨棠翔良航倡伥羌庆姜僵缰疆粮穰将墙桑刚祥详洋徉佯粱量羊伤汤鲂樟彰漳璋猖商防", 28 | "筐煌隍凰蝗惶璜廊浪裆沧纲亢吭潢钢丧盲簧忙茫傍汪臧琅当庠裳昂障糖疡锵杭邙赃滂禳攘瓤抢螳踉眶炀阊彭蒋亡殃蔷镶孀搪彷胱磅膀螃八庚更羹盲横觥彭棚亨英瑛烹平评京惊荆明盟鸣荣莹兵卿生甥笙牲檠擎鲸迎行衡耕萌氓宏闳茎莺樱泓橙筝争清情晴精睛菁旌晶盈瀛嬴营婴缨贞成盛城诚呈程声征正轻名令并倾萦琼赓撑瞠枪伧峥猩珩蘅铿嵘丁嘤鹦铮砰绷轰訇瞪侦顷榜抨趟坪请", 29 | "青经泾形刑邢型陉亭庭廷霆蜓停丁宁钉仃馨星腥醒惺娉灵棂龄铃苓伶零玲翎瓴囹聆听厅汀冥溟螟铭瓶屏萍荧萤荥扃町瞑暝", 30 | "蒸承丞惩陵凌绫冰膺鹰应蝇绳渑乘升胜兴缯凭仍兢矜征凝称登灯僧增曾憎层能棱朋鹏弘肱腾滕藤恒冯瞢扔誊", 31 | "尤邮优忧流留榴骝刘由油游猷悠攸牛修羞秋周州洲舟酬仇柔俦畴筹稠邱抽湫遒收鸠不愁休囚求裘球浮谋牟眸矛侯猴喉讴沤鸥瓯楼娄陬偷头投钩沟幽彪疣绸浏瘤犹啾酋售蹂揉搜叟邹貅泅球逑俅蜉桴罘欧搂抠髅蝼兜句妯惆呕缪繇偻篓馗区", 32 | "侵寻浔林霖临针箴斟沈深淫心琴禽擒钦衾吟今襟金音阴岑簪琳琛椹谌忱壬任黔歆禁喑森参淋郴妊湛", 33 | "覃潭谭参骖南男谙庵含涵函岚蚕探贪耽龛堪戡谈甘三酣篮柑惭蓝郯婪庵颔褴澹", 34 | "盐檐廉帘嫌严占髯谦奁纤签瞻蟾炎添兼缣尖潜阎镰粘淹箝甜恬拈暹詹渐歼黔沾苫占崦阉砭", 35 | "咸缄谗衔岩帆衫杉监凡馋芟喃嵌掺搀严"] -------------------------------------------------------------------------------- /bert_seq2seq/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | 4 | 5 | def gpt_collate_fn(batch): 6 | """ 7 | 动态padding, batch为一部分sample 8 | """ 9 | 10 | def padding(indice, max_length, pad_idx=0): 11 | """ 12 | pad 函数 13 | """ 14 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 15 | return torch.tensor(pad_indice) 16 | 17 | token_ids = [data["token_ids"] for data in batch] 18 | max_length = max([len(t) for t in token_ids]) 19 | 20 | token_ids_padded = padding(token_ids, max_length) 21 | target_ids_padded = token_ids_padded.clone() 22 | target_ids_padded[target_ids_padded == 0] = -100 23 | 24 | return token_ids_padded, target_ids_padded 25 | 26 | def bert_seq2seq_collate_fn(batch): 27 | """ 28 | 动态padding, batch为一部分sample 29 | """ 30 | 31 | def padding(indice, max_length, pad_idx=0): 32 | """ 33 | pad 函数 34 | """ 35 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 36 | return torch.tensor(pad_indice) 37 | 38 | token_ids = [data["token_ids"] for data in batch] 39 | max_length = max([len(t) for t in token_ids]) 40 | token_type_ids = [data["token_type_ids"] for data in batch] 41 | 42 | token_ids_padded = padding(token_ids, max_length) 43 | token_type_ids_padded = padding(token_type_ids, max_length) 44 | target_ids_padded = token_ids_padded[:, 1:].contiguous() 45 | 46 | return token_ids_padded, token_type_ids_padded, target_ids_padded 47 | 48 | def bert_cls_collate_fn(batch): 49 | """ 50 | 动态padding, batch为一部分sample 51 | """ 52 | 53 | def padding(indice, max_length, pad_idx=0): 54 | """ 55 | pad 函数 56 | """ 57 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 58 | return torch.tensor(pad_indice) 59 | 60 | token_ids = [data["token_ids"] for data in batch] 61 | max_length = max([len(t) for t in token_ids]) 62 | token_type_ids = [data["token_type_ids"] for data in batch] 63 | target_ids = [data["target_id"] for data in batch] 64 | target_ids = torch.tensor(target_ids, dtype=torch.long) 65 | 66 | token_ids_padded = padding(token_ids, max_length) 67 | token_type_ids_padded = padding(token_type_ids, max_length) 68 | # target_ids_padded = token_ids_padded[:, 1:].contiguous() 69 | 70 | return token_ids_padded, token_type_ids_padded, target_ids 71 | 72 | def bert_squence_label_collate_fn(batch): 73 | """ 74 | 动态padding, batch为一部分sample 75 | """ 76 | 77 | def padding(indice, max_length, pad_idx=0): 78 | """ 79 | pad 函数 80 | """ 81 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 82 | return torch.tensor(pad_indice) 83 | 84 | token_ids = [data["token_ids"] for data in batch] 85 | 86 | max_length = max([len(t) for t in token_ids]) 87 | token_type_ids = [data["token_type_ids"] for data in batch] 88 | target_ids = [data["target_id"] for data in batch] 89 | 90 | token_ids_padded = padding(token_ids, max_length) 91 | token_type_ids_padded = padding(token_type_ids, max_length) 92 | # target_ids_padded = token_ids_padded[:, 1:].contiguous() 93 | target_ids_padded = padding(target_ids, max_length) 94 | 95 | return token_ids_padded, token_type_ids_padded, target_ids_padded 96 | 97 | class AbstractDataset(Dataset): 98 | 99 | def __init__(self, model_name, model_class, collate_fn=None) -> None: 100 | super().__init__() 101 | if model_name == "gpt2": 102 | self.collate_fn = gpt_collate_fn 103 | if model_name == "bert" or model_name == "roberta" or model_name == "roberta-large" or model_name == "nezha": 104 | if model_class == "seq2seq": 105 | self.collate_fn = bert_seq2seq_collate_fn 106 | elif model_class == "cls": 107 | self.collate_fn = bert_cls_collate_fn 108 | elif model_class == "sequence_labeling_crf" or model_class == "sequence_labeling": 109 | self.collate_fn = bert_squence_label_collate_fn 110 | 111 | 112 | if collate_fn is not None : 113 | self.collate_fn = collate_fn 114 | 115 | 116 | def __getitem__(self, index): 117 | return NotImplemented 118 | 119 | def __len__(self): 120 | return NotImplemented -------------------------------------------------------------------------------- /bert_seq2seq/gpt2_generate_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering 5 | from bert_seq2seq.model.gpt2_model import GPT2LMHeadModel, GPT2Config 6 | from bert_seq2seq.basic_bert import BasicGPT 7 | from bert_seq2seq.tokenizer import Tokenizer 8 | import torch.nn.functional as F 9 | 10 | from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \ 11 | TopPLogitsProcessor, ListProcessor 12 | 13 | class GPT2(BasicGPT): 14 | def __init__(self, word2ix, tokenizer=None, 15 | ): 16 | super().__init__() 17 | self.word2ix = word2ix 18 | if tokenizer is not None: 19 | self.tokenizer = tokenizer 20 | else: 21 | self.tokenizer = Tokenizer(word2ix) 22 | self.config = GPT2Config(len(word2ix)) 23 | self.model = GPT2LMHeadModel(self.config) 24 | 25 | def sample_generate(self, text, input_max_length=256, out_max_length=200, 26 | top_k=30, top_p=1.0, add_eos=False, repetition_penalty=1.0, 27 | temperature=1.0): 28 | 29 | lp = [RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty), 30 | TemperatureLogitsProcessor(temperature=temperature), 31 | TopKLogitsProcessor(top_k=top_k), 32 | TopPLogitsProcessor(top_p=top_p) 33 | ] 34 | 35 | self.list_processor = ListProcessor(lp) 36 | 37 | token_ids, _ = self.tokenizer.encode(text, max_length=input_max_length) 38 | if not add_eos: 39 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long)[:-1].view(1, -1) 40 | else: 41 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1) 42 | 43 | output_ids = [] 44 | sep_id = self.word2ix["[SEP]"] 45 | with torch.no_grad(): 46 | for step in range(out_max_length): 47 | _, scores = self.model(token_ids) 48 | logit_score = torch.log_softmax(scores[:, -1], dim=-1) 49 | logit_score[:, self.word2ix["[UNK]"]] = -float('Inf') 50 | 51 | filtered_logits = self.list_processor(token_ids, logit_score) 52 | 53 | # filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p) 54 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 55 | if sep_id == next_token.item(): 56 | break 57 | output_ids.append(next_token.item()) 58 | token_ids = torch.cat((token_ids, next_token.long()), dim=1) 59 | 60 | return self.tokenizer.decode(np.array(output_ids)) 61 | 62 | def sample_generate_once(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, sep="。"): 63 | 64 | token_ids, _ = self.tokenizer.encode(text, max_length=input_max_length) 65 | # 不加任何的开始符号和结束符号,就是输入一句话。 66 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long)[1:-1].view(1, -1) 67 | 68 | 69 | output_ids = [] 70 | sep_id = self.word2ix[sep] # 句号结尾 71 | with torch.no_grad(): 72 | for step in range(out_max_length): 73 | _, scores = self.model(token_ids) 74 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0) 75 | logit_score[self.word2ix["[UNK]"]] = -float('Inf') 76 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p) 77 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 78 | if sep_id == next_token.item(): 79 | break 80 | output_ids.append(next_token.item()) 81 | token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1) 82 | 83 | return self.tokenizer.decode(np.array(output_ids)) 84 | 85 | def sample_generate_english(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=False): 86 | 87 | token_ids = self.tokenizer.encode(text, max_length=input_max_length, truncation=True) 88 | if add_eos: 89 | token_ids = token_ids + [self.word2ix[""]] 90 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1) 91 | output_ids = [] 92 | sep_id = self.word2ix[""] 93 | with torch.no_grad(): 94 | for step in range(out_max_length): 95 | _, scores = self.model(token_ids) 96 | # print(scores.shape) 97 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0) 98 | # print(logit_score.shape) 99 | logit_score[self.word2ix["unk"]] = -float('Inf') 100 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p) 101 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 102 | if sep_id == next_token.item(): 103 | break 104 | # pass 105 | output_ids.append(next_token.item()) 106 | token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1) 107 | 108 | return self.tokenizer.decode(output_ids) 109 | 110 | 111 | def _make_causal_mask(self, input_ids_shape: torch.Size): 112 | 113 | bsz, tgt_len = input_ids_shape 114 | mask = torch.full((tgt_len, tgt_len), 0.0).to(self.device) 115 | mask_cond = torch.arange(mask.size(-1)).to(self.device) 116 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1.0) 117 | 118 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) 119 | 120 | 121 | def forward(self, x, labels=None): 122 | if labels is not None: 123 | labels = labels.to(self.device) 124 | x = x.to(self.device) 125 | # input_ids = torch.tensor([[1, 2, 3, 5, -100], [4, 5, 6, -100, -100]]) 126 | attention_mask = self._make_causal_mask(x.shape) 127 | pad_mask = (labels != -100).float() 128 | attention_mask = attention_mask * pad_mask.unsqueeze(1).unsqueeze(1) 129 | 130 | loss, lm_logit = self.model(x, labels=labels, attention_mask=attention_mask) 131 | 132 | return loss, lm_logit -------------------------------------------------------------------------------- /bert_seq2seq/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__init__.py -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/bart_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/bart_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/bert_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/bert_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/crf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/crf.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/gpt2_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/gpt2_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/nezha_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/nezha_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/roberta_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/roberta_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/__pycache__/t5_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/t5_model.cpython-37.pyc -------------------------------------------------------------------------------- /bert_seq2seq/model/crf.py: -------------------------------------------------------------------------------- 1 | ## crf layer 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class CRFLayer(nn.Module): 7 | """ 8 | """ 9 | def __init__(self, output_dim): 10 | super(CRFLayer, self).__init__() 11 | 12 | self.output_dim = output_dim 13 | self.trans = nn.Parameter(torch.Tensor(output_dim, output_dim)) 14 | self.trans.data.uniform_(-0.1, 0.1) 15 | 16 | def compute_loss(self, y_pred, y_true, mask): 17 | """ 18 | 计算CRF损失 19 | """ 20 | y_pred = y_pred * mask 21 | y_true = y_true * mask 22 | target_score = self.target_score(y_pred, y_true) 23 | log_norm = self.log_norm_step(y_pred, mask) 24 | log_norm = self.logsumexp(log_norm, dim=1)# 计算标量 25 | return log_norm - target_score 26 | 27 | def forward(self, y_pred, y_true, mask): 28 | """ 29 | y_true: [[1, 2, 3], [2, 3, 0] ] 30 | mask: [[1, 1, 1], [1, 1, 0]] 31 | """ 32 | if y_pred.shape[0] != mask.shape[0] or y_pred.shape[1] != mask.shape[1]: 33 | raise Exception("mask shape is not match to y_pred shape") 34 | mask = mask.reshape((mask.shape[0], mask.shape[1], 1)) 35 | mask = mask.float() 36 | y_true = y_true.reshape(y_pred.shape[:-1]) 37 | y_true = y_true.long() 38 | y_true_onehot = F.one_hot(y_true, self.output_dim) 39 | y_true_onehot = y_true_onehot.float() 40 | 41 | return self.compute_loss(y_pred, y_true_onehot, mask) 42 | 43 | def target_score(self, y_pred, y_true): 44 | """ 45 | 计算状态标签得分 + 转移标签得分 46 | y_true: (batch, seq_len, out_dim) 47 | y_pred: (batch, seq_len, out_dim) 48 | """ 49 | # print(y_pred.shape) 50 | # print(y_true.shape) 51 | point_score = torch.einsum("bni,bni->b", y_pred, y_true) 52 | trans_score = torch.einsum("bni,ij,bnj->b", y_true[:, :-1], self.trans, y_true[:, 1: ]) 53 | 54 | return point_score + trans_score 55 | 56 | def log_norm_step(self, y_pred, mask): 57 | """ 58 | 计算归一化因子Z(X) 59 | """ 60 | state = y_pred[:, 0] # 初始Z(X) 61 | y_pred = y_pred[:, 1: ].contiguous() 62 | mask = mask[:, 1:].contiguous() 63 | batch, seq_len, out_dim = y_pred.shape 64 | for t in range(seq_len): 65 | cur_mask = mask[:, t] 66 | state = torch.unsqueeze(state, 2) # (batch, out_dim, 1) 67 | g = torch.unsqueeze(self.trans, 0) # (1, out_dim, out_dim) 68 | outputs = self.logsumexp(state + g, dim=1) # batch, out_dim 69 | outputs = outputs + y_pred[:, t] 70 | outputs = cur_mask * outputs + (1 - cur_mask) * state.squeeze(-1) 71 | state = outputs 72 | 73 | return outputs 74 | 75 | def logsumexp(self, x, dim=None, keepdim=False): 76 | """ 77 | 避免溢出 78 | """ 79 | if dim is None: 80 | x, dim = x.view(-1), 0 81 | xm, _ = torch.max(x, dim, keepdim=True) 82 | out = xm + torch.log(torch.sum(torch.exp(x - xm), dim=dim, keepdim=True)) 83 | return out if keepdim else out.squeeze(dim) 84 | 85 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .collate import * 16 | from .vocab import * 17 | from .sampler import * 18 | from .tokenizer import * 19 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/data/iterator.py: -------------------------------------------------------------------------------- 1 | # Iterator for NLP Dataset 2 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jieba 16 | from .vocab import Vocab 17 | 18 | 19 | def get_idx_from_word(word, word_to_idx, unk_word): 20 | if word in word_to_idx: 21 | return word_to_idx[word] 22 | return word_to_idx[unk_word] 23 | 24 | 25 | class BaseTokenizer(object): 26 | def __init__(self, vocab): 27 | self.vocab = vocab 28 | 29 | def get_tokenizer(self): 30 | return self.tokenizer 31 | 32 | def cut(self, sentence): 33 | pass 34 | 35 | def encode(self, sentence): 36 | pass 37 | 38 | 39 | class JiebaTokenizer(BaseTokenizer): 40 | """ 41 | Constructs a tokenizer based on `jieba `__. 42 | It supports :meth:`cut` method to split the text to tokens, and :meth:`encode` 43 | method to covert text to token ids. 44 | 45 | Args: 46 | vocab(paddlenlp.data.Vocab): An instance of :class:`paddlenlp.data.Vocab`. 47 | """ 48 | 49 | def __init__(self, vocab): 50 | super(JiebaTokenizer, self).__init__(vocab) 51 | self.tokenizer = jieba.Tokenizer() 52 | # initialize tokenizer 53 | self.tokenizer.FREQ = {key: 1 for key in self.vocab.token_to_idx.keys()} 54 | self.tokenizer.total = len(self.tokenizer.FREQ) 55 | self.tokenizer.initialized = True 56 | 57 | def cut(self, sentence, cut_all=False, use_hmm=True): 58 | """ 59 | The method used to cut the text to tokens. 60 | 61 | Args: 62 | sentence(str): The text that needs to be cuted. 63 | cut_all(bool, optional): Whether to use the full mode. If True, 64 | using full mode that gets all the possible words from the 65 | sentence, which is fast but not accurate. If False, using 66 | accurate mode that attempts to cut the sentence into the most 67 | accurate segmentations, which is suitable for text analysis. 68 | Default: False. 69 | use_hmm(bool, optional): Whether to use the HMM model. Default: True. 70 | 71 | Returns: 72 | list[str]: A list of tokens. 73 | 74 | Example: 75 | .. code-block:: python 76 | 77 | from paddlenlp.data import Vocab, JiebaTokenizer 78 | # The vocab file. The sample file can be downloaded firstly. 79 | # wget https://bj.bcebos.com/paddlenlp/data/senta_word_dict.txt 80 | vocab_file_path = './senta_word_dict.txt' 81 | # Initialize the Vocab 82 | vocab = Vocab.load_vocabulary( 83 | vocab_file_path, 84 | unk_token='[UNK]', 85 | pad_token='[PAD]') 86 | tokenizer = JiebaTokenizer(vocab) 87 | 88 | tokens = tokenizer.cut('我爱你中国') 89 | print(tokens) 90 | # ['我爱你', '中国'] 91 | """ 92 | return self.tokenizer.lcut(sentence, cut_all, use_hmm) 93 | 94 | def encode(self, sentence, cut_all=False, use_hmm=True): 95 | """ 96 | The method used to convert the text to ids. It will firstly call 97 | :meth:`cut` method to cut the text to tokens. Then, convert tokens to 98 | ids using `vocab`. 99 | 100 | Args: 101 | sentence(str): The text that needs to be cuted. 102 | cut_all(bool, optional): Whether to use the full mode. If True, 103 | using full mode that gets all the possible words from the 104 | sentence, which is fast but not accurate. If False, using 105 | accurate mode that attempts to cut the sentence into the most 106 | accurate segmentations, which is suitable for text analysis. 107 | Default: False. 108 | use_hmm(bool, optional): Whether to use the HMM model. Default: True. 109 | 110 | Returns: 111 | list[int]: A list of ids. 112 | 113 | Example: 114 | .. code-block:: python 115 | 116 | from paddlenlp.data import Vocab, JiebaTokenizer 117 | # The vocab file. The sample file can be downloaded firstly. 118 | # wget https://bj.bcebos.com/paddlenlp/data/senta_word_dict.txt 119 | vocab_file_path = './senta_word_dict.txt' 120 | # Initialize the Vocab 121 | vocab = Vocab.load_vocabulary( 122 | vocab_file_path, 123 | unk_token='[UNK]', 124 | pad_token='[PAD]') 125 | tokenizer = JiebaTokenizer(vocab) 126 | 127 | ids = tokenizer.encode('我爱你中国') 128 | print(ids) 129 | # [1170578, 575565] 130 | """ 131 | words = self.cut(sentence, cut_all, use_hmm) 132 | return [ 133 | get_idx_from_word(word, self.vocab.token_to_idx, 134 | self.vocab.unk_token) for word in words 135 | ] 136 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # # 3 | # # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # # you may not use this file except in compliance with the License. 5 | # # You may obtain a copy of the License at 6 | # # 7 | # # http://www.apache.org/licenses/LICENSE-2.0 8 | # # 9 | # # Unless required by applicable law or agreed to in writing, software 10 | # # distributed under the License is distributed on an "AS IS" BASIS, 11 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # # See the License for the specific language governing permissions and 13 | # # limitations under the License. 14 | # 15 | from .model_utils import PretrainedModel, register_base_model 16 | from .tokenizer_utils import PretrainedTokenizer, BPETokenizer, tokenize_chinese_chars, is_chinese_char, AddedToken 17 | from .attention_utils import create_bigbird_rand_mask_idx_list 18 | # 19 | from .bert.modeling import * 20 | from .bert.tokenizer import * 21 | # from .bert_japanese.tokenizer import * 22 | # from .ernie.modeling import * 23 | # from .ernie.tokenizer import * 24 | from .gpt.modeling import * 25 | from .gpt.tokenizer import * 26 | from .roberta.modeling import * 27 | from .roberta.tokenizer import * 28 | # from .electra.modeling import * 29 | # from .electra.tokenizer import * 30 | # from .transformer.modeling import * 31 | # from .ernie_gen.modeling import ErnieForGeneration 32 | # from .optimization import * 33 | # from .ppminilm.modeling import * 34 | # from .ppminilm.tokenizer import * 35 | # from .bigbird.modeling import * 36 | # from .bigbird.tokenizer import * 37 | # from .unified_transformer.modeling import * 38 | # from .unified_transformer.tokenizer import * 39 | # from .ernie_ctm.modeling import * 40 | # from .ernie_ctm.tokenizer import * 41 | # from .tinybert.modeling import * 42 | # from .tinybert.tokenizer import * 43 | # from .distilbert.modeling import * 44 | # from .distilbert.tokenizer import * 45 | # from .skep.modeling import * 46 | # from .skep.tokenizer import * 47 | # from .xlnet.modeling import * 48 | # from .xlnet.tokenizer import * 49 | # from .albert.modeling import * 50 | # from .albert.tokenizer import * 51 | # from .ernie_gram.modeling import * 52 | # from .ernie_gram.tokenizer import * 53 | from .nezha.modeling import * 54 | from .nezha.tokenizer import * 55 | # from .ernie_doc.modeling import * 56 | # from .ernie_doc.tokenizer import * 57 | # from .bart.modeling import * 58 | # from .bart.tokenizer import * 59 | # from .roformer.modeling import * 60 | # from .roformer.tokenizer import * 61 | # from .blenderbot.modeling import * 62 | # from .blenderbot.tokenizer import * 63 | # from .blenderbot_small.modeling import * 64 | # from .blenderbot_small.tokenizer import * 65 | # from .unimo.modeling import * 66 | # from .unimo.tokenizer import * 67 | # from .squeezebert.modeling import * 68 | # from .squeezebert.tokenizer import * 69 | # from .convbert.modeling import * 70 | # from .convbert.tokenizer import * 71 | # from .mpnet.modeling import * 72 | # from .mpnet.tokenizer import * 73 | # from .auto.modeling import * 74 | # from .auto.tokenizer import * 75 | # from .ctrl.modeling import * 76 | # from .ctrl.tokenizer import * 77 | # from .layoutlmv2.modeling import * 78 | # from .layoutlmv2.tokenizer import * 79 | # from .layoutxlm.modeling import * 80 | # from .layoutxlm.tokenizer import * 81 | # from .layoutlm.modeling import * 82 | # from .layoutlm.tokenizer import * 83 | # from .t5.modeling import * 84 | # from .t5.tokenizer import * 85 | # from .mbart.modeling import * 86 | # from .mbart.tokenizer import * 87 | # from .reformer.modeling import * 88 | # from .reformer.tokenizer import * 89 | # from .mobilebert.modeling import * 90 | # from .mobilebert.tokenizer import * 91 | # from .chinesebert.modeling import * 92 | # from .chinesebert.tokenizer import * 93 | # from .funnel.modeling import * 94 | # from .funnel.tokenizer import * 95 | # from .ernie_m.modeling import * 96 | # from .ernie_m.tokenizer import * 97 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/bert/__init__.py -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/gpt/__init__.py -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/nezha/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import * 2 | from .tokenizer import * 3 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/roberta/README.md: -------------------------------------------------------------------------------- 1 | # RoBERTa 2 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/roberta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/roberta/__init__.py -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import inspect 17 | 18 | import paddle 19 | from paddle.nn import Layer 20 | 21 | 22 | def fn_args_to_dict(func, *args, **kwargs): 23 | """ 24 | Inspect function `func` and its arguments for running, and extract a 25 | dict mapping between argument names and keys. 26 | """ 27 | if hasattr(inspect, 'getfullargspec'): 28 | (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, 29 | _) = inspect.getfullargspec(func) 30 | else: 31 | (spec_args, spec_varargs, spec_varkw, 32 | spec_defaults) = inspect.getargspec(func) 33 | # add positional argument values 34 | init_dict = dict(zip(spec_args, args)) 35 | # add default argument values 36 | kwargs_dict = dict(zip(spec_args[-len(spec_defaults):], 37 | spec_defaults)) if spec_defaults else {} 38 | kwargs_dict.update(kwargs) 39 | init_dict.update(kwargs_dict) 40 | return init_dict 41 | 42 | 43 | class InitTrackerMeta(type(Layer)): 44 | """ 45 | This metaclass wraps the `__init__` method of a class to add `init_config` 46 | attribute for instances of that class, and `init_config` use a dict to track 47 | the initial configuration. If the class has `_wrap_init` method, it would be 48 | hooked after `__init__` and called as `_wrap_init(self, init_fn, init_args)`. 49 | Since InitTrackerMeta would be used as metaclass for pretrained model classes, 50 | which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)` 51 | rather than `type` as base class for it to avoid inheritance metaclass 52 | conflicts. 53 | """ 54 | 55 | def __init__(cls, name, bases, attrs): 56 | init_func = cls.__init__ 57 | # If attrs has `__init__`, wrap it using accessable `_wrap_init`. 58 | # Otherwise, no need to wrap again since the super cls has been wraped. 59 | # TODO: remove reduplicated tracker if using super cls `__init__` 60 | help_func = getattr(cls, '_wrap_init', 61 | None) if '__init__' in attrs else None 62 | cls.__init__ = InitTrackerMeta.init_and_track_conf(init_func, help_func) 63 | super(InitTrackerMeta, cls).__init__(name, bases, attrs) 64 | 65 | @staticmethod 66 | def init_and_track_conf(init_func, help_func=None): 67 | """ 68 | wraps `init_func` which is `__init__` method of a class to add `init_config` 69 | attribute for instances of that class. 70 | Args: 71 | init_func (callable): It should be the `__init__` method of a class. 72 | help_func (callable, optional): If provided, it would be hooked after 73 | `init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`. 74 | Default None. 75 | 76 | Returns: 77 | function: the wrapped function 78 | """ 79 | 80 | @functools.wraps(init_func) 81 | def __impl__(self, *args, **kwargs): 82 | # keep full configuration 83 | init_func(self, *args, **kwargs) 84 | # registed helper by `_wrap_init` 85 | if help_func: 86 | help_func(self, init_func, *args, **kwargs) 87 | self.init_config = kwargs 88 | if args: 89 | kwargs['init_args'] = args 90 | kwargs['init_class'] = self.__class__.__name__ 91 | 92 | return __impl__ 93 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/utils/__init__.py -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License" 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ''' 15 | This module is used to store environmental variables in PaddleNLP. 16 | PPNLP_HOME --> the root directory for storing PaddleNLP related data. Default to ~/.paddlenlp. Users can change the 17 | ├ default value through the PPNLP_HOME environment variable. 18 | ├─ MODEL_HOME --> Store model files. 19 | └─ DATA_HOME --> Store automatically downloaded datasets. 20 | ''' 21 | import os 22 | 23 | 24 | def _get_user_home(): 25 | return os.path.expanduser('~') 26 | 27 | 28 | def _get_ppnlp_home(): 29 | if 'PPNLP_HOME' in os.environ: 30 | home_path = os.environ['PPNLP_HOME'] 31 | if os.path.exists(home_path): 32 | if os.path.isdir(home_path): 33 | return home_path 34 | else: 35 | raise RuntimeError( 36 | 'The environment variable PPNLP_HOME {} is not a directory.'. 37 | format(home_path)) 38 | else: 39 | return home_path 40 | return os.path.join(_get_user_home(), '.paddlenlp') 41 | 42 | 43 | def _get_sub_home(directory, parent_home=_get_ppnlp_home()): 44 | home = os.path.join(parent_home, directory) 45 | if not os.path.exists(home): 46 | os.makedirs(home) 47 | return home 48 | 49 | 50 | USER_HOME = _get_user_home() 51 | PPNLP_HOME = _get_ppnlp_home() 52 | MODEL_HOME = _get_sub_home('models') 53 | DATA_HOME = _get_sub_home('datasets') 54 | DOWNLOAD_SERVER = "http://paddlepaddle.org.cn/paddlehub" 55 | FAILED_STATUS = -1 56 | SUCCESS_STATUS = 0 57 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/utils/log.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License" 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import contextlib 17 | import copy 18 | import functools 19 | import logging 20 | import os 21 | import sys 22 | import time 23 | import threading 24 | from typing import List 25 | 26 | import colorlog 27 | from colorama import Fore 28 | 29 | loggers = {} 30 | 31 | log_config = { 32 | 'DEBUG': { 33 | 'level': 10, 34 | 'color': 'purple' 35 | }, 36 | 'INFO': { 37 | 'level': 20, 38 | 'color': 'green' 39 | }, 40 | 'TRAIN': { 41 | 'level': 21, 42 | 'color': 'cyan' 43 | }, 44 | 'EVAL': { 45 | 'level': 22, 46 | 'color': 'blue' 47 | }, 48 | 'WARNING': { 49 | 'level': 30, 50 | 'color': 'yellow' 51 | }, 52 | 'ERROR': { 53 | 'level': 40, 54 | 'color': 'red' 55 | }, 56 | 'CRITICAL': { 57 | 'level': 50, 58 | 'color': 'bold_red' 59 | } 60 | } 61 | 62 | 63 | class Logger(object): 64 | ''' 65 | Deafult logger in PaddleNLP 66 | 67 | Args: 68 | name(str) : Logger name, default is 'PaddleNLP' 69 | ''' 70 | 71 | def __init__(self, name: str=None): 72 | name = 'PaddleNLP' if not name else name 73 | self.logger = logging.getLogger(name) 74 | 75 | for key, conf in log_config.items(): 76 | logging.addLevelName(conf['level'], key) 77 | self.__dict__[key] = functools.partial(self.__call__, conf['level']) 78 | self.__dict__[key.lower()] = functools.partial(self.__call__, 79 | conf['level']) 80 | 81 | self.format = colorlog.ColoredFormatter( 82 | '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s', 83 | log_colors={ 84 | key: conf['color'] 85 | for key, conf in log_config.items() 86 | }) 87 | 88 | self.handler = logging.StreamHandler() 89 | self.handler.setFormatter(self.format) 90 | 91 | self.logger.addHandler(self.handler) 92 | self.logLevel = 'DEBUG' 93 | self.logger.setLevel(logging.DEBUG) 94 | self.logger.propagate = False 95 | self._is_enable = True 96 | 97 | def disable(self): 98 | self._is_enable = False 99 | 100 | def enable(self): 101 | self._is_enable = True 102 | 103 | @property 104 | def is_enable(self) -> bool: 105 | return self._is_enable 106 | 107 | def __call__(self, log_level: str, msg: str): 108 | if not self.is_enable: 109 | return 110 | 111 | self.logger.log(log_level, msg) 112 | 113 | @contextlib.contextmanager 114 | def use_terminator(self, terminator: str): 115 | old_terminator = self.handler.terminator 116 | self.handler.terminator = terminator 117 | yield 118 | self.handler.terminator = old_terminator 119 | 120 | @contextlib.contextmanager 121 | def processing(self, msg: str, interval: float=0.1): 122 | ''' 123 | Continuously print a progress bar with rotating special effects. 124 | 125 | Args: 126 | msg(str): Message to be printed. 127 | interval(float): Rotation interval. Default to 0.1. 128 | ''' 129 | end = False 130 | 131 | def _printer(): 132 | index = 0 133 | flags = ['\\', '|', '/', '-'] 134 | while not end: 135 | flag = flags[index % len(flags)] 136 | with self.use_terminator('\r'): 137 | self.info('{}: {}'.format(msg, flag)) 138 | time.sleep(interval) 139 | index += 1 140 | 141 | t = threading.Thread(target=_printer) 142 | t.start() 143 | yield 144 | end = True 145 | 146 | 147 | logger = Logger() 148 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import paddle 17 | 18 | # A global variable to record the number of calling times for profiler 19 | # functions. It is used to specify the tracing range of training steps. 20 | _profiler_step_id = 0 21 | 22 | # A global variable to avoid parsing from string every time. 23 | _profiler_options = None 24 | 25 | 26 | class ProfilerOptions(object): 27 | ''' 28 | Use a string to initialize a ProfilerOptions. 29 | The string should be in the format: "key1=value1;key2=value;key3=value3". 30 | For example: 31 | "profile_path=model.profile" 32 | "batch_range=[50, 60]; profile_path=model.profile" 33 | "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" 34 | 35 | ProfilerOptions supports following key-value pair: 36 | batch_range - a integer list, e.g. [100, 110]. 37 | state - a string, the optional values are 'CPU', 'GPU' or 'All'. 38 | sorted_key - a string, the optional values are 'calls', 'total', 39 | 'max', 'min' or 'ave. 40 | tracer_option - a string, the optional values are 'Default', 'OpDetail', 41 | 'AllOpDetail'. 42 | profile_path - a string, the path to save the serialized profile data, 43 | which can be used to generate a timeline. 44 | exit_on_finished - a boolean. 45 | ''' 46 | 47 | def __init__(self, options_str): 48 | assert isinstance(options_str, str) 49 | 50 | self._options = { 51 | 'batch_range': [10, 20], 52 | 'state': 'All', 53 | 'sorted_key': 'total', 54 | 'tracer_option': 'Default', 55 | 'profile_path': '/tmp/profile', 56 | 'exit_on_finished': True 57 | } 58 | self._parse_from_string(options_str) 59 | 60 | def _parse_from_string(self, options_str): 61 | for kv in options_str.replace(' ', '').split(';'): 62 | key, value = kv.split('=') 63 | if key == 'batch_range': 64 | value_list = value.replace('[', '').replace(']', '').split(',') 65 | value_list = list(map(int, value_list)) 66 | if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ 67 | 1] > value_list[0]: 68 | self._options[key] = value_list 69 | elif key == 'exit_on_finished': 70 | self._options[key] = value.lower() in ("yes", "true", "t", "1") 71 | elif key in [ 72 | 'state', 'sorted_key', 'tracer_option', 'profile_path' 73 | ]: 74 | self._options[key] = value 75 | 76 | def __getitem__(self, name): 77 | if self._options.get(name, None) is None: 78 | raise ValueError( 79 | "ProfilerOptions does not have an option named %s." % name) 80 | return self._options[name] 81 | 82 | 83 | def add_profiler_step(options_str=None): 84 | ''' 85 | Enable the operator-level timing using PaddlePaddle's profiler. 86 | The profiler uses a independent variable to count the profiler steps. 87 | One call of this function is treated as a profiler step. 88 | 89 | Args: 90 | profiler_options - a string to initialize the ProfilerOptions. 91 | Default is None, and the profiler is disabled. 92 | ''' 93 | if options_str is None: 94 | return 95 | 96 | global _profiler_step_id 97 | global _profiler_options 98 | 99 | if _profiler_options is None: 100 | _profiler_options = ProfilerOptions(options_str) 101 | 102 | if _profiler_step_id == _profiler_options['batch_range'][0]: 103 | paddle.utils.profiler.start_profiler(_profiler_options['state'], 104 | _profiler_options['tracer_option']) 105 | elif _profiler_step_id == _profiler_options['batch_range'][1]: 106 | paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], 107 | _profiler_options['profile_path']) 108 | if _profiler_options['exit_on_finished']: 109 | sys.exit(0) 110 | 111 | _profiler_step_id += 1 112 | -------------------------------------------------------------------------------- /bert_seq2seq/paddle_model/utils/tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import paddle 17 | from .log import logger 18 | 19 | 20 | def static_params_to_dygraph(model, static_tensor_dict): 21 | """Simple tool for convert static paramters to dygraph paramters dict. 22 | 23 | **NOTE** The model must both support static graph and dygraph mode. 24 | 25 | Args: 26 | model (nn.Layer): the model of a neural network. 27 | static_tensor_dict (string): path of which locate the saved paramters in static mode. 28 | Usualy load by `paddle.static.load_program_state`. 29 | 30 | Returns: 31 | [tensor dict]: a state dict the same as the dygraph mode. 32 | """ 33 | state_dict = model.state_dict() 34 | # static_tensor_dict = paddle.static.load_program_state(static_params_path) 35 | 36 | ret_dict = dict() 37 | for n, p in state_dict.items(): 38 | if p.name not in static_tensor_dict: 39 | logger.info("%s paramter is missing from you state dict." % n) 40 | continue 41 | ret_dict[n] = static_tensor_dict[p.name] 42 | 43 | return ret_dict 44 | 45 | 46 | def dygraph_params_to_static(model, dygraph_tensor_dict, topo=None): 47 | """Simple tool for convert dygraph paramters to static paramters dict. 48 | 49 | **NOTE** The model must both support static graph and dygraph mode. 50 | 51 | Args: 52 | model (nn.Layer): the model of a neural network. 53 | dygraph_tensor_dict (string): path of which locate the saved paramters in static mode. 54 | 55 | Returns: 56 | [tensor dict]: a state dict the same as the dygraph mode. 57 | """ 58 | state_dict = model.state_dict() 59 | 60 | ret_dict = dict() 61 | for name, parm in state_dict.items(): 62 | if name not in dygraph_tensor_dict: 63 | logger.info("%s paramter is missing from you state dict." % name) 64 | continue 65 | 66 | tensor = dygraph_tensor_dict[name] 67 | if parm.is_distributed: 68 | assert topo is not None 69 | for dim, v in enumerate(tensor.shape): 70 | if parm.shape[dim] != v: 71 | break 72 | 73 | splited = np.split( 74 | tensor, topo.mp_info.size, axis=dim)[topo.mp_info.rank] 75 | ret_dict[parm.name] = splited 76 | else: 77 | ret_dict[parm.name] = tensor 78 | 79 | return ret_dict 80 | 81 | 82 | class TimeCostAverage(object): 83 | """ 84 | Simple tool for calcluating time average cost in the process of training and inferencing. 85 | """ 86 | 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | """ 92 | Reset the recoder state, and reset the `cnt` to zero. 93 | """ 94 | self.cnt = 0 95 | self.total_time = 0 96 | 97 | def record(self, usetime): 98 | """ 99 | Recoding the time cost in current step and accumulating the `cnt`. 100 | """ 101 | self.cnt += 1 102 | self.total_time += usetime 103 | 104 | def get_average(self): 105 | """ 106 | Returning the average time cost after the start of training. 107 | """ 108 | if self.cnt == 0: 109 | return 0 110 | return self.total_time / self.cnt 111 | 112 | 113 | def get_env_device(): 114 | """ 115 | Return the device name of running enviroment. 116 | """ 117 | if paddle.is_compiled_with_cuda(): 118 | return 'gpu' 119 | elif paddle.is_compiled_with_npu(): 120 | return 'npu' 121 | elif paddle.is_compiled_with_rocm(): 122 | return 'rocm' 123 | elif paddle.is_compiled_with_xpu(): 124 | return 'xpu' 125 | return 'cpu' 126 | 127 | 128 | def compare_version(version, pair_version): 129 | """ 130 | Args: 131 | version (str): The first version string needed to be compared. 132 | The format of version string should be as follow : "xxx.yyy.zzz". 133 | pair_version (str): The second version string needed to be compared. 134 | The format of version string should be as follow : "xxx.yyy.zzz". 135 | Returns: 136 | int: The result of comparasion. 1 means version > pair_version; 0 means 137 | version = pair_version; -1 means version < pair_version. 138 | 139 | Examples: 140 | >>> compare_version("2.2.1", "2.2.0") 141 | >>> 1 142 | >>> compare_version("2.2.0", "2.2.0") 143 | >>> 0 144 | >>> compare_version("2.2.0-rc0", "2.2.0") 145 | >>> -1 146 | >>> compare_version("2.3.0-rc0", "2.2.0") 147 | >>> 1 148 | """ 149 | version = version.strip() 150 | pair_version = pair_version.strip() 151 | if version == pair_version: 152 | return 0 153 | version_list = version.split(".") 154 | pair_version_list = pair_version.split(".") 155 | for version_code, pair_version_code in zip(version_list, pair_version_list): 156 | if not version_code.isnumeric(): 157 | return -1 158 | if not pair_version_code.isnumeric(): 159 | return 1 160 | if int(version_code) > int(pair_version_code): 161 | return 1 162 | elif int(version_code) < int(pair_version_code): 163 | return -1 164 | return 0 165 | -------------------------------------------------------------------------------- /bert_seq2seq/t5_ch.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq.model.t5_model import T5ForConditionalGeneration, T5Config, T5SmallConfig 4 | from bert_seq2seq.tokenizer import T5PegasusTokenizer,load_chinese_base_vocab 5 | from bert_seq2seq.basic_bert import BasicT5 6 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering 7 | import torch.nn.functional as F 8 | 9 | class T5Model(BasicT5): 10 | 11 | def __init__(self, word2idx, size="base"): 12 | super().__init__() 13 | if size == "base": 14 | config = T5Config() 15 | elif size == "small": 16 | config = T5SmallConfig() 17 | else: 18 | raise Exception("not support this model type") 19 | self.model = T5ForConditionalGeneration(config) 20 | 21 | self.word2idx = word2idx 22 | self.tokenizer = T5PegasusTokenizer(self.word2idx) 23 | self.bos_id = self.word2idx["[CLS]"] 24 | self.eos_id = self.word2idx["[SEP]"] 25 | self.unk_id = self.word2idx["[UNK]"] 26 | 27 | def forward(self, input_ids, decoder_input_ids, labels=None): 28 | input_ids = input_ids.to(self.device) 29 | decoder_input_ids = decoder_input_ids.to(self.device) 30 | if labels is not None: 31 | labels = labels.to(self.device) 32 | return self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels) 33 | 34 | 35 | def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=1.0, add_eos=True): 36 | 37 | token_out = self.tokenizer.encode(text, max_length=input_max_length) 38 | if len(token_out) == 2: 39 | token_ids = token_out[0] 40 | else: 41 | token_ids = token_out 42 | if not add_eos: 43 | token_ids = token_ids[:-1] 44 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1) 45 | output_ids = [] 46 | 47 | input_decoder_ids = torch.tensor(self.bos_id, device=self.device, dtype=torch.long).view(1, -1) 48 | with torch.no_grad(): 49 | for step in range(out_max_length): 50 | scores = self.model(input_ids=token_ids, decoder_input_ids=input_decoder_ids)[0] 51 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0) 52 | logit_score[self.unk_id] = -float('Inf') 53 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p) 54 | 55 | filterd_logits_prob = F.softmax(filtered_logits, dim=-1) 56 | 57 | next_token = torch.multinomial(filterd_logits_prob, num_samples=1) 58 | if self.eos_id == next_token.item(): 59 | break 60 | 61 | output_ids.append(next_token.item()) 62 | input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1) 63 | 64 | return self.tokenizer.decode(output_ids) -------------------------------------------------------------------------------- /bert_seq2seq/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq.seq2seq_model import Seq2SeqModel 3 | from bert_seq2seq.bert_cls_classifier import BertClsClassifier 4 | from bert_seq2seq.bert_seq_labeling import BertSeqLabeling 5 | from bert_seq2seq.bert_seq_labeling_crf import BertSeqLabelingCRF 6 | from bert_seq2seq.bert_relation_extraction import BertRelationExtrac 7 | from bert_seq2seq.bert_cls_multi_classifier import BertClsMultiClassifier 8 | import torch.nn.functional as F 9 | from bert_seq2seq.bert_cls_multi_seq2seq import ClsMultiSeq2SeqModel 10 | from bert_seq2seq.simbert_model import SimBertModel 11 | from bert_seq2seq.gpt2_generate_model import GPT2 12 | from bert_seq2seq.basic_bert import BasicBert 13 | 14 | 15 | def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0, target=None): 16 | """ 17 | model_path: 模型位置 18 | 这是个统一的接口,用来加载模型的 19 | model_class : seq2seq or encoder 20 | """ 21 | 22 | if model_class == "seq2seq": 23 | bert_model = Seq2SeqModel(word2ix, model_name=model_name, tokenizer=tokenizer) 24 | return bert_model 25 | elif model_class == "cls": 26 | if target_size == 0: 27 | raise Exception("必须传入参数 target_size,才能确定预测多少分类") 28 | bert_model = BertClsClassifier(word2ix, target_size, model_name=model_name) 29 | return bert_model 30 | elif model_class == "sequence_labeling": 31 | ## 序列标注模型 32 | if target_size == 0: 33 | raise Exception("必须传入参数 target_size,才能确定预测多少分类") 34 | bert_model = BertSeqLabeling(word2ix, target_size, model_name=model_name) 35 | return bert_model 36 | elif model_class == "sequence_labeling_crf": 37 | # 带有crf层的序列标注模型 38 | if target_size == 0: 39 | raise Exception("必须传入参数 target_size,才能确定预测多少分类") 40 | bert_model = BertSeqLabelingCRF(word2ix, target_size, model_name=model_name) 41 | return bert_model 42 | elif model_class == "relation_extrac": 43 | if target_size == 0: 44 | raise Exception("必须传入参数 target_size 表示预测predicate的种类") 45 | bert_model = BertRelationExtrac(word2ix, target_size, model_name=model_name) 46 | return bert_model 47 | elif model_class == "simbert": 48 | bert_model = SimBertModel(word2ix, model_name=model_name, tokenizer=tokenizer) 49 | return bert_model 50 | elif model_class == "multi_label_cls": 51 | bert_model = BertClsMultiClassifier(word2ix, target_size, model_name=model_name) 52 | return bert_model 53 | elif model_class == "multi_label_cls_seq2seq": 54 | bert_model = ClsMultiSeq2SeqModel(word2ix, target, model_name=model_name) 55 | return bert_model 56 | elif model_class == "embedding": 57 | bert_model = BasicBert(word2ix, model_name=model_name, tokenizer=tokenizer) 58 | return bert_model 59 | else : 60 | raise Exception("model_name_err") 61 | 62 | 63 | def load_gpt(word2ix, tokenizer=None): 64 | model = GPT2(word2ix, tokenizer=tokenizer) 65 | return model 66 | 67 | 68 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/examples/.DS_Store -------------------------------------------------------------------------------- /examples/bart_auto_title_train.py: -------------------------------------------------------------------------------- 1 | ## model url : https://huggingface.co/fnlp/bart-base-chinese 2 | import torch 3 | import time 4 | import glob 5 | from torch.utils.data import Dataset, DataLoader 6 | from tqdm import tqdm 7 | from bert_seq2seq.extend_model_method import ExtendModel 8 | 9 | from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline 10 | 11 | src_dir = './corpus/auto_title/train.src' 12 | tgt_dir = './corpus/auto_title/train.tgt' 13 | 14 | vocab_path = "./state_dict/bart-chinese" ## 字典 15 | model_path = "./state_dict/bart-chinese" ## 预训练参数 16 | 17 | model_save_path = "./state_dict/bart_autotile.bin" ## 训练完模型 保存在哪里 18 | batch_size = 8 19 | lr = 1e-5 20 | 21 | tokenizer = BertTokenizer.from_pretrained(vocab_path) 22 | word2idx = tokenizer.vocab 23 | model = BartForConditionalGeneration.from_pretrained(model_path) 24 | 25 | def read_file(): 26 | src = [] 27 | tgt = [] 28 | 29 | with open(src_dir,'r',encoding='utf-8') as f: 30 | lines = f.readlines() 31 | 32 | for line in lines: 33 | src.append(line.strip('\n').lower()) 34 | 35 | with open(tgt_dir,'r',encoding='utf-8') as f: 36 | lines = f.readlines() 37 | for line in lines: 38 | tgt.append(line.strip('\n').lower()) 39 | 40 | return src, tgt 41 | 42 | 43 | class SeqDataset(Dataset): 44 | """ 45 | 针对特定数据集,定义一个相关的取数据的方式 46 | """ 47 | 48 | def __init__(self, sents_src, sents_tgt): 49 | ## 一般init函数是加载所有数据 50 | super(SeqDataset, self).__init__() 51 | # 读原始数据 52 | self.sents_src = sents_src 53 | self.sents_tgt = sents_tgt 54 | 55 | self.idx2word = {k: v for v, k in word2idx.items()} 56 | 57 | def __getitem__(self, i): 58 | ## 得到单个数据 59 | # print(i) 60 | src = self.sents_src[i] 61 | tgt = self.sents_tgt[i] 62 | token_ids_src = tokenizer.encode(src, max_length=256) 63 | token_ids_tgt = tokenizer.encode(tgt, max_length=256) 64 | output = { 65 | "token_ids_src": token_ids_src, 66 | "token_ids_tgt": token_ids_tgt, 67 | } 68 | return output 69 | 70 | def __len__(self): 71 | return len(self.sents_src) 72 | 73 | 74 | def collate_fn(batch): 75 | """ 76 | 动态padding, batch为一部分sample 77 | """ 78 | 79 | def padding(indice, max_length, pad_idx=0): 80 | """ 81 | pad 函数 82 | """ 83 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 84 | return torch.tensor(pad_indice) 85 | 86 | token_ids_src = [data["token_ids_src"] for data in batch] 87 | max_length_src = max([len(t) for t in token_ids_src]) 88 | token_ids_tgt = [data["token_ids_tgt"] for data in batch] 89 | max_length_tgt = max([len(t) for t in token_ids_tgt]) 90 | 91 | token_ids_padded = padding(token_ids_src, max_length_src) 92 | target_ids_padded = padding(token_ids_tgt, max_length_tgt) 93 | labels_ids = target_ids_padded.clone() 94 | target_ids_padded = target_ids_padded[:, :-1].contiguous() 95 | labels_ids = labels_ids[:, 1:].contiguous() 96 | 97 | return token_ids_padded, target_ids_padded, labels_ids 98 | 99 | 100 | class Trainer: 101 | def __init__(self): 102 | # 加载数据 103 | self.sents_src, self.sents_tgt = read_file() 104 | 105 | # 判断是否有可用GPU 106 | self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 107 | print("device: " + str(self.device)) 108 | # 定义模型 109 | self.model = ExtendModel(model, tokenizer=tokenizer, bos_id=word2idx["[CLS]"], eos_id=word2idx["[SEP]"], device=self.device) 110 | 111 | # 将模型发送到计算设备(GPU或CPU) 112 | self.model.to(self.device) 113 | # self.model.set_device(self.device) 114 | # 声明需要优化的参数 115 | self.optim_parameters = list(self.model.parameters()) 116 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 117 | # 声明自定义的数据加载器 118 | dataset = SeqDataset(self.sents_src, self.sents_tgt) 119 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 120 | 121 | def train(self, epoch): 122 | # 一个epoch的训练 123 | self.model.train() 124 | self.iteration(epoch, dataloader=self.dataloader, train=True) 125 | 126 | def save(self, save_path): 127 | """ 128 | 保存模型 129 | """ 130 | self.model.save_all_params(save_path) 131 | print("{} saved!".format(save_path)) 132 | 133 | def iteration(self, epoch, dataloader, train=True): 134 | total_loss = 0 135 | report_loss = 0 136 | start_time = time.time() ## 得到当前时间 137 | step = 0 138 | for token_ids, target_ids, labels_ids in tqdm(dataloader, total=len(dataloader)): 139 | step += 1 140 | token_ids = token_ids.to(self.device) 141 | target_ids = target_ids.to(self.device) 142 | labels_ids = labels_ids.to(self.device) 143 | if step % 100 == 0: 144 | # self.save(model_save_path) 145 | self.model.eval() 146 | test_data = ["本文总结了十个可穿戴产品的设计原则,而这些原则同样也是笔者认为是这个行业最吸引人的地方:1为人们解决重复性问题,2从人开始而不是从机器开始,3要引起注意但不要刻意,4提升用户能力而不是取代人", 147 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界,还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落,未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献", 148 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"] 149 | 150 | for text in test_data: 151 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True, top_k=20)) 152 | self.model.train() 153 | print("report loss is " + str(report_loss)) 154 | report_loss = 0 155 | 156 | # 因为传入了target标签,因此会计算loss并且返回 157 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0] 158 | # 反向传播 159 | if train: 160 | # 清空之前的梯度 161 | self.optimizer.zero_grad() 162 | # 反向传播, 获取新的梯度 163 | loss.backward() 164 | # 用获取的梯度更新模型参数 165 | self.optimizer.step() 166 | 167 | # 为计算当前epoch的平均loss 168 | total_loss += loss.item() 169 | report_loss += loss.item() 170 | 171 | end_time = time.time() 172 | spend_time = end_time - start_time 173 | # 打印训练信息 174 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 175 | # 保存模型 176 | # self.save(model_save_path) 177 | 178 | 179 | if __name__ == '__main__': 180 | 181 | trainer = Trainer() 182 | train_epoches = 10 183 | for epoch in range(train_epoches): 184 | # 训练一个epoch 185 | trainer.train(epoch) 186 | -------------------------------------------------------------------------------- /examples/examples_paddle/mBart_translation_en_ro.py: -------------------------------------------------------------------------------- 1 | from paddlenlp.transformers import MBartForConditionalGeneration,MBartTokenizer 2 | import paddle 3 | from paddle.io import Dataset 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | parse = argparse.ArgumentParser() 8 | parse.add_argument("--model_name",type=str) 9 | parse.add_argument("--epoches",type=int) 10 | parse.add_argument("--test_step",type=int) 11 | parse.add_argument("--save_step",type=int) 12 | parse.add_argument("--batch_size",type=int) 13 | parse.add_argument("--src_lang",type=str) 14 | parse.add_argument("--tgt_lang",type=str) 15 | parse.add_argument("--datapath_en",type=str) 16 | parse.add_argument("--datapath_ro",type=str) 17 | parse.add_argument("--max_length",type=int) 18 | opt = parse.parse_args() 19 | 20 | def read_dataset(datapath_en,datapath_ro): 21 | dataset_en = [] 22 | dataset_ro = [] 23 | with open(datapath_en) as f: 24 | for line in f: 25 | dataset_en.append(line.strip("\n")) 26 | with open(datapath_ro) as ff: 27 | for line in ff: 28 | dataset_ro.append(line.strip("\n")) 29 | return dataset_en,dataset_ro 30 | 31 | class EnRoDataset(Dataset): 32 | def __init__(self,dataset_en,dataset_ro): 33 | super(EnRoDataset,self).__init__() 34 | self.dataset_en = dataset_en 35 | self.dataset_ro = dataset_ro 36 | def __getitem__(self,index): 37 | data_for_en = self.dataset_en[index] 38 | data_for_ro = self.dataset_ro[index] 39 | input_ids = tokenizer.encode(data_for_en)["input_ids"] 40 | decoder_input_ids = [tokenizer.lang_code_to_id[opt.tgt_lang]]+tokenizer.encode(data_for_ro)["input_ids"][:-1] 41 | output = { 42 | "input_ids":input_ids, 43 | "decoder_input_ids":decoder_input_ids 44 | } 45 | return output 46 | def __len__(self): 47 | return len(self.dataset_en) 48 | @staticmethod 49 | def collate_fn(batch): 50 | def padding(indice,max_length=50,pad_idx=tokenizer.pad_token_id): 51 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 52 | return paddle.to_tensor(pad_indice) 53 | 54 | input_ids = [data["input_ids"] for data in batch] 55 | decoder_input_ids = [data["decoder_input_ids"] for data in batch] 56 | max_length_of_input_ids = max([len(text) for text in input_ids]) 57 | max_length_of_decoder_input_ids = max([len(text) for text in decoder_input_ids]) 58 | 59 | input_ids_padded = padding(input_ids,max_length_of_input_ids) 60 | decoder_input_ids_padded = padding(decoder_input_ids,max_length_of_decoder_input_ids) 61 | return input_ids_padded,decoder_input_ids_padded 62 | 63 | 64 | model = MBartForConditionalGeneration.from_pretrained(opt.model_name) 65 | tokenizer = MBartTokenizer.from_pretrained(opt.model_name,src_lang=opt.src_lang,tgt_lang=opt.tgt_lang) 66 | dataset_en,dataset_ro = read_dataset(opt.datapath_en,opt.datapath_ro) 67 | 68 | dataset = EnRoDataset( 69 | dataset_en, 70 | dataset_ro) 71 | dataloader = paddle.io.DataLoader( 72 | dataset, 73 | batch_size=opt.batch_size, 74 | shuffle=True, 75 | collate_fn=dataset.collate_fn) 76 | optimizer = paddle.optimizer.AdamW( 77 | learning_rate=1e-5, 78 | parameters=model.parameters(), 79 | weight_decay=1e-5) 80 | 81 | def calculate_loss(logits,label): 82 | return paddle.nn.functional.cross_entropy(logits.reshape([-1,tokenizer.vocab_size]),label.reshape([-1])) 83 | 84 | def generate_text_for_test(text,target_language,max_length): 85 | with paddle.no_grad(): 86 | input_ids = paddle.to_tensor(tokenizer.encode(text)["input_ids"]).unsqueeze(0) 87 | bos_id = tokenizer.lang_code_to_id[target_language] 88 | outputs, _ = model.generate( 89 | input_ids=input_ids, 90 | forced_bos_token_id=bos_id, 91 | decode_strategy="beam_search", 92 | num_beams=4, 93 | max_length=50) 94 | return tokenizer.convert_ids_to_string(outputs[0].numpy().tolist()[1:-1]) 95 | 96 | def train(): 97 | 98 | global_step = 1 99 | report_loss = 0 100 | 101 | for epoch in range(opt.epoches): 102 | for input_ids, decoder_input_ids in tqdm(dataloader(), total=len(dataloader)): 103 | model.train() 104 | if global_step % opt.test_step == 0: 105 | model.eval() 106 | texts = ["election of Vice-Presidents of the European Parliament ( deadline for submitting nominations ) : see Minutes","agenda for next sitting : see Minutes"] 107 | for text in texts: 108 | print("English:",text) 109 | print("Romanian",generate_text_for_test(text,opt.tgt_lang,opt.max_length)) 110 | print("loss is {}".format(report_loss)) 111 | report_loss = 0 112 | model.train() 113 | if global_step % opt.save_step == 0: 114 | pass 115 | logits = model(input_ids=input_ids,decoder_input_ids=decoder_input_ids) 116 | loss = calculate_loss(logits[:,:-2],decoder_input_ids[:,1:-1]) 117 | report_loss = report_loss + loss.item() 118 | loss.backward() 119 | optimizer.step() 120 | optimizer.clear_grad() 121 | global_step = global_step + 1 122 | 123 | 124 | 125 | if __name__ == "__main__": 126 | train() 127 | 128 | # python new.py --model_name "mbart-large-en-ro" --epoches 3 --test_step 10 --save_step 10000 --batch_size 3 --src_lang "en_XX" --tgt_lang "ro_RO" --datapath_en "./wmt16_en_ro/corpus.en" --datapath_ro "./wmt16_en_ro/corpus.ro" --max_length 128 129 | 130 | -------------------------------------------------------------------------------- /examples/gpt2_ancient_translation_train.py: -------------------------------------------------------------------------------- 1 | ## gpt2 进行文言文翻译 2 | from bert_seq2seq import load_gpt 3 | import torch 4 | from tqdm import tqdm 5 | import time 6 | import glob 7 | from torch.utils.data import Dataset, DataLoader 8 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 9 | 10 | vocab_path = "./state_dict/gpt2通用中文模型/vocab.txt" 11 | model_path = "./state_dict/gpt2通用中文模型/pytorch_model.bin" 12 | model_save_path = "./state_dict/gpt_ancient_trans_model.bin" 13 | batch_size = 8 14 | lr = 1e-5 15 | word2idx = load_chinese_base_vocab(vocab_path) 16 | 17 | def read_corpus(): 18 | """ 19 | 读原始数据 20 | """ 21 | src = [] 22 | tgt = [] 23 | data_path = glob.glob("./corpus/文言文翻译/*") 24 | for p in data_path: 25 | dir = p.split("/")[:-1] 26 | dir = "/".join(dir) 27 | # print(dir) 28 | name = p.split("/")[-1] 29 | if "翻译" in name: 30 | # 找到了一个翻译文件 31 | tgt_name = name 32 | src_name = name[:-2] 33 | with open(dir + "/" + src_name) as fs: 34 | lines = fs.readlines() 35 | for line in lines: 36 | src.append(line.strip("\n").strip()) 37 | 38 | with open(dir + "/" + tgt_name) as ft: 39 | lines = ft.readlines() 40 | for line in lines: 41 | tgt.append(line.strip("\n").strip()) 42 | 43 | else: 44 | pass 45 | 46 | return src, tgt 47 | 48 | class SeqDataset(Dataset): 49 | """ 50 | 针对特定数据集,定义一个相关的取数据的方式 51 | """ 52 | 53 | def __init__(self, sents_src, sents_tgt): 54 | ## 一般init函数是加载所有数据 55 | super(SeqDataset, self).__init__() 56 | # 读原始数据 57 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 58 | self.sents_src = sents_src 59 | self.sents_tgt = sents_tgt 60 | 61 | self.idx2word = {k: v for v, k in word2idx.items()} 62 | self.tokenizer = Tokenizer(word2idx) 63 | 64 | def __getitem__(self, i): 65 | ## 得到单个数据 66 | # print(i) 67 | src = self.sents_src[i] 68 | tgt = self.sents_tgt[i] 69 | token_ids, _ = self.tokenizer.encode(src, tgt, max_length=256) 70 | output = { 71 | "token_ids": token_ids, 72 | } 73 | return output 74 | 75 | def __len__(self): 76 | return len(self.sents_src) 77 | 78 | 79 | def collate_fn(batch): 80 | """ 81 | 动态padding, batch为一部分sample 82 | """ 83 | 84 | def padding(indice, max_length, pad_idx=0): 85 | """ 86 | pad 函数 87 | """ 88 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 89 | return torch.tensor(pad_indice) 90 | 91 | token_ids = [data["token_ids"] for data in batch] 92 | max_length = max([len(t) for t in token_ids]) 93 | 94 | token_ids_padded = padding(token_ids, max_length) 95 | target_ids_padded = token_ids_padded.clone() 96 | target_ids_padded[target_ids_padded == 0] = -100 97 | 98 | return token_ids_padded, target_ids_padded 99 | 100 | 101 | class Trainer: 102 | def __init__(self): 103 | # 加载数据 104 | self.sents_src, self.sents_tgt = read_corpus() 105 | 106 | # 判断是否有可用GPU 107 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 108 | print("device: " + str(self.device)) 109 | # 定义模型 110 | self.gpt_model = load_gpt(word2idx) 111 | ## 加载预训练的模型参数~ 112 | self.gpt_model.load_pretrain_params(model_path) 113 | # 将模型发送到计算设备(GPU或CPU) 114 | self.gpt_model.set_device(self.device) 115 | # 声明需要优化的参数 116 | self.optim_parameters = list(self.gpt_model.parameters()) 117 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 118 | # 声明自定义的数据加载器 119 | dataset = SeqDataset(self.sents_src, self.sents_tgt) 120 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 121 | 122 | def train(self, epoch): 123 | # 一个epoch的训练 124 | self.gpt_model.train() 125 | self.iteration(epoch, dataloader=self.dataloader, train=True) 126 | 127 | def save(self, save_path): 128 | """ 129 | 保存模型 130 | """ 131 | self.gpt_model.save_all_params(save_path) 132 | print("{} saved!".format(save_path)) 133 | 134 | def iteration(self, epoch, dataloader, train=True): 135 | total_loss = 0 136 | report_loss = 0 137 | start_time = time.time() ## 得到当前时间 138 | step = 0 139 | # for token_ids, target_ids in tqdm(dataloader, position=0, leave=True): 140 | for token_ids, target_ids in dataloader: 141 | step += 1 142 | if step % 4000 == 0: 143 | self.gpt_model.eval() 144 | test_data = ["遂入颍川。", "会日暝,结陈相持。", "一言兴邦,斯近之矣。"] 145 | for text in test_data: 146 | print(self.gpt_model.sample_generate(text, add_eos=True)) 147 | self.gpt_model.train() 148 | print("report loss is " + str(report_loss)) 149 | report_loss = 0.0 150 | 151 | # 因为传入了target标签,因此会计算loss并且返回 152 | loss, _ = self.gpt_model(token_ids, 153 | labels=target_ids, 154 | ) 155 | # 反向传播 156 | if train: 157 | # 清空之前的梯度 158 | self.optimizer.zero_grad() 159 | # 反向传播, 获取新的梯度 160 | loss.backward() 161 | # 用获取的梯度更新模型参数 162 | self.optimizer.step() 163 | 164 | # 为计算当前epoch的平均loss 165 | total_loss += loss.item() 166 | report_loss += loss.item() 167 | 168 | end_time = time.time() 169 | spend_time = end_time - start_time 170 | # 打印训练信息 171 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 172 | # 保存模型 173 | self.save(model_save_path) 174 | 175 | 176 | if __name__ == '__main__': 177 | 178 | trainer = Trainer() 179 | train_epoches = 100 180 | for epoch in range(train_epoches): 181 | # 训练一个epoch 182 | trainer.train(epoch) 183 | -------------------------------------------------------------------------------- /examples/gpt2_english_story_train.py: -------------------------------------------------------------------------------- 1 | # gpt2模型进行英文讲故事 给一个开头 继续讲五句话 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | import pandas as pd 6 | from torch.utils.data import Dataset, DataLoader 7 | from bert_seq2seq import load_gpt 8 | from transformers import AutoTokenizer 9 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") 10 | word2ix = tokenizer.get_vocab() 11 | data_path = "./corpus/英文讲故事数据集/train.csv" 12 | model_path = "./state_dict/english_gpt_model/english_gpt_story.bin" 13 | model_save_path = "./state_dict/gpt_auto_story.bin" 14 | batch_size = 8 15 | lr = 1e-5 16 | maxlen = 256 17 | 18 | def load_data(): 19 | sents_src = [] 20 | sents_tgt = [] 21 | df = pd.read_csv(data_path) 22 | for i, row in df.iterrows(): 23 | sents_src.append(row[1]) 24 | tgt = "" 25 | for j in range(2, 7): 26 | tgt += row[j] 27 | sents_tgt.append(tgt) 28 | 29 | return sents_src, sents_tgt 30 | 31 | class GPTDataset(Dataset): 32 | """ 33 | 针对特定数据集,定义一个相关的取数据的方式 34 | """ 35 | 36 | def __init__(self): 37 | ## 一般init函数是加载所有数据 38 | super(GPTDataset, self).__init__() 39 | ## 拿到所有文件名字 40 | self.sents_src, self.sents_tgt = load_data() 41 | self.tokenizer = tokenizer 42 | 43 | def __getitem__(self, i): 44 | ## 得到单个数据 45 | 46 | src_d = self.sents_src[i] 47 | tgt_d = self.sents_tgt[i] 48 | src_ids = self.tokenizer.encode(src_d) + [self.tokenizer.eos_token_id] 49 | tgt_ids = self.tokenizer.encode(tgt_d) + [self.tokenizer.eos_token_id] 50 | output = { 51 | "token_ids": src_ids + tgt_ids, 52 | } 53 | return output 54 | 55 | 56 | 57 | def __len__(self): 58 | return len(self.sents_src) 59 | 60 | 61 | def collate_fn(batch): 62 | """ 63 | 动态padding, batch为一部分sample 64 | """ 65 | 66 | def padding(indice, max_length, pad_idx=0): 67 | """ 68 | pad 函数 69 | """ 70 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 71 | return torch.tensor(pad_indice) 72 | 73 | token_ids = [data["token_ids"] for data in batch] 74 | max_length = max([len(t) for t in token_ids]) 75 | 76 | token_ids_padded = padding(token_ids, max_length, pad_idx=word2ix[""]) 77 | token_target_padded = token_ids_padded.clone() 78 | token_target_padded[token_target_padded == word2ix[""]] = -100 79 | return token_ids_padded, token_target_padded 80 | 81 | 82 | class Trainer: 83 | def __init__(self): 84 | # 判断是否有可用GPU 85 | # self.device = torch.device("cpu") 86 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | print("device: " + str(self.device)) 88 | # 定义模型 89 | self.model = load_gpt(word2ix, tokenizer=tokenizer) 90 | self.model.load_pretrain_params(model_path) 91 | # 加载已经训练好的模型,继续训练 92 | 93 | # 将模型发送到计算设备(GPU或CPU) 94 | self.model.set_device(self.device) 95 | # 声明需要优化的参数 96 | self.optim_parameters = list(self.model.parameters()) 97 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 98 | # 声明自定义的数据加载器 99 | dataset = GPTDataset() 100 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 101 | 102 | def train(self, epoch): 103 | # 一个epoch的训练 104 | self.model.train() 105 | self.iteration(epoch, dataloader=self.dataloader, train=True) 106 | 107 | def save(self, save_path): 108 | """ 109 | 保存模型 110 | """ 111 | self.model.save_all_params(save_path) 112 | print("{} saved!".format(save_path)) 113 | 114 | def iteration(self, epoch, dataloader, train=True): 115 | total_loss = 0 116 | start_time = time.time() ## 得到当前时间 117 | step = 0 118 | report_loss = 0 119 | for token_ids, token_target in tqdm(dataloader, position=0, leave=True): 120 | step += 1 121 | if step % 1000 == 0: 122 | self.model.eval() 123 | print(self.model.sample_generate_english("David Drops the Weight", out_max_length=300, add_eos=True)) 124 | print("loss is " + str(report_loss)) 125 | report_loss = 0 126 | self.model.train() 127 | if step % 6000 == 0: 128 | self.save(model_save_path) 129 | 130 | # 因为传入了target标签,因此会计算loss并且返回 131 | loss, pred_logit = self.model(token_ids, labels=token_target) 132 | report_loss += loss.item() 133 | # 反向传播 134 | if train: 135 | # 清空之前的梯度 136 | self.optimizer.zero_grad() 137 | # 反向传播, 获取新的梯度 138 | loss.backward() 139 | # 用获取的梯度更新模型参数 140 | self.optimizer.step() 141 | 142 | # 为计算当前epoch的平均loss 143 | total_loss += loss.item() 144 | 145 | end_time = time.time() 146 | spend_time = end_time - start_time 147 | # 打印训练信息 148 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 149 | 150 | 151 | if __name__ == '__main__': 152 | 153 | trainer = Trainer() 154 | train_epoches = 20 155 | 156 | for epoch in range(train_epoches): 157 | # 训练一个epoch 158 | trainer.train(epoch) -------------------------------------------------------------------------------- /examples/gpt2_explain_dream_train.py: -------------------------------------------------------------------------------- 1 | ## gpt2模型进行周公解梦 2 | from bert_seq2seq import load_gpt 3 | import torch 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import time 7 | from torch.utils.data import Dataset, DataLoader 8 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 9 | 10 | vocab_path = "./state_dict/gpt_vocab.txt" 11 | model_path = "./state_dict/gpt_pytorch_model.bin" 12 | model_save_path = "./state_dict/gpt_explain_dream_model.bin" 13 | batch_size = 16 14 | lr = 1e-5 15 | data_path = "./corpus/周公解梦/dream_data.csv" 16 | word2idx = load_chinese_base_vocab(vocab_path) 17 | 18 | def read_corpus(): 19 | """ 20 | 读原始数据 21 | """ 22 | sents_src = [] 23 | sents_tgt = [] 24 | 25 | df = pd.read_csv(data_path, delimiter="\t") 26 | for i, row in df.iterrows(): 27 | # print(row) 28 | json_s = eval(row[0]) 29 | sents_src.append(json_s["dream"]) 30 | sents_tgt.append(json_s["decode"]) 31 | 32 | return sents_src, sents_tgt 33 | 34 | 35 | class SeqDataset(Dataset): 36 | """ 37 | 针对特定数据集,定义一个相关的取数据的方式 38 | """ 39 | 40 | def __init__(self, sents_src, sents_tgt): 41 | ## 一般init函数是加载所有数据 42 | super(SeqDataset, self).__init__() 43 | # 读原始数据 44 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 45 | self.sents_src = sents_src 46 | self.sents_tgt = sents_tgt 47 | 48 | self.idx2word = {k: v for v, k in word2idx.items()} 49 | self.tokenizer = Tokenizer(word2idx) 50 | 51 | def __getitem__(self, i): 52 | ## 得到单个数据 53 | # print(i) 54 | src = self.sents_src[i] 55 | tgt = self.sents_tgt[i] 56 | token_ids, _ = self.tokenizer.encode(src, tgt) 57 | output = { 58 | "token_ids": token_ids, 59 | } 60 | return output 61 | 62 | def __len__(self): 63 | return len(self.sents_src) 64 | 65 | 66 | def collate_fn(batch): 67 | """ 68 | 动态padding, batch为一部分sample 69 | """ 70 | 71 | def padding(indice, max_length, pad_idx=0): 72 | """ 73 | pad 函数 74 | """ 75 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 76 | return torch.tensor(pad_indice) 77 | 78 | token_ids = [data["token_ids"] for data in batch] 79 | max_length = max([len(t) for t in token_ids]) 80 | 81 | token_ids_padded = padding(token_ids, max_length) 82 | target_ids_padded = token_ids_padded.clone() 83 | target_ids_padded[target_ids_padded == 0] = -100 84 | 85 | return token_ids_padded, target_ids_padded 86 | 87 | 88 | class Trainer: 89 | def __init__(self): 90 | # 加载数据 91 | self.sents_src, self.sents_tgt = read_corpus() 92 | 93 | # 判断是否有可用GPU 94 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 | print("device: " + str(self.device)) 96 | # 定义模型 97 | self.gpt_model = load_gpt(word2idx) 98 | ## 加载预训练的模型参数~ 99 | self.gpt_model.load_pretrain_params(model_path) 100 | # 将模型发送到计算设备(GPU或CPU) 101 | self.gpt_model.set_device(self.device) 102 | # 声明需要优化的参数 103 | self.optim_parameters = list(self.gpt_model.parameters()) 104 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 105 | # 声明自定义的数据加载器 106 | dataset = SeqDataset(self.sents_src, self.sents_tgt) 107 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 108 | 109 | def train(self, epoch): 110 | # 一个epoch的训练 111 | self.gpt_model.train() 112 | self.iteration(epoch, dataloader=self.dataloader, train=True) 113 | 114 | def save(self, save_path): 115 | """ 116 | 保存模型 117 | """ 118 | self.gpt_model.save_all_params(save_path) 119 | print("{} saved!".format(save_path)) 120 | 121 | def iteration(self, epoch, dataloader, train=True): 122 | total_loss = 0 123 | start_time = time.time() ## 得到当前时间 124 | step = 0 125 | for token_ids, target_ids in tqdm(dataloader, position=0, leave=True): 126 | step += 1 127 | if step % 1000 == 0: 128 | self.gpt_model.eval() 129 | test_data = ["梦见领袖", "梦见海盗抢东西", "梦见和自己的导师谈话"] 130 | for text in test_data: 131 | print(self.gpt_model.sample_generate(text, add_eos=True)) 132 | self.gpt_model.train() 133 | 134 | # 因为传入了target标签,因此会计算loss并且返回 135 | loss, _ = self.gpt_model(token_ids, 136 | labels=target_ids, 137 | ) 138 | # 反向传播 139 | if train: 140 | # 清空之前的梯度 141 | self.optimizer.zero_grad() 142 | # 反向传播, 获取新的梯度 143 | loss.backward() 144 | # 用获取的梯度更新模型参数 145 | self.optimizer.step() 146 | 147 | # 为计算当前epoch的平均loss 148 | total_loss += loss.item() 149 | 150 | end_time = time.time() 151 | spend_time = end_time - start_time 152 | # 打印训练信息 153 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 154 | # 保存模型 155 | self.save(model_save_path) 156 | 157 | 158 | if __name__ == '__main__': 159 | 160 | trainer = Trainer() 161 | train_epoches = 10 162 | for epoch in range(train_epoches): 163 | # 训练一个epoch 164 | trainer.train(epoch) 165 | -------------------------------------------------------------------------------- /examples/nezha_couplets_train.py: -------------------------------------------------------------------------------- 1 | ## 自动对对联的例子 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from torch.utils.data import Dataset, DataLoader 6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 7 | from bert_seq2seq import load_bert 8 | 9 | vocab_path = "./state_dict/nezha-base-www/vocab.txt" # roberta模型字典的位置 10 | model_name = "nezha" # 选择模型名字 11 | model_path = "./state_dict/nezha-base-www/pytorch_model.bin" # roberta模型位置 12 | recent_model_path = "" # 用于把已经训练好的模型继续训练 13 | model_save_path = "./nezha-duilian.bin" 14 | batch_size = 16 15 | lr = 1e-5 16 | data_dir = "./corpus/对联" 17 | word2idx = load_chinese_base_vocab(vocab_path) 18 | 19 | 20 | def read_corpus(dir_path): 21 | """ 22 | 读原始数据 23 | """ 24 | sents_src = [] 25 | sents_tgt = [] 26 | in_path = dir_path + "/in.txt" 27 | out_path = dir_path + "/out.txt" 28 | with open(in_path, "r", encoding="utf-8") as f: 29 | lines = f.readlines() 30 | for line in lines: 31 | sents_src.append(line.strip()) 32 | with open(out_path, "r", encoding="utf-8") as f: 33 | lines = f.readlines() 34 | for line in lines: 35 | sents_tgt.append(line.strip()) 36 | 37 | return sents_src, sents_tgt 38 | 39 | 40 | class BertDataset(Dataset): 41 | """ 42 | 针对特定数据集,定义一个相关的取数据的方式 43 | """ 44 | 45 | def __init__(self, sents_src, sents_tgt): 46 | ## 一般init函数是加载所有数据 47 | super(BertDataset, self).__init__() 48 | # 读原始数据 49 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 50 | self.sents_src = sents_src 51 | self.sents_tgt = sents_tgt 52 | 53 | self.idx2word = {k: v for v, k in word2idx.items()} 54 | self.tokenizer = Tokenizer(word2idx) 55 | 56 | def __getitem__(self, i): 57 | ## 得到单个数据 58 | # print(i) 59 | src = self.sents_src[i] 60 | tgt = self.sents_tgt[i] 61 | token_ids, token_type_ids = self.tokenizer.encode(src, tgt) 62 | output = { 63 | "token_ids": token_ids, 64 | "token_type_ids": token_type_ids, 65 | } 66 | return output 67 | 68 | def __len__(self): 69 | return len(self.sents_src) 70 | 71 | 72 | def collate_fn(batch): 73 | """ 74 | 动态padding, batch为一部分sample 75 | """ 76 | 77 | def padding(indice, max_length, pad_idx=0): 78 | """ 79 | pad 函数 80 | """ 81 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 82 | return torch.tensor(pad_indice) 83 | 84 | token_ids = [data["token_ids"] for data in batch] 85 | max_length = max([len(t) for t in token_ids]) 86 | token_type_ids = [data["token_type_ids"] for data in batch] 87 | 88 | token_ids_padded = padding(token_ids, max_length) 89 | token_type_ids_padded = padding(token_type_ids, max_length) 90 | target_ids_padded = token_ids_padded[:, 1:].contiguous() 91 | 92 | return token_ids_padded, token_type_ids_padded, target_ids_padded 93 | 94 | 95 | class Trainer: 96 | def __init__(self): 97 | # 加载数据 98 | self.sents_src, self.sents_tgt = read_corpus(data_dir) 99 | 100 | # 判断是否有可用GPU 101 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 102 | print("device: " + str(self.device)) 103 | # 定义模型 104 | self.bert_model = load_bert(word2idx, model_name=model_name) 105 | ## 加载预训练的模型参数~ 106 | self.bert_model.load_pretrain_params(model_path) 107 | # 将模型发送到计算设备(GPU或CPU) 108 | self.bert_model.set_device(self.device) 109 | # 声明需要优化的参数 110 | self.optim_parameters = list(self.bert_model.parameters()) 111 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 112 | # 声明自定义的数据加载器 113 | dataset = BertDataset(self.sents_src, self.sents_tgt) 114 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 115 | 116 | def train(self, epoch): 117 | # 一个epoch的训练 118 | self.bert_model.train() 119 | self.iteration(epoch, dataloader=self.dataloader, train=True) 120 | 121 | def save(self, save_path): 122 | """ 123 | 保存模型 124 | """ 125 | self.bert_model.save_all_params(save_path) 126 | print("{} saved!".format(save_path)) 127 | 128 | def iteration(self, epoch, dataloader, train=True): 129 | total_loss = 0 130 | start_time = time.time() ## 得到当前时间 131 | step = 0 132 | for token_ids, token_type_ids, target_ids in tqdm(dataloader, position=0, leave=True): 133 | step += 1 134 | if step % 5000 == 0: 135 | self.bert_model.eval() 136 | test_data = ["花海总涵功德水", "广汉飞霞诗似玉", "执政为民,德行天下"] 137 | for text in test_data: 138 | print(self.bert_model.generate(text, beam_size=3)) 139 | self.bert_model.train() 140 | 141 | # 因为传入了target标签,因此会计算loss并且返回 142 | predictions, loss = self.bert_model(token_ids, 143 | token_type_ids, 144 | labels=target_ids, 145 | ) 146 | # 反向传播 147 | if train: 148 | # 清空之前的梯度 149 | self.optimizer.zero_grad() 150 | # 反向传播, 获取新的梯度 151 | loss.backward() 152 | # 用获取的梯度更新模型参数 153 | self.optimizer.step() 154 | 155 | # 为计算当前epoch的平均loss 156 | total_loss += loss.item() 157 | 158 | end_time = time.time() 159 | spend_time = end_time - start_time 160 | # 打印训练信息 161 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 162 | # 保存模型 163 | self.save(model_save_path) 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | trainer = Trainer() 169 | train_epoches = 10 170 | for epoch in range(train_epoches): 171 | # 训练一个epoch 172 | trainer.train(epoch) 173 | -------------------------------------------------------------------------------- /examples/readme.md: -------------------------------------------------------------------------------- 1 | ## 训练文件说明 2 | 3 | ### roberta、bert 4 | 1. [roberta_THUCNews_auto_title.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_THUCNews_auto_title.py) 自动摘要任务,使用THUCNews数据集,数据量较大。 5 | 2. [roberta_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_auto_title_train.py) 自动摘要任务,使用了另一个小的数据集。 6 | 3. [roberta_math_ques_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_math_ques_train.py) 自动解答小学数学题。 7 | 4. [relationship_classify_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/relationship_classify_train.py) 人物关系分类任务。 8 | 5. [roberta_semantic_matching_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_semantic_matching_train.py) 语义匹配任务。 9 | 6. [roberta_relation_extract_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_relation_extract_train.py) 三元组抽取任务。 10 | 7. [roberta_poem_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_poem_train.py) roberta模型自动写诗任务。 11 | 8. [roberta_participle_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_participle_CRF_train.py) 中文分词任务。 12 | 9. [roberta_medical_ner_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_medical_ner_train.py) NER任务,使用医学数据。 13 | 10. [roberta_couplets_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_couplets_train.py) roberta模型对联任务。 14 | 11. [roberta_news_classification_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_news_classification_train.py) 文本分类任务。 15 | 12. [roberta_coarsness_NER_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_coarsness_NER_CRF_train.py) 粗粒度NER任务,使用roberta+CRF。 16 | 13. [roberta_coarsness_NER_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/coarsness_NER_train.py) 粗粒度NER任务,使用roberta。 17 | 14. [roberta_fine_grained_NER_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_fine_grained_NER_CRF_train.py) 细粒度NER任务,使用Bert+CRF。 18 | 15. [roberta_large_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_large_auto_title_train.py) roberta-large模型,自动标题任务。 19 | 20 | ### nezha 21 | 1. [nezha_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_auto_title_train.py) nezha模型,自动摘要任务。 22 | 2. [nezha_relation_extract_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_relation_extract_train.py) nezha模型,关系抽取任务。 23 | 3. [nezha_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_auto_title_train.py) 华为nezha模型,自动摘要任务。 24 | 4. [nezha_couplets_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_couplets_train.py) 华为nezha模型,自动对联任务 25 | 26 | ### T5 27 | 1. [t5_ancient_translation_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/t5_ancient_translation_train.py) t5模型进行古文翻译。 28 | 2. [t5_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_relation_extract_train.py) t5模型,自动标题任务。 29 | 30 | ### GPT-2 31 | 1. [gpt2_generate_article.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_generate_article.py) GPT-2自动生成文章任务。 32 | 2. [gpt2_explain_dream_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_explain_dream_train.py) gpt模型,使用周公解梦数据集。 33 | 3. [gpt2_ancient_translation_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_ancient_translation_train.py) gpt2模型进行古文翻译。 34 | 4. [gpt2_english_story_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_english_story_train.py) gpt2模型自动生成英文故事。 35 | 36 | ### Simbert 37 | 1. [simbert_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/simbert_train.py) SimBert模型生成相似句子。 -------------------------------------------------------------------------------- /examples/roberta_THUCNews_auto_title.py: -------------------------------------------------------------------------------- 1 | ## THUCNews 原始数据集 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | import glob 6 | from torch.utils.data import Dataset, DataLoader 7 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 8 | from bert_seq2seq import load_bert 9 | 10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 11 | word2idx = load_chinese_base_vocab(vocab_path) 12 | model_name = "roberta" # 选择模型名字 13 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # 模型位置 14 | recent_model_path = "./state_dict/bert_auto_title_model.bin" # 用于把已经训练好的模型继续训练 15 | model_save_path = "./state_dict/bert_auto_title_model.bin" 16 | batch_size = 16 17 | lr = 1e-5 18 | maxlen = 256 19 | 20 | class BertDataset(Dataset): 21 | """ 22 | 针对特定数据集,定义一个相关的取数据的方式 23 | """ 24 | def __init__(self) : 25 | ## 一般init函数是加载所有数据 26 | super(BertDataset, self).__init__() 27 | ## 拿到所有文件名字 28 | self.txts = glob.glob('./corpus/THUCNews/*/*.txt') 29 | self.idx2word = {k: v for v, k in word2idx.items()} 30 | self.tokenizer = Tokenizer(word2idx) 31 | 32 | def __getitem__(self, i): 33 | ## 得到单个数据 34 | # print(i) 35 | text_name = self.txts[i] 36 | with open(text_name, "r", encoding="utf-8") as f: 37 | text = f.read() 38 | text = text.split('\n') 39 | if len(text) > 1: 40 | title = text[0] 41 | content = '\n'.join(text[1:]) 42 | token_ids, token_type_ids = self.tokenizer.encode( 43 | content, title, max_length=maxlen 44 | ) 45 | output = { 46 | "token_ids": token_ids, 47 | "token_type_ids": token_type_ids, 48 | } 49 | return output 50 | 51 | return self.__getitem__(i + 1) 52 | 53 | def __len__(self): 54 | 55 | return len(self.txts) 56 | 57 | def collate_fn(batch): 58 | """ 59 | 动态padding, batch为一部分sample 60 | """ 61 | 62 | def padding(indice, max_length, pad_idx=0): 63 | """ 64 | pad 函数 65 | """ 66 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 67 | return torch.tensor(pad_indice) 68 | 69 | token_ids = [data["token_ids"] for data in batch] 70 | max_length = max([len(t) for t in token_ids]) 71 | token_type_ids = [data["token_type_ids"] for data in batch] 72 | 73 | token_ids_padded = padding(token_ids, max_length) 74 | token_type_ids_padded = padding(token_type_ids, max_length) 75 | target_ids_padded = token_ids_padded[:, 1:].contiguous() 76 | 77 | return token_ids_padded, token_type_ids_padded, target_ids_padded 78 | 79 | class Trainer: 80 | def __init__(self): 81 | # 判断是否有可用GPU 82 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | print("device: " + str(self.device)) 84 | # 定义模型 85 | self.bert_model = load_bert(word2idx, model_name=model_name) 86 | ## 加载预训练的模型参数~ 87 | 88 | self.bert_model.load_pretrain_params(model_path) 89 | # 加载已经训练好的模型,继续训练 90 | 91 | # 将模型发送到计算设备(GPU或CPU) 92 | self.bert_model.set_device(self.device) 93 | # 声明需要优化的参数 94 | self.optim_parameters = list(self.bert_model.parameters()) 95 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 96 | # 声明自定义的数据加载器 97 | dataset = BertDataset() 98 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 99 | 100 | def train(self, epoch): 101 | # 一个epoch的训练 102 | self.bert_model.train() 103 | self.iteration(epoch, dataloader=self.dataloader, train=True) 104 | 105 | def save(self, save_path): 106 | """ 107 | 保存模型 108 | """ 109 | self.bert_model.save_all_params(save_path) 110 | print("{} saved!".format(save_path)) 111 | 112 | def iteration(self, epoch, dataloader, train=True): 113 | total_loss = 0 114 | start_time = time.time() ## 得到当前时间 115 | step = 0 116 | report_loss = 0 117 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True): 118 | step += 1 119 | if step % 1000 == 0: 120 | self.bert_model.eval() 121 | test_data = ["夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。", 122 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献", 123 | "8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。"] 124 | for text in test_data: 125 | print(self.bert_model.generate(text, beam_size=3)) 126 | print("loss is " + str(report_loss)) 127 | report_loss = 0 128 | # self.eval(epoch) 129 | self.bert_model.train() 130 | if step % 8000 == 0: 131 | self.save(model_save_path) 132 | 133 | # 因为传入了target标签,因此会计算loss并且返回 134 | predictions, loss = self.bert_model(token_ids, 135 | token_type_ids, 136 | labels=target_ids, 137 | 138 | ) 139 | report_loss += loss.item() 140 | # 反向传播 141 | if train: 142 | # 清空之前的梯度 143 | self.optimizer.zero_grad() 144 | # 反向传播, 获取新的梯度 145 | loss.backward() 146 | # 用获取的梯度更新模型参数 147 | self.optimizer.step() 148 | 149 | # 为计算当前epoch的平均loss 150 | total_loss += loss.item() 151 | 152 | end_time = time.time() 153 | spend_time = end_time - start_time 154 | # 打印训练信息 155 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time)) 156 | # 保存模型 157 | self.save(model_save_path) 158 | 159 | if __name__ == '__main__': 160 | 161 | trainer = Trainer() 162 | train_epoches = 20 163 | 164 | for epoch in range(train_epoches): 165 | # 训练一个epoch 166 | trainer.train(epoch) -------------------------------------------------------------------------------- /examples/roberta_couplets_train.py: -------------------------------------------------------------------------------- 1 | ## 自动对对联的例子 2 | import sys 3 | import torch 4 | from tqdm import tqdm 5 | import time 6 | from torch.utils.data import Dataset, DataLoader 7 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 8 | from bert_seq2seq import load_bert 9 | 10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 11 | model_name = "roberta" # 选择模型名字 12 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置 13 | recent_model_path = "" # 用于把已经训练好的模型继续训练 14 | model_save_path = "./bert_duilian_model.bin" 15 | batch_size = 16 16 | lr = 1e-5 17 | data_dir = "./corpus/对联" 18 | word2idx = load_chinese_base_vocab(vocab_path) 19 | 20 | def read_corpus(dir_path): 21 | """ 22 | 读原始数据 23 | """ 24 | sents_src = [] 25 | sents_tgt = [] 26 | in_path = dir_path + "/in.txt" 27 | out_path = dir_path + "/out.txt" 28 | with open(in_path, "r", encoding="utf-8") as f: 29 | lines = f.readlines() 30 | for line in lines: 31 | sents_src.append(line.strip()) 32 | with open(out_path, "r", encoding="utf-8") as f: 33 | lines = f.readlines() 34 | for line in lines: 35 | sents_tgt.append(line.strip()) 36 | 37 | return sents_src, sents_tgt 38 | 39 | class BertDataset(Dataset): 40 | """ 41 | 针对特定数据集,定义一个相关的取数据的方式 42 | """ 43 | def __init__(self, sents_src, sents_tgt) : 44 | ## 一般init函数是加载所有数据 45 | super(BertDataset, self).__init__() 46 | # 读原始数据 47 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 48 | self.sents_src = sents_src 49 | self.sents_tgt = sents_tgt 50 | 51 | self.idx2word = {k: v for v, k in word2idx.items()} 52 | self.tokenizer = Tokenizer(word2idx) 53 | 54 | def __getitem__(self, i): 55 | ## 得到单个数据 56 | # print(i) 57 | src = self.sents_src[i] 58 | tgt = self.sents_tgt[i] 59 | token_ids, token_type_ids = self.tokenizer.encode(src, tgt) 60 | output = { 61 | "token_ids": token_ids, 62 | "token_type_ids": token_type_ids, 63 | } 64 | return output 65 | 66 | def __len__(self): 67 | 68 | return len(self.sents_src) 69 | 70 | def collate_fn(batch): 71 | """ 72 | 动态padding, batch为一部分sample 73 | """ 74 | 75 | def padding(indice, max_length, pad_idx=0): 76 | """ 77 | pad 函数 78 | """ 79 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 80 | return torch.tensor(pad_indice) 81 | 82 | token_ids = [data["token_ids"] for data in batch] 83 | max_length = max([len(t) for t in token_ids]) 84 | token_type_ids = [data["token_type_ids"] for data in batch] 85 | 86 | token_ids_padded = padding(token_ids, max_length) 87 | token_type_ids_padded = padding(token_type_ids, max_length) 88 | target_ids_padded = token_ids_padded[:, 1:].contiguous() 89 | 90 | return token_ids_padded, token_type_ids_padded, target_ids_padded 91 | 92 | class Trainer: 93 | def __init__(self): 94 | # 加载数据 95 | self.sents_src, self.sents_tgt = read_corpus(data_dir) 96 | 97 | # 判断是否有可用GPU 98 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | print("device: " + str(self.device)) 100 | # 定义模型 101 | self.bert_model = load_bert(word2idx, model_name=model_name) 102 | ## 加载预训练的模型参数~ 103 | self.bert_model.load_pretrain_params(model_path) 104 | # 将模型发送到计算设备(GPU或CPU) 105 | self.bert_model.set_device(self.device) 106 | # 声明需要优化的参数 107 | self.optim_parameters = list(self.bert_model.parameters()) 108 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 109 | # 声明自定义的数据加载器 110 | dataset = BertDataset(self.sents_src, self.sents_tgt) 111 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 112 | 113 | def train(self, epoch): 114 | # 一个epoch的训练 115 | self.bert_model.train() 116 | self.iteration(epoch, dataloader=self.dataloader, train=True) 117 | 118 | def save(self, save_path): 119 | """ 120 | 保存模型 121 | """ 122 | self.bert_model.save_all_params(save_path) 123 | print("{} saved!".format(save_path)) 124 | 125 | def iteration(self, epoch, dataloader, train=True): 126 | total_loss = 0 127 | start_time = time.time() ## 得到当前时间 128 | step = 0 129 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True): 130 | step += 1 131 | if step % 5000 == 0: 132 | self.bert_model.eval() 133 | test_data = ["花海总涵功德水", "广汉飞霞诗似玉", "执政为民,德行天下"] 134 | for text in test_data: 135 | print(self.bert_model.generate(text, beam_size=3)) 136 | self.bert_model.train() 137 | 138 | # 因为传入了target标签,因此会计算loss并且返回 139 | predictions, loss = self.bert_model(token_ids, 140 | token_type_ids, 141 | labels=target_ids, 142 | ) 143 | # 反向传播 144 | if train: 145 | # 清空之前的梯度 146 | self.optimizer.zero_grad() 147 | # 反向传播, 获取新的梯度 148 | loss.backward() 149 | # 用获取的梯度更新模型参数 150 | self.optimizer.step() 151 | 152 | # 为计算当前epoch的平均loss 153 | total_loss += loss.item() 154 | 155 | end_time = time.time() 156 | spend_time = end_time - start_time 157 | # 打印训练信息 158 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time)) 159 | # 保存模型 160 | self.save(model_save_path) 161 | 162 | if __name__ == '__main__': 163 | 164 | trainer = Trainer() 165 | train_epoches = 10 166 | for epoch in range(train_epoches): 167 | # 训练一个epoch 168 | trainer.train(epoch) 169 | -------------------------------------------------------------------------------- /examples/roberta_news_classification_train.py: -------------------------------------------------------------------------------- 1 | ## 文本分类的例子 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from torch.utils.data import Dataset, DataLoader 6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 7 | from bert_seq2seq import load_bert 8 | 9 | target = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"] 10 | 11 | data_path = "./corpus/新闻标题文本分类/Train.txt" 12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 13 | model_name = "roberta" # 选择模型名字 14 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置 15 | recent_model_path = "" # 用于把已经训练好的模型继续训练 16 | model_save_path = "./bert_multi_classify_model.bin" 17 | batch_size = 16 18 | lr = 1e-5 19 | # 加载字典 20 | word2idx = load_chinese_base_vocab(vocab_path) 21 | 22 | def read_corpus(): 23 | """ 24 | 读原始数据 25 | """ 26 | sents_src = [] 27 | sents_tgt = [] 28 | 29 | with open(data_path) as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | line = line.split("\t") 33 | sents_tgt.append(int(line[0])) 34 | sents_src.append(line[2]) 35 | return sents_src, sents_tgt 36 | 37 | ## 自定义dataset 38 | class NLUDataset(Dataset): 39 | """ 40 | 针对特定数据集,定义一个相关的取数据的方式 41 | """ 42 | def __init__(self, sents_src, sents_tgt) : 43 | ## 一般init函数是加载所有数据 44 | super(NLUDataset, self).__init__() 45 | # 读原始数据 46 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 47 | self.sents_src = sents_src 48 | self.sents_tgt = sents_tgt 49 | 50 | self.idx2word = {k: v for v, k in word2idx.items()} 51 | self.tokenizer = Tokenizer(word2idx) 52 | 53 | def __getitem__(self, i): 54 | ## 得到单个数据 55 | # print(i) 56 | src = self.sents_src[i] 57 | tgt = self.sents_tgt[i] 58 | token_ids, token_type_ids = self.tokenizer.encode(src) 59 | output = { 60 | "token_ids": token_ids, 61 | "token_type_ids": token_type_ids, 62 | "target_id": tgt 63 | } 64 | return output 65 | 66 | def __len__(self): 67 | return len(self.sents_src) 68 | 69 | def collate_fn(batch): 70 | """ 71 | 动态padding, batch为一部分sample 72 | """ 73 | 74 | def padding(indice, max_length, pad_idx=0): 75 | """ 76 | pad 函数 77 | """ 78 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 79 | return torch.tensor(pad_indice) 80 | 81 | token_ids = [data["token_ids"] for data in batch] 82 | max_length = max([len(t) for t in token_ids]) 83 | token_type_ids = [data["token_type_ids"] for data in batch] 84 | target_ids = [data["target_id"] for data in batch] 85 | target_ids = torch.tensor(target_ids, dtype=torch.long) 86 | 87 | token_ids_padded = padding(token_ids, max_length) 88 | token_type_ids_padded = padding(token_type_ids, max_length) 89 | # target_ids_padded = token_ids_padded[:, 1:].contiguous() 90 | 91 | return token_ids_padded, token_type_ids_padded, target_ids 92 | 93 | class Trainer: 94 | def __init__(self): 95 | # 加载数据 96 | self.sents_src, self.sents_tgt = read_corpus() 97 | self.tokenier = Tokenizer(word2idx) 98 | # 判断是否有可用GPU 99 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 100 | print("device: " + str(self.device)) 101 | # 定义模型 102 | self.bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target)) 103 | ## 加载预训练的模型参数~ 104 | self.bert_model.load_pretrain_params(model_path) 105 | # 将模型发送到计算设备(GPU或CPU) 106 | self.bert_model.set_device(self.device) 107 | # 声明需要优化的参数 108 | self.optim_parameters = list(self.bert_model.parameters()) 109 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 110 | # 声明自定义的数据加载器 111 | dataset = NLUDataset(self.sents_src, self.sents_tgt) 112 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 113 | 114 | def train(self, epoch): 115 | # 一个epoch的训练 116 | self.bert_model.train() 117 | self.iteration(epoch, dataloader=self.dataloader, train=True) 118 | 119 | def save(self, save_path): 120 | """ 121 | 保存模型 122 | """ 123 | self.bert_model.save_all_params(save_path) 124 | print("{} saved!".format(save_path)) 125 | 126 | def iteration(self, epoch, dataloader, train=True): 127 | total_loss = 0 128 | start_time = time.time() ## 得到当前时间 129 | step = 0 130 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True): 131 | step += 1 132 | if step % 2000 == 0: 133 | self.bert_model.eval() 134 | test_data = ["编剧梁馨月讨稿酬六六何念助阵 公司称协商解决", "西班牙BBVA第三季度净利降至15.7亿美元", "基金巨亏30亿 欲打开云天系跌停自救"] 135 | for text in test_data: 136 | text, text_ids = self.tokenier.encode(text) 137 | text = torch.tensor(text, device=self.device).view(1, -1) 138 | print(target[torch.argmax(self.bert_model(text)).item()]) 139 | self.bert_model.train() 140 | 141 | # 因为传入了target标签,因此会计算loss并且返回 142 | predictions, loss = self.bert_model(token_ids, 143 | labels=target_ids, 144 | ) 145 | # 反向传播 146 | if train: 147 | # 清空之前的梯度 148 | self.optimizer.zero_grad() 149 | # 反向传播, 获取新的梯度 150 | loss.backward() 151 | # 用获取的梯度更新模型参数 152 | self.optimizer.step() 153 | 154 | # 为计算当前epoch的平均loss 155 | total_loss += loss.item() 156 | 157 | end_time = time.time() 158 | spend_time = end_time - start_time 159 | # 打印训练信息 160 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time)) 161 | # 保存模型 162 | self.save(model_save_path) 163 | 164 | if __name__ == '__main__': 165 | 166 | trainer = Trainer() 167 | train_epoches = 10 168 | for epoch in range(train_epoches): 169 | # 训练一个epoch 170 | trainer.train(epoch) -------------------------------------------------------------------------------- /examples/roberta_semantic_matching_train.py: -------------------------------------------------------------------------------- 1 | # https://tianchi.aliyun.com/competition/entrance/531851/information 2 | import torch 3 | from tqdm import tqdm 4 | import time 5 | from torch.utils.data import Dataset, DataLoader 6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 7 | from bert_seq2seq import load_bert 8 | 9 | target = [0, 1] 10 | train_path = "./data/语义匹配/train.tsv" 11 | test_path = "./data/语义匹配/test.tsv" 12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 13 | model_name = "roberta" # 选择模型名字 14 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置 15 | recent_model_path = "" # 用于把已经训练好的模型继续训练 16 | model_save_path = "./bert_semantic_matching.bin" 17 | batch_size = 16 18 | lr = 1e-5 19 | # 加载字典 20 | word2idx = load_chinese_base_vocab(vocab_path) 21 | 22 | def read_corpus(data_path): 23 | """ 24 | 读原始数据 25 | """ 26 | sents_src = [] 27 | sents_tgt = [] 28 | 29 | with open(data_path) as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | line = line.split("\t") 33 | sents_tgt.append(int(line[2])) 34 | sents_src.append(line[0] + "#" +line[1]) 35 | return sents_src, sents_tgt 36 | 37 | ## 自定义dataset 38 | class NLUDataset(Dataset): 39 | """ 40 | 针对特定数据集,定义一个相关的取数据的方式 41 | """ 42 | def __init__(self, sents_src, sents_tgt) : 43 | ## 一般init函数是加载所有数据 44 | super(NLUDataset, self).__init__() 45 | # 读原始数据 46 | self.sents_src = sents_src 47 | self.sents_tgt = sents_tgt 48 | 49 | self.idx2word = {k: v for v, k in word2idx.items()} 50 | self.tokenizer = Tokenizer(word2idx) 51 | 52 | def __getitem__(self, i): 53 | ## 得到单个数据 54 | # print(i) 55 | src = self.sents_src[i] 56 | tgt = self.sents_tgt[i] 57 | token_ids, token_type_ids = self.tokenizer.encode(src) 58 | output = { 59 | "token_ids": token_ids, 60 | "token_type_ids": token_type_ids, 61 | "target_id": tgt 62 | } 63 | return output 64 | 65 | def __len__(self): 66 | return len(self.sents_src) 67 | 68 | def collate_fn(batch): 69 | """ 70 | 动态padding, batch为一部分sample 71 | """ 72 | 73 | def padding(indice, max_length, pad_idx=0): 74 | """ 75 | pad 函数 76 | """ 77 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 78 | return torch.tensor(pad_indice) 79 | 80 | token_ids = [data["token_ids"] for data in batch] 81 | max_length = max([len(t) for t in token_ids]) 82 | token_type_ids = [data["token_type_ids"] for data in batch] 83 | target_ids = [data["target_id"] for data in batch] 84 | target_ids = torch.tensor(target_ids, dtype=torch.long) 85 | 86 | token_ids_padded = padding(token_ids, max_length) 87 | token_type_ids_padded = padding(token_type_ids, max_length) 88 | # target_ids_padded = token_ids_padded[:, 1:].contiguous() 89 | 90 | return token_ids_padded, token_type_ids_padded, target_ids 91 | 92 | class Trainer: 93 | def __init__(self): 94 | # 加载数据 95 | self.sents_src, self.sents_tgt = read_corpus(train_path) 96 | self.tokenier = Tokenizer(word2idx) 97 | # 判断是否有可用GPU 98 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 99 | print("device: " + str(self.device)) 100 | # 定义模型 101 | self.bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target)) 102 | ## 加载预训练的模型参数~ 103 | self.bert_model.load_pretrain_params(model_path) 104 | # 将模型发送到计算设备(GPU或CPU) 105 | self.bert_model.set_device(self.device) 106 | # 声明需要优化的参数 107 | self.optim_parameters = list(self.bert_model.parameters()) 108 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 109 | # 声明自定义的数据加载器 110 | dataset = NLUDataset(self.sents_src, self.sents_tgt) 111 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 112 | 113 | def train(self, epoch): 114 | # 一个epoch的训练 115 | self.bert_model.train() 116 | self.iteration(epoch, dataloader=self.dataloader, train=True) 117 | 118 | def save(self, save_path): 119 | """ 120 | 保存模型 121 | """ 122 | self.bert_model.save_all_params(save_path) 123 | print("{} saved!".format(save_path)) 124 | 125 | def iteration(self, epoch, dataloader, train=True): 126 | total_loss = 0 127 | start_time = time.time() ## 得到当前时间 128 | step = 0 129 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True): 130 | step += 1 131 | if step % 2000 == 0: 132 | self.bert_model.eval() 133 | test_data = ["后悔了吗#你有没有后悔", "打开自动横屏#开启移动数据", "我觉得你很聪明#你聪明我是这么觉得"] 134 | for text in test_data: 135 | text, text_ids = self.tokenier.encode(text) 136 | text = torch.tensor(text, device=self.device).view(1, -1) 137 | print(target[torch.argmax(self.bert_model(text)).item()]) 138 | self.bert_model.train() 139 | # 保存模型 140 | self.save(model_save_path) 141 | 142 | # 因为传入了target标签,因此会计算loss并且返回 143 | predictions, loss = self.bert_model(token_ids, 144 | labels=target_ids, 145 | ) 146 | # 反向传播 147 | if train: 148 | # 清空之前的梯度 149 | self.optimizer.zero_grad() 150 | # 反向传播, 获取新的梯度 151 | loss.backward() 152 | # 用获取的梯度更新模型参数 153 | self.optimizer.step() 154 | 155 | # 为计算当前epoch的平均loss 156 | total_loss += loss.item() 157 | 158 | end_time = time.time() 159 | spend_time = end_time - start_time 160 | # 打印训练信息 161 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time)) 162 | 163 | 164 | if __name__ == '__main__': 165 | 166 | trainer = Trainer() 167 | train_epoches = 10 168 | for epoch in range(train_epoches): 169 | # 训练一个epoch 170 | trainer.train(epoch) -------------------------------------------------------------------------------- /examples/t5_ancient_translation_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import glob 4 | from torch.utils.data import Dataset, DataLoader 5 | from bert_seq2seq import T5PegasusTokenizer, load_chinese_base_vocab 6 | from bert_seq2seq import T5Model 7 | 8 | vocab_path = "./state_dict/t5-chinese/vocab.txt" 9 | model_path = "./state_dict/t5-chinese/pytorch_model.bin" 10 | model_save_path = "./state_dict/t5_ancient_trans_model.bin" 11 | batch_size = 8 12 | lr = 1e-5 13 | word2idx = load_chinese_base_vocab(vocab_path) 14 | tokenizer = T5PegasusTokenizer(word2idx) 15 | 16 | 17 | def read_corpus(): 18 | """ 19 | 读原始数据 20 | """ 21 | src = [] 22 | tgt = [] 23 | data_path = glob.glob("./corpus/文言文翻译/*") 24 | for p in data_path: 25 | dir = p.split("/")[:-1] 26 | dir = "/".join(dir) 27 | # print(dir) 28 | name = p.split("/")[-1] 29 | if "翻译" in name: 30 | # 找到了一个翻译文件 31 | tgt_name = name 32 | src_name = name[:-2] 33 | with open(dir + "/" + src_name) as fs: 34 | lines = fs.readlines() 35 | for line in lines: 36 | src.append(line.strip("\n").strip()) 37 | 38 | with open(dir + "/" + tgt_name) as ft: 39 | lines = ft.readlines() 40 | for line in lines: 41 | tgt.append(line.strip("\n").strip()) 42 | 43 | else: 44 | pass 45 | 46 | return src, tgt 47 | 48 | class SeqDataset(Dataset): 49 | """ 50 | 针对特定数据集,定义一个相关的取数据的方式 51 | """ 52 | 53 | def __init__(self, sents_src, sents_tgt): 54 | ## 一般init函数是加载所有数据 55 | super(SeqDataset, self).__init__() 56 | # 读原始数据 57 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 58 | self.sents_src = sents_src 59 | self.sents_tgt = sents_tgt 60 | 61 | self.idx2word = {k: v for v, k in word2idx.items()} 62 | 63 | def __getitem__(self, i): 64 | ## 得到单个数据 65 | # print(i) 66 | src = self.sents_src[i] 67 | tgt = self.sents_tgt[i] 68 | token_ids_src, _ = tokenizer.encode(src, max_length=256) 69 | token_ids_tgt, _ = tokenizer.encode(tgt, max_length=256) 70 | output = { 71 | "token_ids_src": token_ids_src, 72 | "token_ids_tgt": token_ids_tgt, 73 | } 74 | return output 75 | 76 | def __len__(self): 77 | return len(self.sents_src) 78 | 79 | 80 | def collate_fn(batch): 81 | """ 82 | 动态padding, batch为一部分sample 83 | """ 84 | 85 | def padding(indice, max_length, pad_idx=0): 86 | """ 87 | pad 函数 88 | """ 89 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 90 | return torch.tensor(pad_indice) 91 | 92 | token_ids_src = [data["token_ids_src"] for data in batch] 93 | max_length_src = max([len(t) for t in token_ids_src]) 94 | token_ids_tgt = [data["token_ids_tgt"] for data in batch] 95 | max_length_tgt = max([len(t) for t in token_ids_tgt]) 96 | 97 | token_ids_padded = padding(token_ids_src, max_length_src) 98 | target_ids_padded = padding(token_ids_tgt, max_length_tgt) 99 | labels_ids = target_ids_padded.clone() 100 | labels_ids[labels_ids == 0] = -100 101 | target_ids_padded = target_ids_padded[:, :-1].contiguous() 102 | labels_ids = labels_ids[:, 1:].contiguous() 103 | 104 | return token_ids_padded, target_ids_padded, labels_ids 105 | 106 | 107 | class Trainer: 108 | def __init__(self): 109 | # 加载数据 110 | self.sents_src, self.sents_tgt = read_corpus() 111 | 112 | # 判断是否有可用GPU 113 | # self.device = torch.device("cpu") 114 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 115 | print("device: " + str(self.device)) 116 | # 定义模型 117 | self.model = T5Model(word2idx) 118 | ## 加载预训练的模型参数~ 119 | self.model.load_pretrain_params(model_path) 120 | # 将模型发送到计算设备(GPU或CPU) 121 | self.model.set_device(self.device) 122 | # 声明需要优化的参数 123 | self.optim_parameters = list(self.model.parameters()) 124 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 125 | # 声明自定义的数据加载器 126 | dataset = SeqDataset(self.sents_src, self.sents_tgt) 127 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 128 | 129 | def train(self, epoch): 130 | # 一个epoch的训练 131 | self.model.train() 132 | self.iteration(epoch, dataloader=self.dataloader, train=True) 133 | 134 | def save(self, save_path): 135 | """ 136 | 保存模型 137 | """ 138 | self.model.save_all_params(save_path) 139 | print("{} saved!".format(save_path)) 140 | 141 | def iteration(self, epoch, dataloader, train=True): 142 | total_loss = 0 143 | report_loss = 0 144 | start_time = time.time() ## 得到当前时间 145 | step = 0 146 | for token_ids, target_ids, labels_ids in dataloader: 147 | step += 1 148 | # print(token_ids.shape) 149 | # print(target_ids.shape) 150 | # print(labels_ids.shape) 151 | if step % 4000 == 0: 152 | self.save(model_save_path) 153 | self.model.eval() 154 | test_data = ["遂入颍川。", "会日暝,结陈相持。", "一言兴邦,斯近之矣。"] 155 | for text in test_data: 156 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True)) 157 | self.model.train() 158 | print("report loss is " + str(report_loss)) 159 | 160 | # 因为传入了target标签,因此会计算loss并且返回 161 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0] 162 | # 反向传播 163 | if train: 164 | # 清空之前的梯度 165 | self.optimizer.zero_grad() 166 | # 反向传播, 获取新的梯度 167 | loss.backward() 168 | # 用获取的梯度更新模型参数 169 | self.optimizer.step() 170 | 171 | # 为计算当前epoch的平均loss 172 | total_loss += loss.item() 173 | report_loss += loss.item() 174 | 175 | end_time = time.time() 176 | spend_time = end_time - start_time 177 | # 打印训练信息 178 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 179 | # 保存模型 180 | self.save(model_save_path) 181 | 182 | 183 | if __name__ == '__main__': 184 | 185 | trainer = Trainer() 186 | train_epoches = 10 187 | for epoch in range(train_epoches): 188 | # 训练一个epoch 189 | trainer.train(epoch) 190 | -------------------------------------------------------------------------------- /examples/t5_auto_title_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import glob 4 | from torch.utils.data import Dataset, DataLoader 5 | from bert_seq2seq import T5PegasusTokenizer, load_chinese_base_vocab 6 | from bert_seq2seq import T5Model 7 | from tqdm import tqdm 8 | 9 | src_dir = './corpus/auto_title/train.src' 10 | tgt_dir = './corpus/auto_title/train.tgt' 11 | 12 | vocab_path = "./state_dict/t5-chinese/vocab.txt" ## 字典 13 | model_path = "./state_dict/t5-chinese/pytorch_model.bin" ## 预训练参数 14 | 15 | model_save_path = "./state_dict/t5_autotile.bin" ## 训练完模型 保存在哪里 16 | batch_size = 1 17 | lr = 1e-5 18 | word2idx = load_chinese_base_vocab(vocab_path) 19 | tokenizer = T5PegasusTokenizer(word2idx) 20 | 21 | def read_file(src_dir, tgt_dir): 22 | src = [] 23 | tgt = [] 24 | 25 | with open(src_dir,'r',encoding='utf-8') as f: 26 | lines = f.readlines() 27 | 28 | for line in lines: 29 | src.append(line.strip('\n').lower()) 30 | 31 | with open(tgt_dir,'r',encoding='utf-8') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | tgt.append(line.strip('\n').lower()) 35 | 36 | return src, tgt 37 | 38 | 39 | class SeqDataset(Dataset): 40 | """ 41 | 针对特定数据集,定义一个相关的取数据的方式 42 | """ 43 | 44 | def __init__(self, sents_src, sents_tgt): 45 | ## 一般init函数是加载所有数据 46 | super(SeqDataset, self).__init__() 47 | # 读原始数据 48 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir) 49 | self.sents_src = sents_src 50 | self.sents_tgt = sents_tgt 51 | 52 | self.idx2word = {k: v for v, k in word2idx.items()} 53 | 54 | def __getitem__(self, i): 55 | ## 得到单个数据 56 | # print(i) 57 | src = self.sents_src[i] 58 | tgt = self.sents_tgt[i] 59 | token_ids_src, _ = tokenizer.encode(src, max_length=256) 60 | token_ids_tgt, _ = tokenizer.encode(tgt, max_length=256) 61 | output = { 62 | "token_ids_src": token_ids_src, 63 | "token_ids_tgt": token_ids_tgt, 64 | } 65 | return output 66 | 67 | def __len__(self): 68 | return len(self.sents_src) 69 | 70 | 71 | def collate_fn(batch): 72 | """ 73 | 动态padding, batch为一部分sample 74 | """ 75 | 76 | def padding(indice, max_length, pad_idx=0): 77 | """ 78 | pad 函数 79 | """ 80 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice] 81 | return torch.tensor(pad_indice) 82 | 83 | token_ids_src = [data["token_ids_src"] for data in batch] 84 | max_length_src = max([len(t) for t in token_ids_src]) 85 | token_ids_tgt = [data["token_ids_tgt"] for data in batch] 86 | max_length_tgt = max([len(t) for t in token_ids_tgt]) 87 | 88 | token_ids_padded = padding(token_ids_src, max_length_src) 89 | target_ids_padded = padding(token_ids_tgt, max_length_tgt) 90 | labels_ids = target_ids_padded.clone() 91 | labels_ids[labels_ids == 0] = -100 92 | target_ids_padded = target_ids_padded[:, :-1].contiguous() 93 | labels_ids = labels_ids[:, 1:].contiguous() 94 | 95 | return token_ids_padded, target_ids_padded, labels_ids 96 | 97 | 98 | class Trainer: 99 | def __init__(self): 100 | # 加载数据 101 | self.sents_src, self.sents_tgt = read_file(src_dir, tgt_dir) 102 | 103 | # 判断是否有可用GPU 104 | # self.device = torch.device("cpu") 105 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 106 | print("device: " + str(self.device)) 107 | # 定义模型 108 | self.model = T5Model(word2idx) 109 | ## 加载预训练的模型参数~ 110 | self.model.load_pretrain_params(model_path) 111 | # self.model.load_all_params(save_model_path) 112 | # 将模型发送到计算设备(GPU或CPU) 113 | self.model.set_device(self.device) 114 | # 声明需要优化的参数 115 | self.optim_parameters = list(self.model.parameters()) 116 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3) 117 | # 声明自定义的数据加载器 118 | dataset = SeqDataset(self.sents_src, self.sents_tgt) 119 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 120 | 121 | def train(self, epoch): 122 | # 一个epoch的训练 123 | self.model.train() 124 | self.iteration(epoch, dataloader=self.dataloader, train=True) 125 | 126 | def save(self, save_path): 127 | """ 128 | 保存模型 129 | """ 130 | self.model.save_all_params(save_path) 131 | print("{} saved!".format(save_path)) 132 | 133 | def iteration(self, epoch, dataloader, train=True): 134 | total_loss = 0 135 | report_loss = 0 136 | start_time = time.time() ## 得到当前时间 137 | step = 0 138 | for token_ids, target_ids, labels_ids in tqdm(dataloader, total=len(dataloader)): 139 | step += 1 140 | if step % 100 == 0: 141 | self.save(model_save_path) 142 | self.model.eval() 143 | test_data = ["本文总结了十个可穿戴产品的设计原则,而这些原则同样也是笔者认为是这个行业最吸引人的地方:1为人们解决重复性问题,2从人开始而不是从机器开始,3要引起注意但不要刻意,4提升用户能力而不是取代人", 144 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界,还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落,未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献", 145 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"] 146 | 147 | for text in test_data: 148 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True, top_k=5)) 149 | self.model.train() 150 | print("report loss is " + str(report_loss)) 151 | report_loss = 0 152 | 153 | # 因为传入了target标签,因此会计算loss并且返回 154 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0] 155 | # 反向传播 156 | if train: 157 | # 清空之前的梯度 158 | self.optimizer.zero_grad() 159 | # 反向传播, 获取新的梯度 160 | loss.backward() 161 | # 用获取的梯度更新模型参数 162 | self.optimizer.step() 163 | 164 | # 为计算当前epoch的平均loss 165 | total_loss += loss.item() 166 | report_loss += loss.item() 167 | 168 | end_time = time.time() 169 | spend_time = end_time - start_time 170 | # 打印训练信息 171 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time)) 172 | # 保存模型 173 | self.save(model_save_path) 174 | 175 | 176 | if __name__ == '__main__': 177 | 178 | trainer = Trainer() 179 | train_epoches = 10 180 | for epoch in range(train_epoches): 181 | # 训练一个epoch 182 | trainer.train(epoch) 183 | -------------------------------------------------------------------------------- /img/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/.DS_Store -------------------------------------------------------------------------------- /img/fenci.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/fenci.png -------------------------------------------------------------------------------- /img/ner-input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner-input.png -------------------------------------------------------------------------------- /img/ner-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner-out.png -------------------------------------------------------------------------------- /img/ner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner.jpg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='bert_seq2seq', 5 | version='2.3.5', 6 | description='use torch to do bert_seq2seq task', 7 | long_description='bert_seq2seq: https://github.com/920232796/bert_seq2seq', 8 | license='Apache License 2.0', 9 | url='https://github.com/920232796/bert_seq2seq', 10 | author='xingzhaohu', 11 | author_email='920232796@qq.com', 12 | packages=find_packages() 13 | ) 14 | -------------------------------------------------------------------------------- /test/auto_title_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | auto_title_model = "./state_dict/bert_auto_title_model2.bin" 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | if __name__ == "__main__": 9 | vocab_path = "../state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 10 | model_name = "roberta" # 选择模型名字 11 | # 加载字典 12 | word2idx = load_chinese_base_vocab(vocab_path) 13 | # 定义模型 14 | bert_model = load_bert(word2idx, model_name=model_name) 15 | bert_model.set_device(device) 16 | bert_model.eval() 17 | ## 加载训练的模型参数~ 18 | bert_model.load_all_params(model_path=auto_title_model, device=device) 19 | 20 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理", 21 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。", 22 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"] 23 | for text in test_data: 24 | with torch.no_grad(): 25 | print(bert_model.generate(text, beam_size=3)) 26 | print("\n") 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /test/bert_english_autotitle_test.py: -------------------------------------------------------------------------------- 1 | ## 英文自动摘要测试文件 2 | import torch 3 | import glob 4 | import json 5 | from rouge import Rouge 6 | from bert_seq2seq import load_bert 7 | from transformers import AutoTokenizer 8 | 9 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 10 | word2idx = tokenizer.get_vocab() 11 | auto_title_model = "./state_dict/bert_english_auto_title_model.bin" 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | maxlen = 256 14 | 15 | if __name__ == "__main__": 16 | model_name = "bert" # 选择模型名字 17 | # 定义模型 18 | bert_model = load_bert(word2idx, tokenizer=tokenizer, model_name=model_name) 19 | bert_model.set_device(device) 20 | bert_model.eval() 21 | ## 加载训练的模型参数~ 22 | bert_model.load_all_params(model_path=auto_title_model, device=device) 23 | rouge = Rouge() 24 | test_file = glob.glob("./corpus/english_autotitle_test/*.json") 25 | num_file = len(test_file) 26 | rouge_1_item = [0.0, 0.0, 0.0] 27 | with open("./auto_title_res.txt", "a+") as fw: 28 | for s_file in test_file : 29 | with open(s_file, "r") as f: 30 | c = f.read() 31 | j = json.loads(c) 32 | title = j["Title"] 33 | text = j["abstract"] 34 | out = bert_model.generate(text, beam_size=3, out_max_length=100, max_length=maxlen) 35 | print(out) 36 | fw.write(title + "\t" + out + "\t" + text + "\n") 37 | 38 | rouge_score = rouge.get_scores(title, out) 39 | print(rouge_score) 40 | rouge_1 = rouge_score[0]["rouge-1"] 41 | rouge_1_item[0] += rouge_1["f"] 42 | rouge_1_item[1] += rouge_1["p"] 43 | rouge_1_item[2] += rouge_1["r"] 44 | # print(rouge_score[0]["rouge-2"]) 45 | # print(rouge_score[0]["rouge-l"]) 46 | for i in range(len(rouge_1_item)): 47 | rouge_1_item[i] = rouge_1_item[i] / num_file 48 | 49 | 50 | print(rouge_1_item) 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /test/english_t5_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 4 | from bert_seq2seq.extend_model_method import ExtendModel 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | if __name__ == "__main__": 9 | 10 | tokenizer = AutoTokenizer.from_pretrained("/Users/xingzhaohu/Downloads/t5_test") 11 | model = AutoModelForSeq2SeqLM.from_pretrained("/Users/xingzhaohu/Downloads/t5_test") 12 | model.eval() 13 | model.to(device) 14 | model = ExtendModel(model, tokenizer, bos_id=0, eos_id=1) 15 | print(model.sample_generate_encoder_decoder("translate English to German: That is good", out_max_length=300, add_eos=True)) 16 | 17 | 18 | -------------------------------------------------------------------------------- /test/get_bert_embedding.py: -------------------------------------------------------------------------------- 1 | ## 使用bert对一个句子进行编码 2 | 3 | import torch 4 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 5 | from bert_seq2seq import load_bert 6 | 7 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | if __name__ == "__main__": 11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 12 | model_name = "roberta" # 选择模型名字 13 | # 加载字典 14 | word2idx = load_chinese_base_vocab(vocab_path) 15 | # 定义模型 16 | bert_model = load_bert(word2idx, model_name=model_name, model_class="embedding") 17 | bert_model.set_device(device) 18 | bert_model.eval() 19 | ## 加载训练的模型参数~ 20 | bert_model.load_pretrain_params(model_path) 21 | 22 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理", 23 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。", 24 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"] 25 | for text in test_data: 26 | with torch.no_grad(): 27 | print(bert_model(text).shape) 28 | print("\n") 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /test/gpt_ancient_translation_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq import load_gpt 4 | from bert_seq2seq import load_chinese_base_vocab 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | vocab_path = "./state_dict/gpt2通用中文模型/vocab.txt" 9 | model_path = "./state_dict/gpt_ancient_trans_model.bin" 10 | 11 | if __name__ == "__main__": 12 | word2ix = load_chinese_base_vocab(vocab_path) 13 | model = load_gpt(word2ix) 14 | model.eval() 15 | model.set_device(device) 16 | model.load_all_params(model_path) 17 | 18 | print(model.sample_generate("余忆童稚时,能张目对日。", out_max_length=300, add_eos=True)) -------------------------------------------------------------------------------- /test/gpt_english_story_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq import load_gpt 4 | from transformers import AutoTokenizer 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | model_path = "./state_dict/gpt_auto_story.bin" 9 | 10 | if __name__ == "__main__": 11 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") 12 | word2ix = tokenizer.get_vocab() 13 | model = load_gpt(word2ix, tokenizer=tokenizer) 14 | model.eval() 15 | model.set_device(device) 16 | model.load_all_params(model_path, device=device) 17 | 18 | print(model.sample_generate_english("Strong Winds", out_max_length=300, add_eos=True)) -------------------------------------------------------------------------------- /test/gpt_explain_dream_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq import load_gpt 4 | from bert_seq2seq import load_chinese_base_vocab 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | vocab_path = "./state_dict/gpt_vocab.txt" 9 | model_path = "./state_dict/gpt_explain_dream_model.bin" 10 | 11 | if __name__ == "__main__": 12 | word2ix = load_chinese_base_vocab(vocab_path) 13 | model = load_gpt(word2ix) 14 | model.eval() 15 | model.set_device(device) 16 | model.load_all_params(model_path) 17 | 18 | print(model.sample_generate("梦见天气很好", out_max_length=300, add_eos=True)) -------------------------------------------------------------------------------- /test/gpt_test_english.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bert_seq2seq import load_gpt 4 | from transformers import AutoTokenizer 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | model_path = "./state_dict/english_gpt_model/english_gpt_story.bin" 9 | 10 | if __name__ == "__main__": 11 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") 12 | word2ix = tokenizer.get_vocab() 13 | model = load_gpt(word2ix, tokenizer=tokenizer) 14 | model.eval() 15 | model.set_device(device) 16 | model.load_pretrain_params(model_path) 17 | 18 | print(model.sample_generate_english("Nice weather today", out_max_length=300)) -------------------------------------------------------------------------------- /test/nezha_auto_title_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | auto_title_model = "./state_dict/nezha_auto_title.bin" 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | if __name__ == "__main__": 9 | vocab_path = "./state_dict/nezha-base-www/vocab.txt" # roberta模型字典的位置 10 | model_name = "nezha" # 选择模型名字 11 | # 加载字典 12 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 13 | # 定义模型 14 | bert_model = load_bert(word2idx, model_name=model_name) 15 | bert_model.set_device(device) 16 | bert_model.eval() 17 | ## 加载训练的模型参数~ 18 | bert_model.load_all_params(model_path=auto_title_model, device=device) 19 | 20 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理", 21 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。", 22 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"] 23 | for text in test_data: 24 | with torch.no_grad(): 25 | print(bert_model.generate(text, beam_size=3)) 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /test/nezha_relation_extract_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 5 | from bert_seq2seq import load_bert 6 | 7 | relation_extrac_model = "./state_dict/nezha_relation_extract.bin" 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 10 | model_name = "nezha" # 选择模型名字 11 | # model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位 12 | # 加载字典 13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 14 | tokenizer = Tokenizer(word2idx) 15 | idx2word = {v: k for k, v in word2idx.items()} 16 | 17 | predicate2id, id2predicate = {}, {} 18 | with open('./corpus/三元组抽取/all_50_schemas') as f: 19 | for l in f: 20 | l = json.loads(l) 21 | if l['predicate'] not in predicate2id: 22 | id2predicate[len(predicate2id)] = l['predicate'] 23 | predicate2id[l['predicate']] = len(predicate2id) 24 | 25 | 26 | def search(pattern, sequence): 27 | """从sequence中寻找子串pattern 28 | 如果找到,返回第一个下标;否则返回-1。 29 | """ 30 | n = len(pattern) 31 | for i in range(len(sequence)): 32 | if sequence[i:i + n] == pattern: 33 | return i 34 | return -1 35 | 36 | def search_subject(token_ids, subject_labels): 37 | # subject_labels: (lens, 2) 38 | if type(subject_labels) is torch.Tensor: 39 | subject_labels = subject_labels.numpy() 40 | if type(token_ids) is torch.Tensor: 41 | token_ids = token_ids.cpu().numpy() 42 | subjects = [] 43 | subject_ids = [] 44 | start = -1 45 | end = -1 46 | for i in range(len(token_ids)): 47 | if subject_labels[i, 0] > 0.5: 48 | start = i 49 | for j in range(len(token_ids)): 50 | if subject_labels[j, 1] > 0.5: 51 | subject_labels[j, 1] = 0 52 | end = j 53 | break 54 | if start == -1 or end == -1: 55 | continue 56 | subject = "" 57 | for k in range(start, end + 1): 58 | subject += idx2word[token_ids[k]] 59 | # print(subject) 60 | subject_ids.append([start, end]) 61 | start = -1 62 | end = -1 63 | subjects.append(subject) 64 | 65 | return subjects, subject_ids 66 | 67 | def search_object(token_ids, object_labels): 68 | objects = [] 69 | if type(object_labels) is torch.Tensor: 70 | object_labels = object_labels.numpy() 71 | if type(token_ids) is torch.Tensor: 72 | token_ids = token_ids.cpu().numpy() 73 | # print(object_labels.sum()) 74 | start = np.where(object_labels[:, :, 0] > 0.5) 75 | end = np.where(object_labels[:, :, 1] > 0.5) 76 | # print(start) 77 | # print(end) 78 | for _start, predicate1 in zip(*start): 79 | for _end, predicate2 in zip(*end): 80 | if _start <= _end and predicate1 == predicate2: 81 | object_text = "" 82 | for k in range(_start, _end + 1): 83 | # print(token_ids(k)) 84 | object_text += idx2word[token_ids[k]] 85 | objects.append( 86 | (id2predicate[predicate1], object_text) 87 | ) 88 | break 89 | 90 | return objects 91 | 92 | if __name__ == "__main__": 93 | 94 | # 定义模型 95 | bert_model = load_bert(word2idx, model_class="relation_extrac", model_name=model_name, target_size=len(predicate2id)) 96 | bert_model.eval() 97 | bert_model.set_device(device) 98 | # ## 加载预训练的模型参数~ 99 | checkpoint = torch.load(relation_extrac_model, map_location="cpu") 100 | # print(checkpoint) 101 | bert_model.load_all_params(model_path=relation_extrac_model, device=device) 102 | text = ["查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部", 103 | "《星空黑夜传奇》是连载于起点中文网的网络小说,作者是啤酒的罪孽", 104 | "《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧", 105 | "杨铁心和郭啸天兄弟二人在牛家村的农屋里喝酒,他们的岳飞大将军在风波亭被害之事,二人希望能够像岳飞大将军一样精忠报国。"] 106 | 107 | for d in text: 108 | with torch.no_grad(): 109 | token_ids_test, segment_ids = tokenizer.encode(d, max_length=256) 110 | token_ids_test = torch.tensor(token_ids_test, device=device).view(1, -1) 111 | # 先预测subject 112 | pred_subject = bert_model.predict_subject(token_ids_test) 113 | pred_subject = pred_subject.squeeze(0) 114 | subject_texts, subject_idss = search_subject(token_ids_test[0], pred_subject.cpu()) 115 | if len(subject_texts) == 0: 116 | print("no subject predicted~") 117 | for sub_text, sub_ids in zip(subject_texts, subject_idss): 118 | print("subject is " + str(sub_text)) 119 | sub_ids = torch.tensor(sub_ids, device=device).view(1, -1) 120 | # print("sub_ids shape is " + str(sub_ids)) 121 | object_p_pred = bert_model.predict_object_predicate(token_ids_test, sub_ids) 122 | res = search_object(token_ids_test[0], object_p_pred.squeeze(0).cpu()) 123 | print("p and obj is " + str(res)) 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /test/poem_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | auto_title_model = "./state_dict/bert_model_poem_ci_duilian.bin" 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | if __name__ == "__main__": 9 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 10 | model_name = "roberta" # 选择模型名字 11 | # model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位 12 | # 加载字典 13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 14 | # 定义模型 15 | bert_model = load_bert(word2idx, model_name=model_name) 16 | bert_model.set_device(device) 17 | bert_model.eval() 18 | # ## 加载预训练的模型参数~ 19 | checkpoint = torch.load(auto_title_model, map_location="cpu") 20 | # print(checkpoint) 21 | bert_model.load_all_params(model_path=auto_title_model, device=device) 22 | test_data = ["江山竞秀,万里风光入画图##对联"] 23 | with torch.no_grad(): 24 | for text in test_data: 25 | if text[-1] == "句" or text[-1] == "诗": 26 | print(bert_model.generate(text, beam_size=3, is_poem=True)) 27 | else: 28 | print(bert_model.generate(text, beam_size=3, is_poem=False)) 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /test/relation_extract_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | import json 5 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 6 | from bert_seq2seq import load_bert 7 | 8 | relation_extrac_model = "./state_dict/bert_model_relation_extrac.bin" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 11 | model_name = "roberta" # 选择模型名字 12 | # 加载字典 13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 14 | tokenizer = Tokenizer(word2idx) 15 | idx2word = {v: k for k, v in word2idx.items()} 16 | 17 | predicate2id, id2predicate = {}, {} 18 | with open('./corpus/三元组抽取/all_50_schemas') as f: 19 | for l in f: 20 | l = json.loads(l) 21 | if l['predicate'] not in predicate2id: 22 | id2predicate[len(predicate2id)] = l['predicate'] 23 | predicate2id[l['predicate']] = len(predicate2id) 24 | 25 | 26 | def search(pattern, sequence): 27 | """从sequence中寻找子串pattern 28 | 如果找到,返回第一个下标;否则返回-1。 29 | """ 30 | n = len(pattern) 31 | for i in range(len(sequence)): 32 | if sequence[i:i + n] == pattern: 33 | return i 34 | return -1 35 | 36 | def search_subject(token_ids, subject_labels): 37 | # subject_labels: (lens, 2) 38 | if type(subject_labels) is torch.Tensor: 39 | subject_labels = subject_labels.numpy() 40 | if type(token_ids) is torch.Tensor: 41 | token_ids = token_ids.cpu().numpy() 42 | subjects = [] 43 | subject_ids = [] 44 | start = -1 45 | end = -1 46 | for i in range(len(token_ids)): 47 | if subject_labels[i, 0] > 0.5: 48 | start = i 49 | for j in range(len(token_ids)): 50 | if subject_labels[j, 1] > 0.5: 51 | subject_labels[j, 1] = 0 52 | end = j 53 | break 54 | if start == -1 or end == -1: 55 | continue 56 | subject = "" 57 | for k in range(start, end + 1): 58 | subject += idx2word[token_ids[k]] 59 | # print(subject) 60 | subject_ids.append([start, end]) 61 | start = -1 62 | end = -1 63 | subjects.append(subject) 64 | 65 | return subjects, subject_ids 66 | 67 | def search_object(token_ids, object_labels): 68 | objects = [] 69 | if type(object_labels) is torch.Tensor: 70 | object_labels = object_labels.numpy() 71 | if type(token_ids) is torch.Tensor: 72 | token_ids = token_ids.cpu().numpy() 73 | # print(object_labels.sum()) 74 | start = np.where(object_labels[:, :, 0] > 0.5) 75 | end = np.where(object_labels[:, :, 1] > 0.5) 76 | # print(start) 77 | # print(end) 78 | for _start, predicate1 in zip(*start): 79 | for _end, predicate2 in zip(*end): 80 | if _start <= _end and predicate1 == predicate2: 81 | object_text = "" 82 | for k in range(_start, _end + 1): 83 | # print(token_ids(k)) 84 | object_text += idx2word[token_ids[k]] 85 | objects.append( 86 | (id2predicate[predicate1], object_text) 87 | ) 88 | break 89 | 90 | return objects 91 | 92 | if __name__ == "__main__": 93 | 94 | # 定义模型 95 | bert_model = load_bert(word2idx, model_class="relation_extrac", model_name=model_name, target_size=len(predicate2id)) 96 | bert_model.eval() 97 | bert_model.set_device(device) 98 | # ## 加载预训练的模型参数~ 99 | checkpoint = torch.load(relation_extrac_model, map_location="cpu") 100 | # print(checkpoint) 101 | bert_model.load_all_params(model_path=relation_extrac_model, device=device) 102 | text = ["查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部", 103 | "《星空黑夜传奇》是连载于起点中文网的网络小说,作者是啤酒的罪孽", 104 | "《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧", 105 | "杨铁心和郭啸天兄弟二人在牛家村的农屋里喝酒,他们的岳飞大将军在风波亭被害之事,二人希望能够像岳飞大将军一样精忠报国。"] 106 | 107 | for d in text: 108 | with torch.no_grad(): 109 | token_ids_test, segment_ids = tokenizer.encode(d, max_length=256) 110 | token_ids_test = torch.tensor(token_ids_test, device=device).view(1, -1) 111 | # 先预测subject 112 | pred_subject = bert_model.predict_subject(token_ids_test) 113 | pred_subject = pred_subject.squeeze(0) 114 | subject_texts, subject_idss = search_subject(token_ids_test[0], pred_subject.cpu()) 115 | if len(subject_texts) == 0: 116 | print("no subject predicted~") 117 | for sub_text, sub_ids in zip(subject_texts, subject_idss): 118 | print("subject is " + str(sub_text)) 119 | sub_ids = torch.tensor(sub_ids, device=device).view(1, -1) 120 | # print("sub_ids shape is " + str(sub_ids)) 121 | object_p_pred = bert_model.predict_object_predicate(token_ids_test, sub_ids) 122 | res = search_object(token_ids_test[0], object_p_pred.squeeze(0).cpu()) 123 | print("p and obj is " + str(res)) 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /test/semantic_matching_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | target = ["0", "1"] 6 | 7 | cls_model = "./state_dict/bert_semantic_matching.bin" 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | if __name__ == "__main__": 11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 12 | model_name = "roberta" # 选择模型名字 13 | # 加载字典 14 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 15 | tokenizer = Tokenizer(word2idx) 16 | # 定义模型 17 | bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target)) 18 | bert_model.set_device(device) 19 | bert_model.eval() 20 | ## 加载训练的模型参数~ 21 | bert_model.load_all_params(model_path=cls_model, device=device) 22 | test_data = ["你是不是我仇人#你是俺的仇人吗", 23 | "这个就没意思了#我没别的意思", 24 | "查一下我的家在哪里#家在哪里?"] 25 | for text in test_data: 26 | with torch.no_grad(): 27 | text_ids, _ = tokenizer.encode(text) 28 | text_ids = torch.tensor(text_ids, device=device).view(1, -1) 29 | print(text + " -> res is " + str(target[torch.argmax(bert_model(text_ids)).item()])) 30 | -------------------------------------------------------------------------------- /test/t5_chinese_autotitle_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, T5PegasusTokenizer 3 | from bert_seq2seq.extend_model_method import ExtendModel 4 | import glob 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | from bert_seq2seq.t5_ch import T5Model 9 | 10 | vocab_path = "./state_dict/t5-chinese/vocab.txt" 11 | model_path = "./state_dict/t5_autotile.bin" 12 | word2idx = load_chinese_base_vocab(vocab_path) 13 | 14 | model = T5Model(word2idx, size="base") 15 | model.set_device(device) 16 | model.load_all_params(model_path) 17 | model.eval() 18 | 19 | all_txt = glob.glob("./*.txt") 20 | print(all_txt) 21 | for t in all_txt: 22 | with open(t, encoding="utf-8") as f: 23 | content = f.read() 24 | out = model.sample_generate_encoder_decoder(content) 25 | print(out) 26 | 27 | # for t in all_txt: 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /test/t5_chinese_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, T5PegasusTokenizer 3 | from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration 4 | from bert_seq2seq.extend_model_method import ExtendModel 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | # transformer t5 代码 9 | model_path = './state_dict/t5-chinese' 10 | model = MT5ForConditionalGeneration.from_pretrained(model_path) 11 | word2ix = load_chinese_base_vocab("./state_dict/t5-chinese/vocab.txt") 12 | tokenizer = T5PegasusTokenizer(word2ix) 13 | model.eval() 14 | model = ExtendModel(model, tokenizer=tokenizer, bos_id=word2ix["[CLS]"], eos_id=word2ix["[SEP]"], device=device) 15 | text = '从那之后,一发不可收拾。此后的近百年间,一共有十七位新娘在与君山一带失踪。有时十几年相安无事,有时短短一个月内失踪两名。一个恐怖传说迅速传开:与君山里住着一位鬼新郎,若是他看中了一位女子,便会在她出嫁的路上将她掳走,再把送亲的队伍吃掉。' 16 | out = model.sample_generate_encoder_decoder(text) 17 | print(out) 18 | 19 | # 加载自己t5代码 20 | from bert_seq2seq.t5_ch import T5Model 21 | vocab_path = "./state_dict/t5-chinese/vocab.txt" 22 | model = T5Model(vocab_path, size="base") 23 | model.set_device(device) 24 | model.load_pretrain_params("./state_dict/t5-chinese/pytorch_model.bin") 25 | model.eval() 26 | text = '从那之后,一发不可收拾。此后的近百年间,一共有十七位新娘在与君山一带失踪。有时十几年相安无事,有时短短一个月内失踪两名。一个恐怖传说迅速传开:与君山里住着一位鬼新郎,若是他看中了一位女子,便会在她出嫁的路上将她掳走,再把送亲的队伍吃掉。' 27 | out = model.sample_generate_encoder_decoder(text) 28 | print(out) 29 | 30 | 31 | -------------------------------------------------------------------------------- /test/test_paddle/bert_couplet_test_paddle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/home/bert_seq2seq/paddle_model") 3 | import paddle 4 | import numpy as np 5 | from paddle import nn 6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, BPETokenizer, CTRLTokenizer 7 | from paddlenlp.transformers.bert.modeling import BertSeq2Seq 8 | import paddle.nn.functional as F 9 | from tqdm import tqdm 10 | from paddle.io import Dataset 11 | 12 | 13 | class Predictor: 14 | 15 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ): 16 | self.out_max_length = out_max_length 17 | self.beam_size = beam_size 18 | self.max_length = max_length 19 | self.tokenizer = tokenizer 20 | self.model = model 21 | 22 | def generate(self, text): 23 | self.model.eval() 24 | input_max_length = self.max_length - self.out_max_length 25 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length) 26 | vocab = self.tokenizer.vocab 27 | token_ids = tokenizer_out["input_ids"] 28 | token_type_ids = tokenizer_out["token_type_ids"] 29 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1]) 30 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1]) 31 | 32 | # print(f"token_ids is {token_ids}") 33 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size) 34 | # print(out_puts_ids) 35 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids) 36 | 37 | return self.tokenizer.convert_tokens_to_string(tokens) 38 | 39 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1,): 40 | """ 41 | beam-search操作 42 | """ 43 | sep_id = word2ix["[SEP]"] 44 | 45 | # 用来保存输出序列 46 | # output_ids = paddle.empty([1, 0]).astype("int") 47 | output_ids = None 48 | # output_ids = np.empty([1, 0]).astype(np.int) 49 | # 用来保存累计得分 50 | with paddle.no_grad(): 51 | output_scores = np.zeros([token_ids.shape[0]]) 52 | for step in range(self.out_max_length): 53 | if step == 0: 54 | scores = self.model(token_ids, token_type_ids) 55 | # 重复beam-size次 输入ids 56 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1]) 57 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1]) 58 | else: 59 | scores = self.model(new_input_ids, new_token_type_ids) 60 | 61 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy() 62 | 63 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分 64 | ## 取topk的时候我们是展平了然后再去调用topk函数 65 | # 展平 66 | logit_score = logit_score.reshape([-1]) 67 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:] 68 | hype_score = logit_score[hype_pos] 69 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引 70 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引 71 | 72 | output_scores = hype_score 73 | if output_ids is None: 74 | output_ids = indice2.reshape([beam_size, 1]) 75 | else : 76 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int) 77 | 78 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1) 79 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1) 80 | 81 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记 82 | best_one = output_scores.argmax() 83 | if end_counts[best_one] == 1: 84 | # 说明出现终止了~ 85 | return output_ids[best_one][:-1] 86 | else : 87 | # 保留未完成部分 88 | flag = (end_counts < 1) # 标记未完成序列 89 | if not flag.all(): # 如果有已完成的 90 | token_ids = token_ids[flag] 91 | token_type_ids = token_type_ids[flag] 92 | new_input_ids = new_input_ids[flag] 93 | new_token_type_ids = new_token_type_ids[flag] 94 | output_ids = output_ids[flag] # 扔掉已完成序列 95 | output_scores = output_scores[flag] # 扔掉已完成序列 96 | beam_size = flag.sum() # topk相应变化 97 | 98 | return output_ids[output_scores.argmax()] 99 | 100 | 101 | 102 | 103 | def returnForecast(data,modelAddress = None,tokenizerAddress = None): 104 | 105 | if modelAddress is None: 106 | model = BertSeq2Seq.from_pretrained('./model') 107 | 108 | else: 109 | model = BertSeq2Seq.from_pretrained(modelAddress) 110 | 111 | 112 | if tokenizerAddress is None: 113 | tokenizer = BertTokenizer.from_pretrained('./tokenizer') 114 | 115 | else: 116 | tokenizer = BertTokenizer.from_pretrained(tokenizerAddress) 117 | 118 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512) 119 | 120 | 121 | import collections 122 | 123 | 124 | OutCouplet = {} 125 | 126 | OutCouplet = collections.OrderedDict() 127 | 128 | for simply_in in data: 129 | 130 | out = predictor.generate(simply_in) 131 | OutCouplet[simply_in] = out 132 | 133 | return OutCouplet 134 | 135 | def forecastForCouplet(data,modelAddress = None,tokenizerAddress = None): 136 | 137 | Couplet = returnForecast(data,modelAddress,tokenizerAddress) 138 | 139 | for Uplink,Downlink in Couplet.items(): 140 | 141 | print(f"上联:{Uplink},下联:{Downlink}。") 142 | 143 | 144 | 145 | 146 | 147 | if __name__ == '__main__': 148 | 149 | test_data = ["床前明月光", "万里悲秋常作客","广汉飞霞诗似玉", "执政为民,德行天下","春回大地万事新"] 150 | 151 | forecastForCouplet(test_data) 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /test/test_paddle/roberta_autotitle_test_paddle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/home/bert_seq2seq/paddle_model") 3 | import paddle 4 | import numpy as np 5 | from paddle import nn 6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, CTRLTokenizer,RobertaTokenizer,RobertaModel 7 | from paddlenlp.transformers.roberta.modeling import robertaSeq2Seq 8 | from paddlenlp.transformers.roberta.modeling import RobertaPretrainedModel 9 | import paddle.nn.functional as F 10 | from tqdm import tqdm 11 | from paddle.io import Dataset 12 | 13 | class Predictor: 14 | 15 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ): 16 | self.out_max_length = out_max_length 17 | self.beam_size = beam_size 18 | self.max_length = max_length 19 | self.tokenizer = tokenizer 20 | self.model = model 21 | 22 | def generate(self, text): 23 | self.model.eval() 24 | input_max_length = self.max_length - self.out_max_length 25 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length) 26 | vocab = self.tokenizer.vocab 27 | token_ids = tokenizer_out["input_ids"] 28 | token_type_ids = tokenizer_out["token_type_ids"] 29 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1]) 30 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1]) 31 | 32 | # print(f"token_ids is {token_ids}") 33 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size) 34 | # print(out_puts_ids) 35 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids) 36 | 37 | return self.tokenizer.convert_tokens_to_string(tokens) 38 | 39 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, ): 40 | """ 41 | beam-search操作 42 | """ 43 | sep_id = word2ix["[SEP]"] 44 | 45 | # 用来保存输出序列 46 | # output_ids = paddle.empty([1, 0]).astype("int") 47 | output_ids = None 48 | # output_ids = np.empty([1, 0]).astype(np.int) 49 | # 用来保存累计得分 50 | with paddle.no_grad(): 51 | output_scores = np.zeros([token_ids.shape[0]]) 52 | for step in range(self.out_max_length): 53 | if step == 0: 54 | scores = self.model(token_ids, token_type_ids) 55 | # 重复beam-size次 输入ids 56 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1]) 57 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1]) 58 | else: 59 | scores = self.model(new_input_ids, new_token_type_ids) 60 | 61 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy() 62 | 63 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分 64 | ## 取topk的时候我们是展平了然后再去调用topk函数 65 | # 展平 66 | logit_score = logit_score.reshape([-1]) 67 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:] 68 | hype_score = logit_score[hype_pos] 69 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引 70 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引 71 | 72 | output_scores = hype_score 73 | if output_ids is None: 74 | output_ids = indice2.reshape([beam_size, 1]) 75 | else: 76 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int) 77 | 78 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1) 79 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1) 80 | 81 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记 82 | best_one = output_scores.argmax() 83 | if end_counts[best_one] == 1: 84 | # 说明出现终止了~ 85 | return output_ids[best_one][:-1] 86 | else: 87 | # 保留未完成部分 88 | flag = (end_counts < 1) # 标记未完成序列 89 | if not flag.all(): # 如果有已完成的 90 | token_ids = token_ids[flag] 91 | token_type_ids = token_type_ids[flag] 92 | new_input_ids = new_input_ids[flag] 93 | new_token_type_ids = new_token_type_ids[flag] 94 | output_ids = output_ids[flag] # 扔掉已完成序列 95 | output_scores = output_scores[flag] # 扔掉已完成序列 96 | beam_size = flag.sum() # topk相应变化 97 | 98 | return output_ids[output_scores.argmax()] 99 | 100 | def returnForecast(data,modelAddress = None,tokenizerAddress = None): 101 | 102 | if modelAddress is None: 103 | model = robertaSeq2Seq.from_pretrained('./model') 104 | 105 | else: 106 | model = robertaSeq2Seq.from_pretrained(modelAddress) 107 | 108 | 109 | if tokenizerAddress is None: 110 | tokenizer = RobertaTokenizer.from_pretrained('./tokenizer') 111 | 112 | else: 113 | tokenizer = RobertaTokenizer.from_pretrained(tokenizerAddress) 114 | 115 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512) 116 | 117 | 118 | import collections 119 | 120 | 121 | OutCouplet = {} 122 | 123 | OutCouplet = collections.OrderedDict() 124 | 125 | for simply_in in data: 126 | 127 | out = predictor.generate(simply_in) 128 | OutCouplet[simply_in] = out 129 | 130 | return OutCouplet 131 | 132 | def forecastForAutoTitle(data,modelAddress = None,tokenizerAddress = None): 133 | 134 | Couplet = returnForecast(data,modelAddress,tokenizerAddress) 135 | 136 | for Uplink,Downlink in Couplet.items(): 137 | 138 | print(f"新闻:{Uplink},标题:{Downlink}。") 139 | 140 | 141 | 142 | 143 | 144 | if __name__ == '__main__': 145 | 146 | test_data = [ 147 | "本文总结了十个可穿戴产品的设计原则而这些原则同样也是笔者认为是这个行业最吸引人的地方1为人们解决重复性问题2从人开始而不是从机器开始3要引起注意但不要刻意4提升用户能力而不是取代人", 148 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献", 149 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"] 150 | 151 | forecastForAutoTitle(test_data) 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /test/test_paddle/roberta_math_test_paddle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/home/bert_seq2seq/paddle_model") 3 | import paddle 4 | import numpy as np 5 | from paddle import nn 6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, CTRLTokenizer, RobertaTokenizer, \ 7 | RobertaModel 8 | from paddlenlp.transformers.roberta.modeling import robertaSeq2Seq 9 | from paddlenlp.transformers.roberta.modeling import RobertaPretrainedModel 10 | import paddle.nn.functional as F 11 | from tqdm import tqdm 12 | from paddle.io import Dataset 13 | import json 14 | import re 15 | 16 | 17 | 18 | class Predictor: 19 | 20 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ): 21 | self.out_max_length = out_max_length 22 | self.beam_size = beam_size 23 | self.max_length = max_length 24 | self.tokenizer = tokenizer 25 | self.model = model 26 | 27 | def generate(self, text): 28 | self.model.eval() 29 | input_max_length = self.max_length - self.out_max_length 30 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length) 31 | vocab = self.tokenizer.vocab 32 | token_ids = tokenizer_out["input_ids"] 33 | token_type_ids = tokenizer_out["token_type_ids"] 34 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1]) 35 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1]) 36 | 37 | # print(f"token_ids is {token_ids}") 38 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size) 39 | # print(out_puts_ids) 40 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids) 41 | 42 | return self.tokenizer.convert_tokens_to_string(tokens) 43 | 44 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1,): 45 | """ 46 | beam-search操作 47 | """ 48 | sep_id = word2ix["[SEP]"] 49 | 50 | # 用来保存输出序列 51 | # output_ids = paddle.empty([1, 0]).astype("int") 52 | output_ids = None 53 | # output_ids = np.empty([1, 0]).astype(np.int) 54 | # 用来保存累计得分 55 | with paddle.no_grad(): 56 | output_scores = np.zeros([token_ids.shape[0]]) 57 | for step in range(self.out_max_length): 58 | if step == 0: 59 | scores = self.model(token_ids, token_type_ids) 60 | # 重复beam-size次 输入ids 61 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1]) 62 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1]) 63 | else: 64 | scores = self.model(new_input_ids, new_token_type_ids) 65 | 66 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy() 67 | 68 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分 69 | ## 取topk的时候我们是展平了然后再去调用topk函数 70 | # 展平 71 | logit_score = logit_score.reshape([-1]) 72 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:] 73 | hype_score = logit_score[hype_pos] 74 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引 75 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引 76 | 77 | output_scores = hype_score 78 | if output_ids is None: 79 | output_ids = indice2.reshape([beam_size, 1]) 80 | else : 81 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int) 82 | 83 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1) 84 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1) 85 | 86 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记 87 | best_one = output_scores.argmax() 88 | if end_counts[best_one] == 1: 89 | # 说明出现终止了~ 90 | return output_ids[best_one][:-1] 91 | else : 92 | # 保留未完成部分 93 | flag = (end_counts < 1) # 标记未完成序列 94 | if not flag.all(): # 如果有已完成的 95 | token_ids = token_ids[flag] 96 | token_type_ids = token_type_ids[flag] 97 | new_input_ids = new_input_ids[flag] 98 | new_token_type_ids = new_token_type_ids[flag] 99 | output_ids = output_ids[flag] # 扔掉已完成序列 100 | output_scores = output_scores[flag] # 扔掉已完成序列 101 | beam_size = flag.sum() # topk相应变化 102 | 103 | return output_ids[output_scores.argmax()] 104 | 105 | 106 | 107 | 108 | 109 | 110 | def returnForecast(data,modelAddress = None,tokenizerAddress = None): 111 | 112 | if modelAddress is None: 113 | model = robertaSeq2Seq.from_pretrained('./model') 114 | 115 | else: 116 | model = robertaSeq2Seq.from_pretrained(modelAddress) 117 | 118 | 119 | if tokenizerAddress is None: 120 | tokenizer = RobertaTokenizer.from_pretrained('./tokenizer') 121 | 122 | else: 123 | tokenizer = RobertaTokenizer.from_pretrained(tokenizerAddress) 124 | 125 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512) 126 | 127 | 128 | import collections 129 | 130 | 131 | OutCouplet = {} 132 | 133 | OutCouplet = collections.OrderedDict() 134 | 135 | for simply_in in data: 136 | 137 | out = predictor.generate(simply_in) 138 | OutCouplet[simply_in] = out 139 | 140 | return OutCouplet 141 | 142 | def forecastForMath(data,modelAddress = None,tokenizerAddress = None): 143 | 144 | Couplet = returnForecast(data,modelAddress,tokenizerAddress) 145 | 146 | for Uplink,Downlink in Couplet.items(): 147 | 148 | print(f"问题:{Uplink},算式:{Downlink}。") 149 | 150 | 151 | 152 | 153 | 154 | if __name__ == '__main__': 155 | 156 | test_data = [ 157 | "王艳家买了一台洗衣机和一台电冰箱,一共花了6000元,电冰箱的价钱是洗衣机的3/5,求洗衣机的价钱.", 158 | "六1班原来男生占总数的2/5,又转来5名男生,现在男生占总数的5/11,女生有多少人?", 159 | "两个相同的数相乘,积是3600,这个数是多少."] 160 | 161 | forecastForMath(test_data) 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /test/做数学题_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 6 | 7 | model_name = "roberta" # 选择模型名字 8 | model_path = "./state_dict/bert_math_ques_model.bin" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | if __name__ == "__main__": 12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 13 | model_name = "roberta" # 选择模型名字 14 | # 加载字典 15 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 16 | tokenizer = Tokenizer(word2idx) 17 | # 定义模型 18 | bert_model = load_bert(word2idx, model_name=model_name, model_class="seq2seq") 19 | bert_model.set_device(device) 20 | bert_model.eval() 21 | ## 加载训练的模型参数~ 22 | bert_model.load_all_params(model_path=model_path, device=device) 23 | test_data = ["王艳家买了一台洗衣机和一台电冰箱,一共花了6000元,电冰箱的价钱是洗衣机的3/5,求洗衣机的价钱.", 24 | "六1班原来男生占总数的2/5,又转来5名男生,现在男生占总数的5/11,女生有多少人?", 25 | "两个相同的数相乘,积是3600,这个数是多少.", 26 | "1加1等于几"] 27 | for text in test_data: 28 | with torch.no_grad(): 29 | print(bert_model.generate(text, beam_size=3)) -------------------------------------------------------------------------------- /test/新闻标题文本分类_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | target = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"] 6 | 7 | cls_model = "./state_dict/bert_multi_classify_model.bin" 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | if __name__ == "__main__": 11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 12 | model_name = "roberta" # 选择模型名字 13 | # 加载字典 14 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 15 | tokenizer = Tokenizer(word2idx) 16 | # 定义模型 17 | bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target)) 18 | bert_model.set_device(device) 19 | bert_model.eval() 20 | ## 加载训练的模型参数~ 21 | bert_model.load_all_params(model_path=cls_model, device=device) 22 | test_data = ["编剧梁馨月讨稿酬六六何念助阵 公司称协商解决", 23 | "西班牙BBVA第三季度净利降至15.7亿美元", 24 | "基金巨亏30亿 欲打开云天系跌停自救"] 25 | for text in test_data: 26 | with torch.no_grad(): 27 | text, text_ids = tokenizer.encode(text) 28 | text = torch.tensor(text, device=device).view(1, -1) 29 | print(target[torch.argmax(bert_model(text)).item()]) 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /test/粗粒度ner_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | target = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"] 6 | 7 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 8 | 9 | model_name = "roberta" # 选择模型名字 10 | model_path = "./state_dict/bert_粗粒度ner_crf.bin" 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | def viterbi_decode(nodes, trans): 14 | """ 15 | 维特比算法 解码 16 | nodes: (seq_len, target_size) 17 | trans: (target_size, target_size) 18 | """ 19 | with torch.no_grad(): 20 | scores = nodes[0] 21 | scores[1:] -= 100000 # 刚开始标签肯定是"O" 22 | target_size = nodes.shape[1] 23 | seq_len = nodes.shape[0] 24 | labels = torch.arange(0, target_size).view(1, -1) 25 | path = labels 26 | for l in range(1, seq_len): 27 | scores = scores.view(-1, 1) 28 | M = scores + trans + nodes[l].view(1, -1) 29 | scores, ids = M.max(0) 30 | path = torch.cat((path[:, ids], labels), dim=0) 31 | # print(scores) 32 | # print(scores) 33 | return path[:, scores.argmax()] 34 | 35 | def ner_print(model, test_data, device="cpu"): 36 | model.eval() 37 | idxtword = {v: k for k, v in word2idx.items()} 38 | 39 | tokenier = Tokenizer(word2idx) 40 | trans = model.state_dict()["crf_layer.trans"] 41 | for text in test_data: 42 | decode = [] 43 | text_encode, text_ids = tokenier.encode(text) 44 | 45 | text_tensor = torch.tensor(text_encode, device=device).view(1, -1) 46 | out = model(text_tensor).squeeze(0) # 其实是nodes 47 | labels = viterbi_decode(out, trans) 48 | starting = False 49 | for l in labels: 50 | if l > 0: 51 | label = target[l.item()] 52 | if label[0] == "B": 53 | decode.append(label[2: ]) 54 | starting = True 55 | elif starting: 56 | decode.append(label[2: ]) 57 | else: 58 | starting = False 59 | decode.append("O") 60 | else : 61 | decode.append("O") 62 | flag = 0 63 | 64 | res = {} 65 | text_decode = [idxtword[i] for i in text_encode] 66 | for index, each_entity in enumerate(decode): 67 | if each_entity != "O": 68 | if flag != each_entity: 69 | # cur_text = "".join([text[t] for t in mapping[index]]) 70 | cur_text = text_decode[index] 71 | if each_entity in res.keys(): 72 | res[each_entity].append(cur_text) 73 | else : 74 | res[each_entity] = [cur_text] 75 | flag = each_entity 76 | elif flag == each_entity: 77 | res[each_entity][-1] += text_decode[index] 78 | # res[each_entity][-1] += "".join([text[t] for t in mapping[index]]) 79 | else : 80 | flag = 0 81 | print(res) 82 | 83 | if __name__ == "__main__": 84 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 85 | model_name = "roberta" # 选择模型名字 86 | # 加载字典 87 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 88 | tokenizer = Tokenizer(word2idx) 89 | # 定义模型 90 | bert_model = load_bert(word2idx, model_name=model_name, model_class="sequence_labeling_crf", target_size=len(target)) 91 | bert_model.set_device(device) 92 | bert_model.eval() 93 | ## 加载训练的模型参数~ 94 | bert_model.load_all_params(model_path=model_path, device=device) 95 | test_data = ["日寇在京掠夺文物详情。", 96 | "以书结缘,把欧美,港台流行的食品类食谱汇集一堂。", 97 | "明天天津下雨,不知道杨永康主任还能不能来学校吃个饭。", 98 | "美国的华莱士,我和他谈笑风生", 99 | "看包公断案的戏" 100 | ] 101 | ner_print(bert_model, test_data, device=device) 102 | -------------------------------------------------------------------------------- /test/细粒度ner_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 3 | from bert_seq2seq import load_bert 4 | 5 | target = ["other", "address", "book", "company", "game", "government", "movie", "name", "organization", "position", "scene"] 6 | 7 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 8 | 9 | model_name = "roberta" # 选择模型名字 10 | model_path = "./state_dict/细粒度_bert_ner_model_crf.bin" 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | def viterbi_decode(nodes, trans): 14 | """ 15 | 维特比算法 解码 16 | nodes: (seq_len, target_size) 17 | trans: (target_size, target_size) 18 | """ 19 | with torch.no_grad(): 20 | scores = nodes[0] 21 | scores[1:] -= 100000 # 刚开始标签肯定是"O" 22 | target_size = nodes.shape[1] 23 | seq_len = nodes.shape[0] 24 | labels = torch.arange(0, target_size).view(1, -1) 25 | path = labels 26 | for l in range(1, seq_len): 27 | scores = scores.view(-1, 1) 28 | M = scores + trans + nodes[l].view(1, -1) 29 | scores, ids = M.max(0) 30 | path = torch.cat((path[:, ids], labels), dim=0) 31 | 32 | return path[:, scores.argmax()] 33 | 34 | def ner_print(model, test_data, device="cpu"): 35 | model.eval() 36 | idxtword = {v: k for k, v in word2idx.items()} 37 | tokenier = Tokenizer(word2idx) 38 | trans = model.state_dict()["crf_layer.trans"] 39 | for text in test_data: 40 | decode = [] 41 | text_encode, _ = tokenier.encode(text) 42 | text_tensor = torch.tensor(text_encode, device=device).view(1, -1) 43 | out = model(text_tensor).squeeze(0) # 其实是nodes 44 | labels = viterbi_decode(out, trans) 45 | starting = False 46 | for l in labels: 47 | if l > 0: 48 | label = target[l.item()] 49 | decode.append(label) 50 | else : 51 | decode.append("other") 52 | flag = 0 53 | res = {} 54 | # print(decode) 55 | # print(text) 56 | decode_text = [idxtword[i] for i in text_encode] 57 | for index, each_entity in enumerate(decode): 58 | if each_entity != "other": 59 | if flag != each_entity: 60 | # cur_text = "".join([text[t] for t in mapping[index]]) 61 | 62 | cur_text = decode_text[index] 63 | if each_entity in res.keys(): 64 | res[each_entity].append(cur_text) 65 | else : 66 | res[each_entity] = [cur_text] 67 | flag = each_entity 68 | elif flag == each_entity: 69 | res[each_entity][-1] += decode_text[index] 70 | else : 71 | flag = 0 72 | print(res) 73 | 74 | 75 | if __name__ == "__main__": 76 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置 77 | model_name = "roberta" # 选择模型名字 78 | # 加载字典 79 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False) 80 | tokenizer = Tokenizer(word2idx) 81 | # 定义模型 82 | bert_model = load_bert(word2idx, model_name=model_name, model_class="sequence_labeling_crf", target_size=len(target)) 83 | bert_model.set_device(device) 84 | bert_model.eval() 85 | ## 加载训练的模型参数~ 86 | bert_model.load_all_params(model_path=model_path, device=device) 87 | # test_data = ["在广州经营小古董珠宝店的潘凝已经收藏了200多款泰迪熊,其中不少更是老牌泰迪熊厂商史蒂夫、赫曼。", 88 | # "2009年1月,北京市长郭金龙在其政府工作报告中曾明确提出,限价房不停建", 89 | # "昨天,记者连线农业银行亳州市支行办公室主任沈伦,他表示,亳州市支行已经对此事进行了讨论和研究", 90 | # "他们又有会怎样的读书经历。曾经留学海外的香港《号外杂志》主编、著名城市文化学者和作家陈冠中先生" 91 | # ] 92 | test_data = ["曹操南征荆州,刘表之子刘琮投降,刘备领军民十余万避难,于当阳遭遇曹军追兵,惨败。", 93 | "赵哲妮(曾用名赵哲),女,1953年11月19日出生,汉族,初中文化程度,户籍所在地河南省漯河市源汇区。"] 94 | ner_print(bert_model, test_data, device=device) --------------------------------------------------------------------------------