├── .DS_Store
├── .gitignore
├── .idea
├── .gitignore
├── bert_seq2seq.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── bert_seq2seq
├── .DS_Store
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── bart_chinese.cpython-37.pyc
│ ├── basic_bert.cpython-37.pyc
│ ├── bert_cls_classifier.cpython-37.pyc
│ ├── bert_cls_multi_classifier.cpython-37.pyc
│ ├── bert_cls_multi_seq2seq.cpython-37.pyc
│ ├── bert_encoder.cpython-37.pyc
│ ├── bert_relation_extraction.cpython-37.pyc
│ ├── bert_seq_labeling.cpython-37.pyc
│ ├── bert_seq_labeling_crf.cpython-37.pyc
│ ├── config.cpython-37.pyc
│ ├── extend_model_method.cpython-37.pyc
│ ├── gpt2_generate_model.cpython-37.pyc
│ ├── helper.cpython-37.pyc
│ ├── seq2seq_model.cpython-37.pyc
│ ├── simbert_model.cpython-37.pyc
│ ├── t5_ch.cpython-37.pyc
│ ├── tokenizer.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── bart_chinese.py
├── basic_bert.py
├── bert_cls_classifier.py
├── bert_cls_multi_classifier.py
├── bert_cls_multi_seq2seq.py
├── bert_relation_extraction.py
├── bert_seq_labeling.py
├── bert_seq_labeling_crf.py
├── config.py
├── dataset.py
├── extend_model_method.py
├── gpt2_generate_model.py
├── helper.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── bart_model.cpython-37.pyc
│ │ ├── bert_model.cpython-37.pyc
│ │ ├── crf.cpython-37.pyc
│ │ ├── gpt2_model.cpython-37.pyc
│ │ ├── nezha_model.cpython-37.pyc
│ │ ├── roberta_model.cpython-37.pyc
│ │ └── t5_model.cpython-37.pyc
│ ├── bart_model.py
│ ├── bert_model.py
│ ├── crf.py
│ ├── gpt2_model.py
│ ├── nezha_model.py
│ ├── roberta_model.py
│ └── t5_model.py
├── paddle_model
│ ├── data
│ │ ├── __init__.py
│ │ ├── collate.py
│ │ ├── iterator.py
│ │ ├── sampler.py
│ │ ├── tokenizer.py
│ │ └── vocab.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── attention_utils.py
│ │ ├── bert
│ │ │ ├── __init__.py
│ │ │ ├── modeling.py
│ │ │ └── tokenizer.py
│ │ ├── generation_utils.py
│ │ ├── gpt
│ │ │ ├── __init__.py
│ │ │ ├── modeling.py
│ │ │ └── tokenizer.py
│ │ ├── model_utils.py
│ │ ├── nezha
│ │ │ ├── __init__.py
│ │ │ ├── modeling.py
│ │ │ └── tokenizer.py
│ │ ├── optimization.py
│ │ ├── roberta
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── modeling.py
│ │ │ └── tokenizer.py
│ │ ├── tokenizer_utils.py
│ │ └── utils.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── batch_sampler.py
│ │ ├── downloader.py
│ │ ├── env.py
│ │ ├── log.py
│ │ ├── profiler.py
│ │ └── tools.py
├── seq2seq_model.py
├── simbert_model.py
├── t5_ch.py
├── tokenizer.py
└── utils.py
├── examples
├── .DS_Store
├── bart_auto_title_train.py
├── examples_paddle
│ ├── bert_couplet_paddle_train.py
│ ├── mBart_translation_en_ro.py
│ ├── roberta_autotitle_paddle_train.py
│ └── roberta_math_paddle_train.py
├── gpt2_ancient_translation_train.py
├── gpt2_english_story_train.py
├── gpt2_explain_dream_train.py
├── gpt2_generate_article.py
├── nezha_auto_title_train.py
├── nezha_couplets_train.py
├── nezha_relation_extract_train.py
├── readme.md
├── relationship_classify_train.py
├── roberta_THUCNews_auto_title.py
├── roberta_auto_title_train.py
├── roberta_coarsness_NER_CRF_train.py
├── roberta_coarsness_NER_train.py
├── roberta_couplets_train.py
├── roberta_fine_grained_NER_CRF_train.py
├── roberta_large_auto_article_gen.py
├── roberta_large_auto_title_train.py
├── roberta_math_ques_train.py
├── roberta_medical_ner_train.py
├── roberta_news_classification_train.py
├── roberta_participle_CRF_train.py
├── roberta_poem_train.py
├── roberta_relation_extract_train.py
├── roberta_semantic_matching_train.py
├── simbert_train.py
├── t5_ancient_translation_train.py
└── t5_auto_title_train.py
├── img
├── .DS_Store
├── fenci.png
├── ner-input.png
├── ner-out.png
└── ner.jpg
├── setup.py
└── test
├── auto_title_test.py
├── bert_english_autotitle_test.py
├── english_t5_test.py
├── get_bert_embedding.py
├── gpt_ancient_translation_test.py
├── gpt_english_story_test.py
├── gpt_explain_dream_test.py
├── gpt_test_english.py
├── nezha_auto_title_test.py
├── nezha_relation_extract_test.py
├── poem_test.py
├── relation_extract_test.py
├── semantic_matching_test.py
├── t5_chinese_autotitle_test.py
├── t5_chinese_test.py
├── test_paddle
├── bert_couplet_test_paddle.py
├── roberta_autotitle_test_paddle.py
└── roberta_math_test_paddle.py
├── 做数学题_test.py
├── 新闻标题文本分类_test.py
├── 粗粒度ner_test.py
└── 细粒度ner_test.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | nouse
2 | corpus
3 | dist
4 | 写文章
5 | 代码
6 | state_dict
7 | bert_seq2seq.egg-info
8 | .idea
9 | .vscode
10 |
11 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/bert_seq2seq.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/Users/xingzhaohu/.local/share/virtualenvs/ml-5foBrNl9/bin/python",
3 | "git.ignoreLimitWarning": true
4 | }
--------------------------------------------------------------------------------
/bert_seq2seq/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/.DS_Store
--------------------------------------------------------------------------------
/bert_seq2seq/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokenizer import load_chinese_base_vocab, Tokenizer
2 | from .utils import load_bert, load_gpt
3 | from .t5_ch import T5PegasusTokenizer, T5Model
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bart_chinese.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bart_chinese.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/basic_bert.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/basic_bert.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_cls_classifier.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_classifier.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_cls_multi_classifier.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_multi_classifier.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_cls_multi_seq2seq.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_cls_multi_seq2seq.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_relation_extraction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_relation_extraction.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_seq_labeling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_seq_labeling.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/bert_seq_labeling_crf.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/bert_seq_labeling_crf.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/extend_model_method.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/extend_model_method.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/gpt2_generate_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/gpt2_generate_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/helper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/helper.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/seq2seq_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/seq2seq_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/simbert_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/simbert_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/t5_ch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/t5_ch.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/bart_chinese.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq.model.bart_model import BartConfig, BartForConditionalGeneration, BartModel, shift_tokens_right
4 | from bert_seq2seq.tokenizer import Tokenizer,load_chinese_base_vocab
5 | from bert_seq2seq.basic_bert import BasicBart
6 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | class BartGenerationModel(BasicBart):
11 |
12 | def __init__(self, word2idx):
13 | super().__init__()
14 | config = BartConfig()
15 | self.config = config
16 | self.model = BartModel(config)
17 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
18 |
19 | self.word2idx = word2idx
20 | self.tokenizer = Tokenizer(self.word2idx)
21 | self.bos_id = self.word2idx["[CLS]"]
22 | self.eos_id = self.word2idx["[SEP]"]
23 | self.unk_id = self.word2idx["[UNK]"]
24 |
25 | def forward(self, input_ids, decoder_input_ids, labels=None):
26 | input_ids = input_ids.to(self.device)
27 | decoder_input_ids = decoder_input_ids.to(self.device)
28 | if labels is not None:
29 | labels = labels.to(self.device)
30 | if labels is not None:
31 | if decoder_input_ids is None:
32 | decoder_input_ids = shift_tokens_right(
33 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
34 | )
35 |
36 | decoder_out, _ = self.model(
37 | input_ids,
38 | decoder_input_ids=decoder_input_ids,
39 | )
40 |
41 | lm_logits = self.lm_head(decoder_out)
42 | target_mask = (decoder_input_ids > 0).float().view(-1)
43 | masked_lm_loss = None
44 | if labels is not None:
45 | loss_fct = nn.CrossEntropyLoss()
46 | masked_lm_loss = (loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) * target_mask).sum() / target_mask.sum()
47 |
48 | output = (lm_logits,)
49 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
50 |
51 |
52 | def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=True):
53 |
54 | token_out = self.tokenizer.encode(text, max_length=input_max_length)
55 | if len(token_out) == 2:
56 | token_ids = token_out[0]
57 | else:
58 | token_ids = token_out
59 | if not add_eos:
60 | token_ids = token_ids[:-1]
61 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
62 | output_ids = []
63 |
64 | input_decoder_ids = torch.tensor(self.bos_id, device=self.device, dtype=torch.long).view(1, -1)
65 | with torch.no_grad():
66 | for step in range(out_max_length):
67 | scores = self.model(input_ids=token_ids, decoder_input_ids=input_decoder_ids)[0]
68 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
69 | logit_score[self.unk_id] = -float('Inf')
70 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
71 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
72 | if self.eos_id == next_token.item():
73 | break
74 | output_ids.append(next_token.item())
75 | input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1)
76 |
77 | return self.tokenizer.decode(output_ids)
--------------------------------------------------------------------------------
/bert_seq2seq/bert_cls_classifier.py:
--------------------------------------------------------------------------------
1 | ## bert encoder模型
2 | import torch
3 | import torch.nn as nn
4 | from bert_seq2seq.tokenizer import Tokenizer
5 | from bert_seq2seq.basic_bert import BasicBert
6 |
7 | class BertClsClassifier(BasicBert):
8 | """
9 | """
10 | def __init__(self, word2ix, target_size, model_name="roberta"):
11 | super(BertClsClassifier, self).__init__(word2ix=word2ix, model_name=model_name)
12 | self.target_size = target_size
13 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size)
14 |
15 | def compute_loss(self, predictions, labels):
16 | """
17 | 计算loss
18 | predictions: (batch_size, 1)
19 | """
20 | predictions = predictions.view(-1, self.target_size)
21 | labels = labels.view(-1)
22 | loss = nn.CrossEntropyLoss(reduction="mean")
23 | return loss(predictions, labels)
24 |
25 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1):
26 |
27 | text = text.to(self.device)
28 | if position_enc is not None:
29 | position_enc = position_enc.to(self.device)
30 | if labels is not None:
31 | labels = labels.to(self.device)
32 | all_layers, pooled_out = self.bert(text,
33 | output_all_encoded_layers=True)
34 | if use_layer_num != -1:
35 | pooled_out = all_layers[use_layer_num][:, 0]
36 |
37 | predictions = self.final_dense(pooled_out)
38 | if labels is not None:
39 | ## 计算loss
40 | loss = self.compute_loss(predictions, labels)
41 | return predictions, loss
42 | else :
43 | return predictions
44 |
--------------------------------------------------------------------------------
/bert_seq2seq/bert_cls_multi_classifier.py:
--------------------------------------------------------------------------------
1 | ## bert encoder模型
2 | from multiprocessing import pool
3 | import torch
4 | import torch.nn as nn
5 | from bert_seq2seq.tokenizer import Tokenizer
6 | from bert_seq2seq.basic_bert import BasicBert
7 |
8 | class BertClsMultiClassifier(BasicBert):
9 | """
10 | """
11 | def __init__(self, word2ix, target_size, model_name="roberta"):
12 | super(BertClsMultiClassifier, self).__init__(word2ix=word2ix, model_name=model_name)
13 | self.target_size = target_size
14 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size)
15 |
16 | def compute_loss(self, predictions, labels):
17 | """
18 | 计算loss
19 | predictions: (batch_size, 1)
20 | """
21 | # predictions = torch.sigmoid(predictions)
22 | batch_size = predictions.shape[0]
23 | # predictions = predictions.view(-1)
24 | # labels = labels.view(-1)
25 | loss = nn.BCEWithLogitsLoss(reduction="none")
26 | return loss(predictions, labels).sum() / batch_size
27 |
28 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1):
29 |
30 | text = text.to(self.device)
31 | if position_enc is not None:
32 | position_enc = position_enc.to(self.device)
33 | if labels is not None:
34 | labels = labels.to(self.device)
35 | all_layers, pooled_out = self.bert(text,
36 | output_all_encoded_layers=True)
37 |
38 | if use_layer_num != -1:
39 | pooled_out = all_layers[use_layer_num][:, 0]
40 |
41 | predictions = self.final_dense(pooled_out)
42 |
43 | if labels is not None:
44 | ## 计算loss
45 | loss = self.compute_loss(predictions, labels)
46 | return predictions, loss
47 | else :
48 | return predictions
49 |
--------------------------------------------------------------------------------
/bert_seq2seq/bert_relation_extraction.py:
--------------------------------------------------------------------------------
1 | ## bert 关系抽取模型
2 | import torch
3 | import torch.nn as nn
4 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, Tokenizer
5 | from bert_seq2seq.basic_bert import BasicBert
6 |
7 | class BertRelationExtrac(BasicBert):
8 | """
9 | """
10 | def __init__(self, word2ix, predicate_num, model_name="roberta"):
11 | super(BertRelationExtrac, self).__init__(word2ix=word2ix, model_name=model_name)
12 |
13 | self.predicate_num = predicate_num
14 | self.subject_pred_norm = nn.LayerNorm(self.config.hidden_size)
15 | self.subject_pred = nn.Linear(self.config.hidden_size, 2)
16 | self.activation = nn.Sigmoid()
17 | self.object_pred = nn.Linear(self.config.hidden_size, 2 * self.predicate_num)
18 |
19 |
20 | def binary_crossentropy(self, labels, pred):
21 | labels = labels.float()
22 | loss = (-labels) * torch.log(pred) - (1.0 - labels) * torch.log(1.0 - pred)
23 | return loss
24 |
25 | def compute_total_loss(self, subject_pred, object_pred, subject_labels, object_labels):
26 | """
27 | 计算loss
28 | """
29 | subject_loss = self.binary_crossentropy(subject_labels, subject_pred)
30 | subject_loss = torch.mean(subject_loss, dim=2)
31 | subject_loss = (subject_loss * self.target_mask).sum() / self.target_mask.sum()
32 |
33 | object_loss = self.binary_crossentropy(object_labels, object_pred)
34 | object_loss = torch.mean(object_loss, dim=3).sum(dim=2)
35 | object_loss = (object_loss * self.target_mask).sum() / self.target_mask.sum()
36 |
37 | return subject_loss + object_loss
38 |
39 | def extrac_subject(self, output, subject_ids):
40 | ## 抽取subject的向量表征
41 | batch_size = output.shape[0]
42 | hidden_size = output.shape[-1]
43 | start_end = torch.gather(output, index=subject_ids.unsqueeze(-1).expand((batch_size, 2, hidden_size)), dim=1)
44 | subject = torch.cat((start_end[:, 0], start_end[:, 1]), dim=-1)
45 | return subject
46 |
47 | def forward(self, text, subject_ids, position_enc=None, subject_labels=None, object_labels=None, use_layer_num=-1):
48 | if use_layer_num != -1:
49 |
50 | raise Exception("目前 use_layer_num 只支持-1")
51 | # 计算target mask
52 | text = text.to(self.device)
53 | subject_ids = subject_ids.to(self.device)
54 | self.target_mask = (text > 0).float()
55 | enc_layers, _ = self.bert(text,
56 | output_all_encoded_layers=True)
57 |
58 | squence_out = enc_layers[use_layer_num]
59 |
60 | tokens_hidden_state, _ = self.cls(squence_out)
61 |
62 | subject_pred_out = self.subject_pred(self.subject_pred_norm(tokens_hidden_state))
63 |
64 | subject_pred_act = self.activation(subject_pred_out)
65 |
66 | subject_pred_act = subject_pred_act**2
67 |
68 | subject_vec = self.extrac_subject(tokens_hidden_state, subject_ids)
69 | object_layer_norm = self.layer_norm_cond([tokens_hidden_state, subject_vec])
70 | object_pred_out = self.object_pred(object_layer_norm)
71 | object_pred_act = self.activation(object_pred_out)
72 |
73 | object_pred_act = object_pred_act**4
74 |
75 | batch_size, seq_len, target_size = object_pred_act.shape
76 |
77 | object_pred_act = object_pred_act.reshape((batch_size, seq_len, int(target_size/2), 2))
78 | predictions = object_pred_act
79 | if subject_labels is not None and object_labels is not None:
80 | ## 计算loss
81 | subject_labels = subject_labels.to(self.device)
82 | object_labels = object_labels.to(self.device)
83 | loss = self.compute_total_loss(subject_pred_act, object_pred_act, subject_labels, object_labels)
84 | return predictions, loss
85 | else :
86 | return predictions
87 |
88 | def predict_subject(self, text,use_layer_num=-1):
89 | if use_layer_num != -1:
90 |
91 | raise Exception("use_layer_num目前只支持-1")
92 | text = text.to(self.device)
93 |
94 | self.target_mask = (text > 0).float()
95 | enc_layers, _ = self.bert(text, output_all_encoded_layers=True)
96 | squence_out = enc_layers[use_layer_num]
97 | tokens_hidden_state, _ = self.cls(squence_out)
98 | subject_pred_out = self.subject_pred(self.subject_pred_norm(tokens_hidden_state))
99 | subject_pred_act = self.activation(subject_pred_out)
100 |
101 | subject_pred_act = subject_pred_act**2
102 |
103 | # subject_pred_act = (subject_pred_act > 0.5).long()
104 | return subject_pred_act
105 |
106 | def predict_object_predicate(self, text, subject_ids, use_layer_num=-1):
107 | if use_layer_num != -1:
108 |
109 | raise Exception("use_layer_num目前只支持-1")
110 | # 计算target mask
111 | text = text.to(self.device)
112 | subject_ids = subject_ids.to(self.device)
113 |
114 | enc_layers, _ = self.bert(text, output_all_encoded_layers=True)
115 | squence_out = enc_layers[use_layer_num]
116 | tokens_hidden_state, _ = self.cls(squence_out)
117 |
118 | subject_vec = self.extrac_subject(tokens_hidden_state, subject_ids)
119 | object_layer_norm = self.layer_norm_cond([tokens_hidden_state, subject_vec])
120 | object_pred_out = self.object_pred(object_layer_norm)
121 | object_pred_act = self.activation(object_pred_out)
122 |
123 | object_pred_act = object_pred_act**4
124 |
125 | batch_size, seq_len, target_size = object_pred_act.shape
126 | object_pred_act = object_pred_act.view((batch_size, seq_len, int(target_size/2), 2))
127 | predictions = object_pred_act
128 | return predictions
--------------------------------------------------------------------------------
/bert_seq2seq/bert_seq_labeling.py:
--------------------------------------------------------------------------------
1 | ## bert encoder模型
2 | import torch
3 | import torch.nn as nn
4 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, Tokenizer
5 | from bert_seq2seq.basic_bert import BasicBert
6 |
7 | class BertSeqLabeling(BasicBert):
8 | """
9 | """
10 | def __init__(self, word2ix, target_size, model_name="roberta"):
11 | super(BertSeqLabeling, self).__init__(word2ix=word2ix, model_name=model_name)
12 | self.target_size = target_size
13 |
14 |
15 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size)
16 |
17 | def compute_loss(self, predictions, labels):
18 | """
19 | 计算loss
20 | predictions: (batch_size, 1)
21 | """
22 | predictions = predictions.view(-1, self.target_size)
23 | labels = labels.view(-1)
24 | self.target_mask = self.target_mask.view(-1)
25 | loss = nn.CrossEntropyLoss(reduction="none")
26 | return (loss(predictions, labels) * self.target_mask).sum() / self.target_mask.sum()
27 |
28 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1):
29 | if use_layer_num != -1:
30 | if use_layer_num < 0 or use_layer_num > 7:
31 | # 越界
32 | raise Exception("层数选择错误,因为bert base模型共8层,所以参数只只允许0 - 7, 默认为-1,取最后一层")
33 | self.target_mask = (text > 0).float().to(self.device)
34 | text = text.to(self.device)
35 | if position_enc is not None:
36 | position_enc = position_enc.to(self.device)
37 | if labels is not None:
38 | labels = labels.to(self.device)
39 |
40 | enc_layers, _ = self.bert(text,
41 | output_all_encoded_layers=True)
42 | squence_out = enc_layers[use_layer_num]
43 |
44 | tokens_hidden_state, _ = self.cls(squence_out)
45 | predictions = self.final_dense(tokens_hidden_state)
46 | if labels is not None:
47 | ## 计算loss
48 | loss = self.compute_loss(predictions, labels)
49 | return predictions, loss
50 | else :
51 | return predictions
52 |
--------------------------------------------------------------------------------
/bert_seq2seq/bert_seq_labeling_crf.py:
--------------------------------------------------------------------------------
1 | ## bert encoder模型
2 | import torch
3 | import torch.nn as nn
4 | from bert_seq2seq.tokenizer import Tokenizer
5 | from bert_seq2seq.model.crf import CRFLayer
6 | from bert_seq2seq.basic_bert import BasicBert
7 |
8 | class BertSeqLabelingCRF(BasicBert):
9 | """
10 | """
11 | def __init__(self, word2ix, target_size, model_name="roberta"):
12 | super(BertSeqLabelingCRF, self).__init__(word2ix=word2ix, model_name=model_name)
13 | self.target_size = target_size
14 |
15 | self.final_dense = nn.Linear(self.config.hidden_size, self.target_size)
16 | self.crf_layer = CRFLayer(self.target_size)
17 |
18 | def compute_loss(self, predictions, labels):
19 | """
20 | 计算loss
21 | """
22 | loss = self.crf_layer(predictions, labels, self.target_mask)
23 |
24 | return loss.mean()
25 |
26 | def forward(self, text, position_enc=None, labels=None, use_layer_num=-1):
27 | if use_layer_num != -1:
28 | # 越界
29 | raise Exception("use_layer_num目前只支持-1")
30 | # 计算target mask
31 | self.target_mask = (text > 0).float().to(self.device)
32 | text = text.to(self.device)
33 | if position_enc is not None :
34 | position_enc = position_enc.to(self.device)
35 | if labels is not None :
36 | labels = labels.to(self.device)
37 | enc_layers, _ = self.bert(text,
38 | output_all_encoded_layers=True)
39 | squence_out = enc_layers[use_layer_num]
40 |
41 | tokens_hidden_state, _ = self.cls(squence_out)
42 | # print(cls_token)
43 | predictions = self.final_dense(tokens_hidden_state)
44 |
45 | if labels is not None:
46 | ## 计算loss
47 | loss = self.compute_loss(predictions, labels)
48 | return predictions, loss
49 | else :
50 | return predictions
51 |
--------------------------------------------------------------------------------
/bert_seq2seq/config.py:
--------------------------------------------------------------------------------
1 |
2 | max_length = 256
3 |
4 | yayun_list = [
5 | "东同铜桐筒童僮瞳中衷忠虫终戎崇嵩弓躬宫融雄熊穹穷冯风枫丰充隆空公功工攻蒙笼聋珑洪红鸿虹丛翁聪通蓬烘潼胧砻峒螽梦讧冻忡酆恫总侗窿懵庞种盅芎倥艨绒葱匆骢",
6 | "冬农宗钟龙舂松冲容蓉庸封胸雍浓重从逢缝踪茸峰锋烽蛩慵恭供淙侬松凶墉镛佣溶邛共憧喁邕壅纵龚枞脓淞匈汹禺蚣榕彤",
7 | "江扛窗邦缸降双庞逄腔撞幢桩淙豇",
8 | "支枝移为垂吹陂碑奇宜仪皮儿离施知驰池规危夷师姿迟眉悲之芝时诗棋旗辞词期祠基疑姬丝司葵医帷思滋持随痴维卮麋螭麾墀弥慈遗肌脂雌披嬉尸狸炊篱兹差疲茨卑亏蕤陲骑曦歧岐谁斯私窥熙欺疵赀笞羁彝颐资糜饥衰锥姨楣夔涯伊蓍追",
9 | "缁箕椎罴篪萎匙脾坻嶷治骊尸綦怡尼漪累牺饴而鸱推縻璃祁绥逵羲羸肢骐訾狮奇嗤咨堕其睢漓蠡噫馗辎胝鳍蛇陴淇淄丽筛厮氏痍貔比僖贻祺嘻鹂瓷琦嵋怩熹孜台蚩罹魑丕琪耆衰惟剂提禧居栀戏畸椅磁痿离佳虽仔寅委崎隋逶倭黎犁郦",
10 | "微薇晖徽挥韦围帏违霏菲妃绯飞非扉肥腓威畿机几讥矶稀希衣依沂巍归诽痱欷葳颀圻",
11 | "鱼渔初书舒居裾车渠余予誉舆胥狙锄疏蔬梳虚嘘徐猪闾庐驴诸除储如墟与畲疽苴于茹蛆且沮祛蜍榈淤好雎纾躇趄滁屠据匹咀衙涂虑",
12 | "虞愚娱隅刍无芜巫于盂衢儒濡襦须株诛蛛殊瑜榆谀愉腴区驱躯朱珠趋扶符凫雏敷夫肤纡输枢厨俱驹模谟蒲胡湖瑚乎壶狐弧孤辜姑觚菰徒途涂荼图屠奴呼吾七虞梧吴租卢鲈苏酥乌枯都铺禺诬竽吁瞿劬需俞逾觎揄萸臾渝岖镂娄夫孚桴俘迂姝拘摹糊鸪沽呱蛄驽逋舻垆徂孥泸栌嚅蚨诹扶母毋芙喁颅轳句邾洙麸机膜瓠恶芋呕驺喻枸侏龉葫懦帑拊",
13 | "齐蛴脐黎犁梨黧妻萋凄堤低氐诋题提荑缔折篦鸡稽兮奚嵇蹊倪霓西栖犀嘶撕梯鼙批挤迷泥溪圭闺睽奎携畦骊鹂儿",
14 | "佳街鞋牌柴钗差涯阶偕谐骸排乖怀淮豺侪埋霾斋娲蜗娃哇皆喈揩蛙楷槐俳",
15 | "灰恢魁隈回徊枚梅媒煤瑰雷催摧堆陪杯醅嵬推开哀埃台苔该才材财裁来莱栽哉灾猜胎孩虺崔裴培坏垓陔徕皑傀崃诙煨桅唉颏能茴酶偎隗咳",
16 | "真因茵辛新薪晨辰臣人仁神亲申伸绅身宾滨邻鳞麟珍尘陈春津秦频苹颦银垠筠巾民珉缗贫淳醇纯唇伦纶轮沦匀旬巡驯钧均臻榛姻寅彬鹑皴遵循振甄岷谆椿询恂峋莘堙屯呻粼磷辚濒闽豳逡填狺泯洵溱夤荀竣娠纫鄞抡畛嶙斌氤",
17 | "文闻纹云氛分纷芬焚坟群裙君军勤斤筋勋薰曛熏荤耘芸汾氲员欣芹殷昕贲郧雯蕲",
18 | "元原源园猿辕坦烦繁蕃樊翻萱喧冤言轩藩魂浑温孙门尊存蹲敦墩暾屯豚村盆奔论坤昏婚阍痕根恩吞沅媛援爰幡番反埙鸳宛掀昆琨鲲扪荪髡跟垠抡蕴犍袁怨蜿溷昆炖饨臀喷纯",
19 | "寒韩翰丹殚单安难餐滩坛檀弹残干肝竿乾阑栏澜兰看刊丸桓纨端湍酸团抟攒官观冠鸾銮栾峦欢宽盘蟠漫汗郸叹摊奸剜棺钻瘢谩瞒潘胖弁拦完莞獾拌掸萑倌繁曼馒鳗谰洹滦",
20 | "删潸关弯湾还环鹌鬟寰班斑颁般蛮颜菅攀顽山鳏艰闲娴悭孱潺殷扳讪患",
21 | "先前千阡笺天坚肩贤弦烟燕莲怜田填钿年颠巅牵妍研眠渊涓蠲编玄县泉迁仙鲜钱煎然延筵禅蝉缠连联涟篇偏便全宣镌穿川缘鸢铅捐旋娟船涎鞭专圆员乾虔愆骞权拳椽传焉跹溅舷咽零骈阗鹃翩扁平沿诠痊悛荃遄卷挛戋佃滇婵颛犍搴嫣癣澶单竣鄢扇键蜷棉",
22 | "萧箫挑貂刁凋雕迢条跳苕调枭浇聊辽寥撩僚寮尧幺宵消霄绡销超朝潮嚣樵谯骄娇焦蕉椒饶烧遥姚摇谣瑶韶昭招飚标杓镳瓢苗描猫要腰邀乔桥侨妖夭漂飘翘祧佻徼侥哨娆陶橇劭潇骁獠料硝灶鹞钊蛲峤轿荞嘹逍燎憔剽",
23 | "肴巢交郊茅嘲钞包胶爻苞梢蛟庖匏坳敲胞抛鲛崤铙炮哮捎茭淆泡跑咬啁教咆鞘剿刨佼抓姣唠",
24 | "豪毫操髦刀萄猱桃糟漕旄袍挠蒿涛皋号陶翱敖遭篙羔高嘈搔毛艘滔骚韬缫膏牢醪逃槽劳洮叨绸饕骜熬臊涝淘尻挑嚣捞嗥薅咎谣",
25 | "歌多罗河戈阿和波科柯陀娥蛾鹅萝荷过磨螺禾哥娑驼佗沱峨那苛诃珂轲莎蓑梭婆摩魔讹坡颇俄哦呵皤么涡窝茄迦伽磋跎番蹉搓驮献蝌箩锅倭罗嵯锣",
26 | "麻花霞家茶华沙车牙蛇瓜斜邪芽嘉瑕纱鸦遮叉葩奢楂琶衙赊涯夸巴加耶嗟遐笳差蟆蛙虾拿葭茄挝呀枷哑娲爬杷蜗爷芭鲨珈骅娃哇洼畲丫夸裟瘕些桠杈痂哆爹椰咤笆桦划迦揶吾佘",
27 | "阳杨扬香乡光昌堂章张王房芳长塘妆常凉霜藏场央泱鸯秧嫱床方浆觞梁娘庄黄仓皇装殇襄骧相湘箱缃创忘芒望尝偿樯枪坊囊郎唐狂强肠康冈苍匡荒遑行妨棠翔良航倡伥羌庆姜僵缰疆粮穰将墙桑刚祥详洋徉佯粱量羊伤汤鲂樟彰漳璋猖商防",
28 | "筐煌隍凰蝗惶璜廊浪裆沧纲亢吭潢钢丧盲簧忙茫傍汪臧琅当庠裳昂障糖疡锵杭邙赃滂禳攘瓤抢螳踉眶炀阊彭蒋亡殃蔷镶孀搪彷胱磅膀螃八庚更羹盲横觥彭棚亨英瑛烹平评京惊荆明盟鸣荣莹兵卿生甥笙牲檠擎鲸迎行衡耕萌氓宏闳茎莺樱泓橙筝争清情晴精睛菁旌晶盈瀛嬴营婴缨贞成盛城诚呈程声征正轻名令并倾萦琼赓撑瞠枪伧峥猩珩蘅铿嵘丁嘤鹦铮砰绷轰訇瞪侦顷榜抨趟坪请",
29 | "青经泾形刑邢型陉亭庭廷霆蜓停丁宁钉仃馨星腥醒惺娉灵棂龄铃苓伶零玲翎瓴囹聆听厅汀冥溟螟铭瓶屏萍荧萤荥扃町瞑暝",
30 | "蒸承丞惩陵凌绫冰膺鹰应蝇绳渑乘升胜兴缯凭仍兢矜征凝称登灯僧增曾憎层能棱朋鹏弘肱腾滕藤恒冯瞢扔誊",
31 | "尤邮优忧流留榴骝刘由油游猷悠攸牛修羞秋周州洲舟酬仇柔俦畴筹稠邱抽湫遒收鸠不愁休囚求裘球浮谋牟眸矛侯猴喉讴沤鸥瓯楼娄陬偷头投钩沟幽彪疣绸浏瘤犹啾酋售蹂揉搜叟邹貅泅球逑俅蜉桴罘欧搂抠髅蝼兜句妯惆呕缪繇偻篓馗区",
32 | "侵寻浔林霖临针箴斟沈深淫心琴禽擒钦衾吟今襟金音阴岑簪琳琛椹谌忱壬任黔歆禁喑森参淋郴妊湛",
33 | "覃潭谭参骖南男谙庵含涵函岚蚕探贪耽龛堪戡谈甘三酣篮柑惭蓝郯婪庵颔褴澹",
34 | "盐檐廉帘嫌严占髯谦奁纤签瞻蟾炎添兼缣尖潜阎镰粘淹箝甜恬拈暹詹渐歼黔沾苫占崦阉砭",
35 | "咸缄谗衔岩帆衫杉监凡馋芟喃嵌掺搀严"]
--------------------------------------------------------------------------------
/bert_seq2seq/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader, Dataset
3 |
4 |
5 | def gpt_collate_fn(batch):
6 | """
7 | 动态padding, batch为一部分sample
8 | """
9 |
10 | def padding(indice, max_length, pad_idx=0):
11 | """
12 | pad 函数
13 | """
14 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
15 | return torch.tensor(pad_indice)
16 |
17 | token_ids = [data["token_ids"] for data in batch]
18 | max_length = max([len(t) for t in token_ids])
19 |
20 | token_ids_padded = padding(token_ids, max_length)
21 | target_ids_padded = token_ids_padded.clone()
22 | target_ids_padded[target_ids_padded == 0] = -100
23 |
24 | return token_ids_padded, target_ids_padded
25 |
26 | def bert_seq2seq_collate_fn(batch):
27 | """
28 | 动态padding, batch为一部分sample
29 | """
30 |
31 | def padding(indice, max_length, pad_idx=0):
32 | """
33 | pad 函数
34 | """
35 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
36 | return torch.tensor(pad_indice)
37 |
38 | token_ids = [data["token_ids"] for data in batch]
39 | max_length = max([len(t) for t in token_ids])
40 | token_type_ids = [data["token_type_ids"] for data in batch]
41 |
42 | token_ids_padded = padding(token_ids, max_length)
43 | token_type_ids_padded = padding(token_type_ids, max_length)
44 | target_ids_padded = token_ids_padded[:, 1:].contiguous()
45 |
46 | return token_ids_padded, token_type_ids_padded, target_ids_padded
47 |
48 | def bert_cls_collate_fn(batch):
49 | """
50 | 动态padding, batch为一部分sample
51 | """
52 |
53 | def padding(indice, max_length, pad_idx=0):
54 | """
55 | pad 函数
56 | """
57 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
58 | return torch.tensor(pad_indice)
59 |
60 | token_ids = [data["token_ids"] for data in batch]
61 | max_length = max([len(t) for t in token_ids])
62 | token_type_ids = [data["token_type_ids"] for data in batch]
63 | target_ids = [data["target_id"] for data in batch]
64 | target_ids = torch.tensor(target_ids, dtype=torch.long)
65 |
66 | token_ids_padded = padding(token_ids, max_length)
67 | token_type_ids_padded = padding(token_type_ids, max_length)
68 | # target_ids_padded = token_ids_padded[:, 1:].contiguous()
69 |
70 | return token_ids_padded, token_type_ids_padded, target_ids
71 |
72 | def bert_squence_label_collate_fn(batch):
73 | """
74 | 动态padding, batch为一部分sample
75 | """
76 |
77 | def padding(indice, max_length, pad_idx=0):
78 | """
79 | pad 函数
80 | """
81 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
82 | return torch.tensor(pad_indice)
83 |
84 | token_ids = [data["token_ids"] for data in batch]
85 |
86 | max_length = max([len(t) for t in token_ids])
87 | token_type_ids = [data["token_type_ids"] for data in batch]
88 | target_ids = [data["target_id"] for data in batch]
89 |
90 | token_ids_padded = padding(token_ids, max_length)
91 | token_type_ids_padded = padding(token_type_ids, max_length)
92 | # target_ids_padded = token_ids_padded[:, 1:].contiguous()
93 | target_ids_padded = padding(target_ids, max_length)
94 |
95 | return token_ids_padded, token_type_ids_padded, target_ids_padded
96 |
97 | class AbstractDataset(Dataset):
98 |
99 | def __init__(self, model_name, model_class, collate_fn=None) -> None:
100 | super().__init__()
101 | if model_name == "gpt2":
102 | self.collate_fn = gpt_collate_fn
103 | if model_name == "bert" or model_name == "roberta" or model_name == "roberta-large" or model_name == "nezha":
104 | if model_class == "seq2seq":
105 | self.collate_fn = bert_seq2seq_collate_fn
106 | elif model_class == "cls":
107 | self.collate_fn = bert_cls_collate_fn
108 | elif model_class == "sequence_labeling_crf" or model_class == "sequence_labeling":
109 | self.collate_fn = bert_squence_label_collate_fn
110 |
111 |
112 | if collate_fn is not None :
113 | self.collate_fn = collate_fn
114 |
115 |
116 | def __getitem__(self, index):
117 | return NotImplemented
118 |
119 | def __len__(self):
120 | return NotImplemented
--------------------------------------------------------------------------------
/bert_seq2seq/gpt2_generate_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering
5 | from bert_seq2seq.model.gpt2_model import GPT2LMHeadModel, GPT2Config
6 | from bert_seq2seq.basic_bert import BasicGPT
7 | from bert_seq2seq.tokenizer import Tokenizer
8 | import torch.nn.functional as F
9 |
10 | from bert_seq2seq.helper import RepetitionPenaltyLogitsProcessor, TemperatureLogitsProcessor, TopKLogitsProcessor, \
11 | TopPLogitsProcessor, ListProcessor
12 |
13 | class GPT2(BasicGPT):
14 | def __init__(self, word2ix, tokenizer=None,
15 | ):
16 | super().__init__()
17 | self.word2ix = word2ix
18 | if tokenizer is not None:
19 | self.tokenizer = tokenizer
20 | else:
21 | self.tokenizer = Tokenizer(word2ix)
22 | self.config = GPT2Config(len(word2ix))
23 | self.model = GPT2LMHeadModel(self.config)
24 |
25 | def sample_generate(self, text, input_max_length=256, out_max_length=200,
26 | top_k=30, top_p=1.0, add_eos=False, repetition_penalty=1.0,
27 | temperature=1.0):
28 |
29 | lp = [RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty),
30 | TemperatureLogitsProcessor(temperature=temperature),
31 | TopKLogitsProcessor(top_k=top_k),
32 | TopPLogitsProcessor(top_p=top_p)
33 | ]
34 |
35 | self.list_processor = ListProcessor(lp)
36 |
37 | token_ids, _ = self.tokenizer.encode(text, max_length=input_max_length)
38 | if not add_eos:
39 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long)[:-1].view(1, -1)
40 | else:
41 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
42 |
43 | output_ids = []
44 | sep_id = self.word2ix["[SEP]"]
45 | with torch.no_grad():
46 | for step in range(out_max_length):
47 | _, scores = self.model(token_ids)
48 | logit_score = torch.log_softmax(scores[:, -1], dim=-1)
49 | logit_score[:, self.word2ix["[UNK]"]] = -float('Inf')
50 |
51 | filtered_logits = self.list_processor(token_ids, logit_score)
52 |
53 | # filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
54 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
55 | if sep_id == next_token.item():
56 | break
57 | output_ids.append(next_token.item())
58 | token_ids = torch.cat((token_ids, next_token.long()), dim=1)
59 |
60 | return self.tokenizer.decode(np.array(output_ids))
61 |
62 | def sample_generate_once(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, sep="。"):
63 |
64 | token_ids, _ = self.tokenizer.encode(text, max_length=input_max_length)
65 | # 不加任何的开始符号和结束符号,就是输入一句话。
66 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long)[1:-1].view(1, -1)
67 |
68 |
69 | output_ids = []
70 | sep_id = self.word2ix[sep] # 句号结尾
71 | with torch.no_grad():
72 | for step in range(out_max_length):
73 | _, scores = self.model(token_ids)
74 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
75 | logit_score[self.word2ix["[UNK]"]] = -float('Inf')
76 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
77 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
78 | if sep_id == next_token.item():
79 | break
80 | output_ids.append(next_token.item())
81 | token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1)
82 |
83 | return self.tokenizer.decode(np.array(output_ids))
84 |
85 | def sample_generate_english(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=False):
86 |
87 | token_ids = self.tokenizer.encode(text, max_length=input_max_length, truncation=True)
88 | if add_eos:
89 | token_ids = token_ids + [self.word2ix[""]]
90 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
91 | output_ids = []
92 | sep_id = self.word2ix[""]
93 | with torch.no_grad():
94 | for step in range(out_max_length):
95 | _, scores = self.model(token_ids)
96 | # print(scores.shape)
97 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
98 | # print(logit_score.shape)
99 | logit_score[self.word2ix["unk"]] = -float('Inf')
100 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
101 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
102 | if sep_id == next_token.item():
103 | break
104 | # pass
105 | output_ids.append(next_token.item())
106 | token_ids = torch.cat((token_ids, next_token.long().unsqueeze(0)), dim=1)
107 |
108 | return self.tokenizer.decode(output_ids)
109 |
110 |
111 | def _make_causal_mask(self, input_ids_shape: torch.Size):
112 |
113 | bsz, tgt_len = input_ids_shape
114 | mask = torch.full((tgt_len, tgt_len), 0.0).to(self.device)
115 | mask_cond = torch.arange(mask.size(-1)).to(self.device)
116 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1.0)
117 |
118 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
119 |
120 |
121 | def forward(self, x, labels=None):
122 | if labels is not None:
123 | labels = labels.to(self.device)
124 | x = x.to(self.device)
125 | # input_ids = torch.tensor([[1, 2, 3, 5, -100], [4, 5, 6, -100, -100]])
126 | attention_mask = self._make_causal_mask(x.shape)
127 | pad_mask = (labels != -100).float()
128 | attention_mask = attention_mask * pad_mask.unsqueeze(1).unsqueeze(1)
129 |
130 | loss, lm_logit = self.model(x, labels=labels, attention_mask=attention_mask)
131 |
132 | return loss, lm_logit
--------------------------------------------------------------------------------
/bert_seq2seq/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__init__.py
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/bart_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/bart_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/bert_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/bert_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/crf.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/crf.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/gpt2_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/gpt2_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/nezha_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/nezha_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/roberta_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/roberta_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/__pycache__/t5_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/model/__pycache__/t5_model.cpython-37.pyc
--------------------------------------------------------------------------------
/bert_seq2seq/model/crf.py:
--------------------------------------------------------------------------------
1 | ## crf layer
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class CRFLayer(nn.Module):
7 | """
8 | """
9 | def __init__(self, output_dim):
10 | super(CRFLayer, self).__init__()
11 |
12 | self.output_dim = output_dim
13 | self.trans = nn.Parameter(torch.Tensor(output_dim, output_dim))
14 | self.trans.data.uniform_(-0.1, 0.1)
15 |
16 | def compute_loss(self, y_pred, y_true, mask):
17 | """
18 | 计算CRF损失
19 | """
20 | y_pred = y_pred * mask
21 | y_true = y_true * mask
22 | target_score = self.target_score(y_pred, y_true)
23 | log_norm = self.log_norm_step(y_pred, mask)
24 | log_norm = self.logsumexp(log_norm, dim=1)# 计算标量
25 | return log_norm - target_score
26 |
27 | def forward(self, y_pred, y_true, mask):
28 | """
29 | y_true: [[1, 2, 3], [2, 3, 0] ]
30 | mask: [[1, 1, 1], [1, 1, 0]]
31 | """
32 | if y_pred.shape[0] != mask.shape[0] or y_pred.shape[1] != mask.shape[1]:
33 | raise Exception("mask shape is not match to y_pred shape")
34 | mask = mask.reshape((mask.shape[0], mask.shape[1], 1))
35 | mask = mask.float()
36 | y_true = y_true.reshape(y_pred.shape[:-1])
37 | y_true = y_true.long()
38 | y_true_onehot = F.one_hot(y_true, self.output_dim)
39 | y_true_onehot = y_true_onehot.float()
40 |
41 | return self.compute_loss(y_pred, y_true_onehot, mask)
42 |
43 | def target_score(self, y_pred, y_true):
44 | """
45 | 计算状态标签得分 + 转移标签得分
46 | y_true: (batch, seq_len, out_dim)
47 | y_pred: (batch, seq_len, out_dim)
48 | """
49 | # print(y_pred.shape)
50 | # print(y_true.shape)
51 | point_score = torch.einsum("bni,bni->b", y_pred, y_true)
52 | trans_score = torch.einsum("bni,ij,bnj->b", y_true[:, :-1], self.trans, y_true[:, 1: ])
53 |
54 | return point_score + trans_score
55 |
56 | def log_norm_step(self, y_pred, mask):
57 | """
58 | 计算归一化因子Z(X)
59 | """
60 | state = y_pred[:, 0] # 初始Z(X)
61 | y_pred = y_pred[:, 1: ].contiguous()
62 | mask = mask[:, 1:].contiguous()
63 | batch, seq_len, out_dim = y_pred.shape
64 | for t in range(seq_len):
65 | cur_mask = mask[:, t]
66 | state = torch.unsqueeze(state, 2) # (batch, out_dim, 1)
67 | g = torch.unsqueeze(self.trans, 0) # (1, out_dim, out_dim)
68 | outputs = self.logsumexp(state + g, dim=1) # batch, out_dim
69 | outputs = outputs + y_pred[:, t]
70 | outputs = cur_mask * outputs + (1 - cur_mask) * state.squeeze(-1)
71 | state = outputs
72 |
73 | return outputs
74 |
75 | def logsumexp(self, x, dim=None, keepdim=False):
76 | """
77 | 避免溢出
78 | """
79 | if dim is None:
80 | x, dim = x.view(-1), 0
81 | xm, _ = torch.max(x, dim, keepdim=True)
82 | out = xm + torch.log(torch.sum(torch.exp(x - xm), dim=dim, keepdim=True))
83 | return out if keepdim else out.squeeze(dim)
84 |
85 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .collate import *
16 | from .vocab import *
17 | from .sampler import *
18 | from .tokenizer import *
19 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/data/iterator.py:
--------------------------------------------------------------------------------
1 | # Iterator for NLP Dataset
2 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import jieba
16 | from .vocab import Vocab
17 |
18 |
19 | def get_idx_from_word(word, word_to_idx, unk_word):
20 | if word in word_to_idx:
21 | return word_to_idx[word]
22 | return word_to_idx[unk_word]
23 |
24 |
25 | class BaseTokenizer(object):
26 | def __init__(self, vocab):
27 | self.vocab = vocab
28 |
29 | def get_tokenizer(self):
30 | return self.tokenizer
31 |
32 | def cut(self, sentence):
33 | pass
34 |
35 | def encode(self, sentence):
36 | pass
37 |
38 |
39 | class JiebaTokenizer(BaseTokenizer):
40 | """
41 | Constructs a tokenizer based on `jieba `__.
42 | It supports :meth:`cut` method to split the text to tokens, and :meth:`encode`
43 | method to covert text to token ids.
44 |
45 | Args:
46 | vocab(paddlenlp.data.Vocab): An instance of :class:`paddlenlp.data.Vocab`.
47 | """
48 |
49 | def __init__(self, vocab):
50 | super(JiebaTokenizer, self).__init__(vocab)
51 | self.tokenizer = jieba.Tokenizer()
52 | # initialize tokenizer
53 | self.tokenizer.FREQ = {key: 1 for key in self.vocab.token_to_idx.keys()}
54 | self.tokenizer.total = len(self.tokenizer.FREQ)
55 | self.tokenizer.initialized = True
56 |
57 | def cut(self, sentence, cut_all=False, use_hmm=True):
58 | """
59 | The method used to cut the text to tokens.
60 |
61 | Args:
62 | sentence(str): The text that needs to be cuted.
63 | cut_all(bool, optional): Whether to use the full mode. If True,
64 | using full mode that gets all the possible words from the
65 | sentence, which is fast but not accurate. If False, using
66 | accurate mode that attempts to cut the sentence into the most
67 | accurate segmentations, which is suitable for text analysis.
68 | Default: False.
69 | use_hmm(bool, optional): Whether to use the HMM model. Default: True.
70 |
71 | Returns:
72 | list[str]: A list of tokens.
73 |
74 | Example:
75 | .. code-block:: python
76 |
77 | from paddlenlp.data import Vocab, JiebaTokenizer
78 | # The vocab file. The sample file can be downloaded firstly.
79 | # wget https://bj.bcebos.com/paddlenlp/data/senta_word_dict.txt
80 | vocab_file_path = './senta_word_dict.txt'
81 | # Initialize the Vocab
82 | vocab = Vocab.load_vocabulary(
83 | vocab_file_path,
84 | unk_token='[UNK]',
85 | pad_token='[PAD]')
86 | tokenizer = JiebaTokenizer(vocab)
87 |
88 | tokens = tokenizer.cut('我爱你中国')
89 | print(tokens)
90 | # ['我爱你', '中国']
91 | """
92 | return self.tokenizer.lcut(sentence, cut_all, use_hmm)
93 |
94 | def encode(self, sentence, cut_all=False, use_hmm=True):
95 | """
96 | The method used to convert the text to ids. It will firstly call
97 | :meth:`cut` method to cut the text to tokens. Then, convert tokens to
98 | ids using `vocab`.
99 |
100 | Args:
101 | sentence(str): The text that needs to be cuted.
102 | cut_all(bool, optional): Whether to use the full mode. If True,
103 | using full mode that gets all the possible words from the
104 | sentence, which is fast but not accurate. If False, using
105 | accurate mode that attempts to cut the sentence into the most
106 | accurate segmentations, which is suitable for text analysis.
107 | Default: False.
108 | use_hmm(bool, optional): Whether to use the HMM model. Default: True.
109 |
110 | Returns:
111 | list[int]: A list of ids.
112 |
113 | Example:
114 | .. code-block:: python
115 |
116 | from paddlenlp.data import Vocab, JiebaTokenizer
117 | # The vocab file. The sample file can be downloaded firstly.
118 | # wget https://bj.bcebos.com/paddlenlp/data/senta_word_dict.txt
119 | vocab_file_path = './senta_word_dict.txt'
120 | # Initialize the Vocab
121 | vocab = Vocab.load_vocabulary(
122 | vocab_file_path,
123 | unk_token='[UNK]',
124 | pad_token='[PAD]')
125 | tokenizer = JiebaTokenizer(vocab)
126 |
127 | ids = tokenizer.encode('我爱你中国')
128 | print(ids)
129 | # [1170578, 575565]
130 | """
131 | words = self.cut(sentence, cut_all, use_hmm)
132 | return [
133 | get_idx_from_word(word, self.vocab.token_to_idx,
134 | self.vocab.unk_token) for word in words
135 | ]
136 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | # #
3 | # # Licensed under the Apache License, Version 2.0 (the "License");
4 | # # you may not use this file except in compliance with the License.
5 | # # You may obtain a copy of the License at
6 | # #
7 | # # http://www.apache.org/licenses/LICENSE-2.0
8 | # #
9 | # # Unless required by applicable law or agreed to in writing, software
10 | # # distributed under the License is distributed on an "AS IS" BASIS,
11 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # # See the License for the specific language governing permissions and
13 | # # limitations under the License.
14 | #
15 | from .model_utils import PretrainedModel, register_base_model
16 | from .tokenizer_utils import PretrainedTokenizer, BPETokenizer, tokenize_chinese_chars, is_chinese_char, AddedToken
17 | from .attention_utils import create_bigbird_rand_mask_idx_list
18 | #
19 | from .bert.modeling import *
20 | from .bert.tokenizer import *
21 | # from .bert_japanese.tokenizer import *
22 | # from .ernie.modeling import *
23 | # from .ernie.tokenizer import *
24 | from .gpt.modeling import *
25 | from .gpt.tokenizer import *
26 | from .roberta.modeling import *
27 | from .roberta.tokenizer import *
28 | # from .electra.modeling import *
29 | # from .electra.tokenizer import *
30 | # from .transformer.modeling import *
31 | # from .ernie_gen.modeling import ErnieForGeneration
32 | # from .optimization import *
33 | # from .ppminilm.modeling import *
34 | # from .ppminilm.tokenizer import *
35 | # from .bigbird.modeling import *
36 | # from .bigbird.tokenizer import *
37 | # from .unified_transformer.modeling import *
38 | # from .unified_transformer.tokenizer import *
39 | # from .ernie_ctm.modeling import *
40 | # from .ernie_ctm.tokenizer import *
41 | # from .tinybert.modeling import *
42 | # from .tinybert.tokenizer import *
43 | # from .distilbert.modeling import *
44 | # from .distilbert.tokenizer import *
45 | # from .skep.modeling import *
46 | # from .skep.tokenizer import *
47 | # from .xlnet.modeling import *
48 | # from .xlnet.tokenizer import *
49 | # from .albert.modeling import *
50 | # from .albert.tokenizer import *
51 | # from .ernie_gram.modeling import *
52 | # from .ernie_gram.tokenizer import *
53 | from .nezha.modeling import *
54 | from .nezha.tokenizer import *
55 | # from .ernie_doc.modeling import *
56 | # from .ernie_doc.tokenizer import *
57 | # from .bart.modeling import *
58 | # from .bart.tokenizer import *
59 | # from .roformer.modeling import *
60 | # from .roformer.tokenizer import *
61 | # from .blenderbot.modeling import *
62 | # from .blenderbot.tokenizer import *
63 | # from .blenderbot_small.modeling import *
64 | # from .blenderbot_small.tokenizer import *
65 | # from .unimo.modeling import *
66 | # from .unimo.tokenizer import *
67 | # from .squeezebert.modeling import *
68 | # from .squeezebert.tokenizer import *
69 | # from .convbert.modeling import *
70 | # from .convbert.tokenizer import *
71 | # from .mpnet.modeling import *
72 | # from .mpnet.tokenizer import *
73 | # from .auto.modeling import *
74 | # from .auto.tokenizer import *
75 | # from .ctrl.modeling import *
76 | # from .ctrl.tokenizer import *
77 | # from .layoutlmv2.modeling import *
78 | # from .layoutlmv2.tokenizer import *
79 | # from .layoutxlm.modeling import *
80 | # from .layoutxlm.tokenizer import *
81 | # from .layoutlm.modeling import *
82 | # from .layoutlm.tokenizer import *
83 | # from .t5.modeling import *
84 | # from .t5.tokenizer import *
85 | # from .mbart.modeling import *
86 | # from .mbart.tokenizer import *
87 | # from .reformer.modeling import *
88 | # from .reformer.tokenizer import *
89 | # from .mobilebert.modeling import *
90 | # from .mobilebert.tokenizer import *
91 | # from .chinesebert.modeling import *
92 | # from .chinesebert.tokenizer import *
93 | # from .funnel.modeling import *
94 | # from .funnel.tokenizer import *
95 | # from .ernie_m.modeling import *
96 | # from .ernie_m.tokenizer import *
97 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/bert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/bert/__init__.py
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/gpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/gpt/__init__.py
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/nezha/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling import *
2 | from .tokenizer import *
3 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/roberta/README.md:
--------------------------------------------------------------------------------
1 | # RoBERTa
2 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/roberta/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/transformers/roberta/__init__.py
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/transformers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | import inspect
17 |
18 | import paddle
19 | from paddle.nn import Layer
20 |
21 |
22 | def fn_args_to_dict(func, *args, **kwargs):
23 | """
24 | Inspect function `func` and its arguments for running, and extract a
25 | dict mapping between argument names and keys.
26 | """
27 | if hasattr(inspect, 'getfullargspec'):
28 | (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _,
29 | _) = inspect.getfullargspec(func)
30 | else:
31 | (spec_args, spec_varargs, spec_varkw,
32 | spec_defaults) = inspect.getargspec(func)
33 | # add positional argument values
34 | init_dict = dict(zip(spec_args, args))
35 | # add default argument values
36 | kwargs_dict = dict(zip(spec_args[-len(spec_defaults):],
37 | spec_defaults)) if spec_defaults else {}
38 | kwargs_dict.update(kwargs)
39 | init_dict.update(kwargs_dict)
40 | return init_dict
41 |
42 |
43 | class InitTrackerMeta(type(Layer)):
44 | """
45 | This metaclass wraps the `__init__` method of a class to add `init_config`
46 | attribute for instances of that class, and `init_config` use a dict to track
47 | the initial configuration. If the class has `_wrap_init` method, it would be
48 | hooked after `__init__` and called as `_wrap_init(self, init_fn, init_args)`.
49 | Since InitTrackerMeta would be used as metaclass for pretrained model classes,
50 | which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)`
51 | rather than `type` as base class for it to avoid inheritance metaclass
52 | conflicts.
53 | """
54 |
55 | def __init__(cls, name, bases, attrs):
56 | init_func = cls.__init__
57 | # If attrs has `__init__`, wrap it using accessable `_wrap_init`.
58 | # Otherwise, no need to wrap again since the super cls has been wraped.
59 | # TODO: remove reduplicated tracker if using super cls `__init__`
60 | help_func = getattr(cls, '_wrap_init',
61 | None) if '__init__' in attrs else None
62 | cls.__init__ = InitTrackerMeta.init_and_track_conf(init_func, help_func)
63 | super(InitTrackerMeta, cls).__init__(name, bases, attrs)
64 |
65 | @staticmethod
66 | def init_and_track_conf(init_func, help_func=None):
67 | """
68 | wraps `init_func` which is `__init__` method of a class to add `init_config`
69 | attribute for instances of that class.
70 | Args:
71 | init_func (callable): It should be the `__init__` method of a class.
72 | help_func (callable, optional): If provided, it would be hooked after
73 | `init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`.
74 | Default None.
75 |
76 | Returns:
77 | function: the wrapped function
78 | """
79 |
80 | @functools.wraps(init_func)
81 | def __impl__(self, *args, **kwargs):
82 | # keep full configuration
83 | init_func(self, *args, **kwargs)
84 | # registed helper by `_wrap_init`
85 | if help_func:
86 | help_func(self, init_func, *args, **kwargs)
87 | self.init_config = kwargs
88 | if args:
89 | kwargs['init_args'] = args
90 | kwargs['init_class'] = self.__class__.__name__
91 |
92 | return __impl__
93 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/bert_seq2seq/paddle_model/utils/__init__.py
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/utils/env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License"
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | '''
15 | This module is used to store environmental variables in PaddleNLP.
16 | PPNLP_HOME --> the root directory for storing PaddleNLP related data. Default to ~/.paddlenlp. Users can change the
17 | ├ default value through the PPNLP_HOME environment variable.
18 | ├─ MODEL_HOME --> Store model files.
19 | └─ DATA_HOME --> Store automatically downloaded datasets.
20 | '''
21 | import os
22 |
23 |
24 | def _get_user_home():
25 | return os.path.expanduser('~')
26 |
27 |
28 | def _get_ppnlp_home():
29 | if 'PPNLP_HOME' in os.environ:
30 | home_path = os.environ['PPNLP_HOME']
31 | if os.path.exists(home_path):
32 | if os.path.isdir(home_path):
33 | return home_path
34 | else:
35 | raise RuntimeError(
36 | 'The environment variable PPNLP_HOME {} is not a directory.'.
37 | format(home_path))
38 | else:
39 | return home_path
40 | return os.path.join(_get_user_home(), '.paddlenlp')
41 |
42 |
43 | def _get_sub_home(directory, parent_home=_get_ppnlp_home()):
44 | home = os.path.join(parent_home, directory)
45 | if not os.path.exists(home):
46 | os.makedirs(home)
47 | return home
48 |
49 |
50 | USER_HOME = _get_user_home()
51 | PPNLP_HOME = _get_ppnlp_home()
52 | MODEL_HOME = _get_sub_home('models')
53 | DATA_HOME = _get_sub_home('datasets')
54 | DOWNLOAD_SERVER = "http://paddlepaddle.org.cn/paddlehub"
55 | FAILED_STATUS = -1
56 | SUCCESS_STATUS = 0
57 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/utils/log.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License"
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import contextlib
17 | import copy
18 | import functools
19 | import logging
20 | import os
21 | import sys
22 | import time
23 | import threading
24 | from typing import List
25 |
26 | import colorlog
27 | from colorama import Fore
28 |
29 | loggers = {}
30 |
31 | log_config = {
32 | 'DEBUG': {
33 | 'level': 10,
34 | 'color': 'purple'
35 | },
36 | 'INFO': {
37 | 'level': 20,
38 | 'color': 'green'
39 | },
40 | 'TRAIN': {
41 | 'level': 21,
42 | 'color': 'cyan'
43 | },
44 | 'EVAL': {
45 | 'level': 22,
46 | 'color': 'blue'
47 | },
48 | 'WARNING': {
49 | 'level': 30,
50 | 'color': 'yellow'
51 | },
52 | 'ERROR': {
53 | 'level': 40,
54 | 'color': 'red'
55 | },
56 | 'CRITICAL': {
57 | 'level': 50,
58 | 'color': 'bold_red'
59 | }
60 | }
61 |
62 |
63 | class Logger(object):
64 | '''
65 | Deafult logger in PaddleNLP
66 |
67 | Args:
68 | name(str) : Logger name, default is 'PaddleNLP'
69 | '''
70 |
71 | def __init__(self, name: str=None):
72 | name = 'PaddleNLP' if not name else name
73 | self.logger = logging.getLogger(name)
74 |
75 | for key, conf in log_config.items():
76 | logging.addLevelName(conf['level'], key)
77 | self.__dict__[key] = functools.partial(self.__call__, conf['level'])
78 | self.__dict__[key.lower()] = functools.partial(self.__call__,
79 | conf['level'])
80 |
81 | self.format = colorlog.ColoredFormatter(
82 | '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
83 | log_colors={
84 | key: conf['color']
85 | for key, conf in log_config.items()
86 | })
87 |
88 | self.handler = logging.StreamHandler()
89 | self.handler.setFormatter(self.format)
90 |
91 | self.logger.addHandler(self.handler)
92 | self.logLevel = 'DEBUG'
93 | self.logger.setLevel(logging.DEBUG)
94 | self.logger.propagate = False
95 | self._is_enable = True
96 |
97 | def disable(self):
98 | self._is_enable = False
99 |
100 | def enable(self):
101 | self._is_enable = True
102 |
103 | @property
104 | def is_enable(self) -> bool:
105 | return self._is_enable
106 |
107 | def __call__(self, log_level: str, msg: str):
108 | if not self.is_enable:
109 | return
110 |
111 | self.logger.log(log_level, msg)
112 |
113 | @contextlib.contextmanager
114 | def use_terminator(self, terminator: str):
115 | old_terminator = self.handler.terminator
116 | self.handler.terminator = terminator
117 | yield
118 | self.handler.terminator = old_terminator
119 |
120 | @contextlib.contextmanager
121 | def processing(self, msg: str, interval: float=0.1):
122 | '''
123 | Continuously print a progress bar with rotating special effects.
124 |
125 | Args:
126 | msg(str): Message to be printed.
127 | interval(float): Rotation interval. Default to 0.1.
128 | '''
129 | end = False
130 |
131 | def _printer():
132 | index = 0
133 | flags = ['\\', '|', '/', '-']
134 | while not end:
135 | flag = flags[index % len(flags)]
136 | with self.use_terminator('\r'):
137 | self.info('{}: {}'.format(msg, flag))
138 | time.sleep(interval)
139 | index += 1
140 |
141 | t = threading.Thread(target=_printer)
142 | t.start()
143 | yield
144 | end = True
145 |
146 |
147 | logger = Logger()
148 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/utils/profiler.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | import paddle
17 |
18 | # A global variable to record the number of calling times for profiler
19 | # functions. It is used to specify the tracing range of training steps.
20 | _profiler_step_id = 0
21 |
22 | # A global variable to avoid parsing from string every time.
23 | _profiler_options = None
24 |
25 |
26 | class ProfilerOptions(object):
27 | '''
28 | Use a string to initialize a ProfilerOptions.
29 | The string should be in the format: "key1=value1;key2=value;key3=value3".
30 | For example:
31 | "profile_path=model.profile"
32 | "batch_range=[50, 60]; profile_path=model.profile"
33 | "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
34 |
35 | ProfilerOptions supports following key-value pair:
36 | batch_range - a integer list, e.g. [100, 110].
37 | state - a string, the optional values are 'CPU', 'GPU' or 'All'.
38 | sorted_key - a string, the optional values are 'calls', 'total',
39 | 'max', 'min' or 'ave.
40 | tracer_option - a string, the optional values are 'Default', 'OpDetail',
41 | 'AllOpDetail'.
42 | profile_path - a string, the path to save the serialized profile data,
43 | which can be used to generate a timeline.
44 | exit_on_finished - a boolean.
45 | '''
46 |
47 | def __init__(self, options_str):
48 | assert isinstance(options_str, str)
49 |
50 | self._options = {
51 | 'batch_range': [10, 20],
52 | 'state': 'All',
53 | 'sorted_key': 'total',
54 | 'tracer_option': 'Default',
55 | 'profile_path': '/tmp/profile',
56 | 'exit_on_finished': True
57 | }
58 | self._parse_from_string(options_str)
59 |
60 | def _parse_from_string(self, options_str):
61 | for kv in options_str.replace(' ', '').split(';'):
62 | key, value = kv.split('=')
63 | if key == 'batch_range':
64 | value_list = value.replace('[', '').replace(']', '').split(',')
65 | value_list = list(map(int, value_list))
66 | if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
67 | 1] > value_list[0]:
68 | self._options[key] = value_list
69 | elif key == 'exit_on_finished':
70 | self._options[key] = value.lower() in ("yes", "true", "t", "1")
71 | elif key in [
72 | 'state', 'sorted_key', 'tracer_option', 'profile_path'
73 | ]:
74 | self._options[key] = value
75 |
76 | def __getitem__(self, name):
77 | if self._options.get(name, None) is None:
78 | raise ValueError(
79 | "ProfilerOptions does not have an option named %s." % name)
80 | return self._options[name]
81 |
82 |
83 | def add_profiler_step(options_str=None):
84 | '''
85 | Enable the operator-level timing using PaddlePaddle's profiler.
86 | The profiler uses a independent variable to count the profiler steps.
87 | One call of this function is treated as a profiler step.
88 |
89 | Args:
90 | profiler_options - a string to initialize the ProfilerOptions.
91 | Default is None, and the profiler is disabled.
92 | '''
93 | if options_str is None:
94 | return
95 |
96 | global _profiler_step_id
97 | global _profiler_options
98 |
99 | if _profiler_options is None:
100 | _profiler_options = ProfilerOptions(options_str)
101 |
102 | if _profiler_step_id == _profiler_options['batch_range'][0]:
103 | paddle.utils.profiler.start_profiler(_profiler_options['state'],
104 | _profiler_options['tracer_option'])
105 | elif _profiler_step_id == _profiler_options['batch_range'][1]:
106 | paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
107 | _profiler_options['profile_path'])
108 | if _profiler_options['exit_on_finished']:
109 | sys.exit(0)
110 |
111 | _profiler_step_id += 1
112 |
--------------------------------------------------------------------------------
/bert_seq2seq/paddle_model/utils/tools.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import paddle
17 | from .log import logger
18 |
19 |
20 | def static_params_to_dygraph(model, static_tensor_dict):
21 | """Simple tool for convert static paramters to dygraph paramters dict.
22 |
23 | **NOTE** The model must both support static graph and dygraph mode.
24 |
25 | Args:
26 | model (nn.Layer): the model of a neural network.
27 | static_tensor_dict (string): path of which locate the saved paramters in static mode.
28 | Usualy load by `paddle.static.load_program_state`.
29 |
30 | Returns:
31 | [tensor dict]: a state dict the same as the dygraph mode.
32 | """
33 | state_dict = model.state_dict()
34 | # static_tensor_dict = paddle.static.load_program_state(static_params_path)
35 |
36 | ret_dict = dict()
37 | for n, p in state_dict.items():
38 | if p.name not in static_tensor_dict:
39 | logger.info("%s paramter is missing from you state dict." % n)
40 | continue
41 | ret_dict[n] = static_tensor_dict[p.name]
42 |
43 | return ret_dict
44 |
45 |
46 | def dygraph_params_to_static(model, dygraph_tensor_dict, topo=None):
47 | """Simple tool for convert dygraph paramters to static paramters dict.
48 |
49 | **NOTE** The model must both support static graph and dygraph mode.
50 |
51 | Args:
52 | model (nn.Layer): the model of a neural network.
53 | dygraph_tensor_dict (string): path of which locate the saved paramters in static mode.
54 |
55 | Returns:
56 | [tensor dict]: a state dict the same as the dygraph mode.
57 | """
58 | state_dict = model.state_dict()
59 |
60 | ret_dict = dict()
61 | for name, parm in state_dict.items():
62 | if name not in dygraph_tensor_dict:
63 | logger.info("%s paramter is missing from you state dict." % name)
64 | continue
65 |
66 | tensor = dygraph_tensor_dict[name]
67 | if parm.is_distributed:
68 | assert topo is not None
69 | for dim, v in enumerate(tensor.shape):
70 | if parm.shape[dim] != v:
71 | break
72 |
73 | splited = np.split(
74 | tensor, topo.mp_info.size, axis=dim)[topo.mp_info.rank]
75 | ret_dict[parm.name] = splited
76 | else:
77 | ret_dict[parm.name] = tensor
78 |
79 | return ret_dict
80 |
81 |
82 | class TimeCostAverage(object):
83 | """
84 | Simple tool for calcluating time average cost in the process of training and inferencing.
85 | """
86 |
87 | def __init__(self):
88 | self.reset()
89 |
90 | def reset(self):
91 | """
92 | Reset the recoder state, and reset the `cnt` to zero.
93 | """
94 | self.cnt = 0
95 | self.total_time = 0
96 |
97 | def record(self, usetime):
98 | """
99 | Recoding the time cost in current step and accumulating the `cnt`.
100 | """
101 | self.cnt += 1
102 | self.total_time += usetime
103 |
104 | def get_average(self):
105 | """
106 | Returning the average time cost after the start of training.
107 | """
108 | if self.cnt == 0:
109 | return 0
110 | return self.total_time / self.cnt
111 |
112 |
113 | def get_env_device():
114 | """
115 | Return the device name of running enviroment.
116 | """
117 | if paddle.is_compiled_with_cuda():
118 | return 'gpu'
119 | elif paddle.is_compiled_with_npu():
120 | return 'npu'
121 | elif paddle.is_compiled_with_rocm():
122 | return 'rocm'
123 | elif paddle.is_compiled_with_xpu():
124 | return 'xpu'
125 | return 'cpu'
126 |
127 |
128 | def compare_version(version, pair_version):
129 | """
130 | Args:
131 | version (str): The first version string needed to be compared.
132 | The format of version string should be as follow : "xxx.yyy.zzz".
133 | pair_version (str): The second version string needed to be compared.
134 | The format of version string should be as follow : "xxx.yyy.zzz".
135 | Returns:
136 | int: The result of comparasion. 1 means version > pair_version; 0 means
137 | version = pair_version; -1 means version < pair_version.
138 |
139 | Examples:
140 | >>> compare_version("2.2.1", "2.2.0")
141 | >>> 1
142 | >>> compare_version("2.2.0", "2.2.0")
143 | >>> 0
144 | >>> compare_version("2.2.0-rc0", "2.2.0")
145 | >>> -1
146 | >>> compare_version("2.3.0-rc0", "2.2.0")
147 | >>> 1
148 | """
149 | version = version.strip()
150 | pair_version = pair_version.strip()
151 | if version == pair_version:
152 | return 0
153 | version_list = version.split(".")
154 | pair_version_list = pair_version.split(".")
155 | for version_code, pair_version_code in zip(version_list, pair_version_list):
156 | if not version_code.isnumeric():
157 | return -1
158 | if not pair_version_code.isnumeric():
159 | return 1
160 | if int(version_code) > int(pair_version_code):
161 | return 1
162 | elif int(version_code) < int(pair_version_code):
163 | return -1
164 | return 0
165 |
--------------------------------------------------------------------------------
/bert_seq2seq/t5_ch.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq.model.t5_model import T5ForConditionalGeneration, T5Config, T5SmallConfig
4 | from bert_seq2seq.tokenizer import T5PegasusTokenizer,load_chinese_base_vocab
5 | from bert_seq2seq.basic_bert import BasicT5
6 | from bert_seq2seq.seq2seq_model import top_k_top_p_filtering
7 | import torch.nn.functional as F
8 |
9 | class T5Model(BasicT5):
10 |
11 | def __init__(self, word2idx, size="base"):
12 | super().__init__()
13 | if size == "base":
14 | config = T5Config()
15 | elif size == "small":
16 | config = T5SmallConfig()
17 | else:
18 | raise Exception("not support this model type")
19 | self.model = T5ForConditionalGeneration(config)
20 |
21 | self.word2idx = word2idx
22 | self.tokenizer = T5PegasusTokenizer(self.word2idx)
23 | self.bos_id = self.word2idx["[CLS]"]
24 | self.eos_id = self.word2idx["[SEP]"]
25 | self.unk_id = self.word2idx["[UNK]"]
26 |
27 | def forward(self, input_ids, decoder_input_ids, labels=None):
28 | input_ids = input_ids.to(self.device)
29 | decoder_input_ids = decoder_input_ids.to(self.device)
30 | if labels is not None:
31 | labels = labels.to(self.device)
32 | return self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)
33 |
34 |
35 | def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=1.0, add_eos=True):
36 |
37 | token_out = self.tokenizer.encode(text, max_length=input_max_length)
38 | if len(token_out) == 2:
39 | token_ids = token_out[0]
40 | else:
41 | token_ids = token_out
42 | if not add_eos:
43 | token_ids = token_ids[:-1]
44 | token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
45 | output_ids = []
46 |
47 | input_decoder_ids = torch.tensor(self.bos_id, device=self.device, dtype=torch.long).view(1, -1)
48 | with torch.no_grad():
49 | for step in range(out_max_length):
50 | scores = self.model(input_ids=token_ids, decoder_input_ids=input_decoder_ids)[0]
51 | logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
52 | logit_score[self.unk_id] = -float('Inf')
53 | filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
54 |
55 | filterd_logits_prob = F.softmax(filtered_logits, dim=-1)
56 |
57 | next_token = torch.multinomial(filterd_logits_prob, num_samples=1)
58 | if self.eos_id == next_token.item():
59 | break
60 |
61 | output_ids.append(next_token.item())
62 | input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1)
63 |
64 | return self.tokenizer.decode(output_ids)
--------------------------------------------------------------------------------
/bert_seq2seq/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq.seq2seq_model import Seq2SeqModel
3 | from bert_seq2seq.bert_cls_classifier import BertClsClassifier
4 | from bert_seq2seq.bert_seq_labeling import BertSeqLabeling
5 | from bert_seq2seq.bert_seq_labeling_crf import BertSeqLabelingCRF
6 | from bert_seq2seq.bert_relation_extraction import BertRelationExtrac
7 | from bert_seq2seq.bert_cls_multi_classifier import BertClsMultiClassifier
8 | import torch.nn.functional as F
9 | from bert_seq2seq.bert_cls_multi_seq2seq import ClsMultiSeq2SeqModel
10 | from bert_seq2seq.simbert_model import SimBertModel
11 | from bert_seq2seq.gpt2_generate_model import GPT2
12 | from bert_seq2seq.basic_bert import BasicBert
13 |
14 |
15 | def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0, target=None):
16 | """
17 | model_path: 模型位置
18 | 这是个统一的接口,用来加载模型的
19 | model_class : seq2seq or encoder
20 | """
21 |
22 | if model_class == "seq2seq":
23 | bert_model = Seq2SeqModel(word2ix, model_name=model_name, tokenizer=tokenizer)
24 | return bert_model
25 | elif model_class == "cls":
26 | if target_size == 0:
27 | raise Exception("必须传入参数 target_size,才能确定预测多少分类")
28 | bert_model = BertClsClassifier(word2ix, target_size, model_name=model_name)
29 | return bert_model
30 | elif model_class == "sequence_labeling":
31 | ## 序列标注模型
32 | if target_size == 0:
33 | raise Exception("必须传入参数 target_size,才能确定预测多少分类")
34 | bert_model = BertSeqLabeling(word2ix, target_size, model_name=model_name)
35 | return bert_model
36 | elif model_class == "sequence_labeling_crf":
37 | # 带有crf层的序列标注模型
38 | if target_size == 0:
39 | raise Exception("必须传入参数 target_size,才能确定预测多少分类")
40 | bert_model = BertSeqLabelingCRF(word2ix, target_size, model_name=model_name)
41 | return bert_model
42 | elif model_class == "relation_extrac":
43 | if target_size == 0:
44 | raise Exception("必须传入参数 target_size 表示预测predicate的种类")
45 | bert_model = BertRelationExtrac(word2ix, target_size, model_name=model_name)
46 | return bert_model
47 | elif model_class == "simbert":
48 | bert_model = SimBertModel(word2ix, model_name=model_name, tokenizer=tokenizer)
49 | return bert_model
50 | elif model_class == "multi_label_cls":
51 | bert_model = BertClsMultiClassifier(word2ix, target_size, model_name=model_name)
52 | return bert_model
53 | elif model_class == "multi_label_cls_seq2seq":
54 | bert_model = ClsMultiSeq2SeqModel(word2ix, target, model_name=model_name)
55 | return bert_model
56 | elif model_class == "embedding":
57 | bert_model = BasicBert(word2ix, model_name=model_name, tokenizer=tokenizer)
58 | return bert_model
59 | else :
60 | raise Exception("model_name_err")
61 |
62 |
63 | def load_gpt(word2ix, tokenizer=None):
64 | model = GPT2(word2ix, tokenizer=tokenizer)
65 | return model
66 |
67 |
68 |
--------------------------------------------------------------------------------
/examples/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/examples/.DS_Store
--------------------------------------------------------------------------------
/examples/bart_auto_title_train.py:
--------------------------------------------------------------------------------
1 | ## model url : https://huggingface.co/fnlp/bart-base-chinese
2 | import torch
3 | import time
4 | import glob
5 | from torch.utils.data import Dataset, DataLoader
6 | from tqdm import tqdm
7 | from bert_seq2seq.extend_model_method import ExtendModel
8 |
9 | from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline
10 |
11 | src_dir = './corpus/auto_title/train.src'
12 | tgt_dir = './corpus/auto_title/train.tgt'
13 |
14 | vocab_path = "./state_dict/bart-chinese" ## 字典
15 | model_path = "./state_dict/bart-chinese" ## 预训练参数
16 |
17 | model_save_path = "./state_dict/bart_autotile.bin" ## 训练完模型 保存在哪里
18 | batch_size = 8
19 | lr = 1e-5
20 |
21 | tokenizer = BertTokenizer.from_pretrained(vocab_path)
22 | word2idx = tokenizer.vocab
23 | model = BartForConditionalGeneration.from_pretrained(model_path)
24 |
25 | def read_file():
26 | src = []
27 | tgt = []
28 |
29 | with open(src_dir,'r',encoding='utf-8') as f:
30 | lines = f.readlines()
31 |
32 | for line in lines:
33 | src.append(line.strip('\n').lower())
34 |
35 | with open(tgt_dir,'r',encoding='utf-8') as f:
36 | lines = f.readlines()
37 | for line in lines:
38 | tgt.append(line.strip('\n').lower())
39 |
40 | return src, tgt
41 |
42 |
43 | class SeqDataset(Dataset):
44 | """
45 | 针对特定数据集,定义一个相关的取数据的方式
46 | """
47 |
48 | def __init__(self, sents_src, sents_tgt):
49 | ## 一般init函数是加载所有数据
50 | super(SeqDataset, self).__init__()
51 | # 读原始数据
52 | self.sents_src = sents_src
53 | self.sents_tgt = sents_tgt
54 |
55 | self.idx2word = {k: v for v, k in word2idx.items()}
56 |
57 | def __getitem__(self, i):
58 | ## 得到单个数据
59 | # print(i)
60 | src = self.sents_src[i]
61 | tgt = self.sents_tgt[i]
62 | token_ids_src = tokenizer.encode(src, max_length=256)
63 | token_ids_tgt = tokenizer.encode(tgt, max_length=256)
64 | output = {
65 | "token_ids_src": token_ids_src,
66 | "token_ids_tgt": token_ids_tgt,
67 | }
68 | return output
69 |
70 | def __len__(self):
71 | return len(self.sents_src)
72 |
73 |
74 | def collate_fn(batch):
75 | """
76 | 动态padding, batch为一部分sample
77 | """
78 |
79 | def padding(indice, max_length, pad_idx=0):
80 | """
81 | pad 函数
82 | """
83 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
84 | return torch.tensor(pad_indice)
85 |
86 | token_ids_src = [data["token_ids_src"] for data in batch]
87 | max_length_src = max([len(t) for t in token_ids_src])
88 | token_ids_tgt = [data["token_ids_tgt"] for data in batch]
89 | max_length_tgt = max([len(t) for t in token_ids_tgt])
90 |
91 | token_ids_padded = padding(token_ids_src, max_length_src)
92 | target_ids_padded = padding(token_ids_tgt, max_length_tgt)
93 | labels_ids = target_ids_padded.clone()
94 | target_ids_padded = target_ids_padded[:, :-1].contiguous()
95 | labels_ids = labels_ids[:, 1:].contiguous()
96 |
97 | return token_ids_padded, target_ids_padded, labels_ids
98 |
99 |
100 | class Trainer:
101 | def __init__(self):
102 | # 加载数据
103 | self.sents_src, self.sents_tgt = read_file()
104 |
105 | # 判断是否有可用GPU
106 | self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
107 | print("device: " + str(self.device))
108 | # 定义模型
109 | self.model = ExtendModel(model, tokenizer=tokenizer, bos_id=word2idx["[CLS]"], eos_id=word2idx["[SEP]"], device=self.device)
110 |
111 | # 将模型发送到计算设备(GPU或CPU)
112 | self.model.to(self.device)
113 | # self.model.set_device(self.device)
114 | # 声明需要优化的参数
115 | self.optim_parameters = list(self.model.parameters())
116 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
117 | # 声明自定义的数据加载器
118 | dataset = SeqDataset(self.sents_src, self.sents_tgt)
119 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
120 |
121 | def train(self, epoch):
122 | # 一个epoch的训练
123 | self.model.train()
124 | self.iteration(epoch, dataloader=self.dataloader, train=True)
125 |
126 | def save(self, save_path):
127 | """
128 | 保存模型
129 | """
130 | self.model.save_all_params(save_path)
131 | print("{} saved!".format(save_path))
132 |
133 | def iteration(self, epoch, dataloader, train=True):
134 | total_loss = 0
135 | report_loss = 0
136 | start_time = time.time() ## 得到当前时间
137 | step = 0
138 | for token_ids, target_ids, labels_ids in tqdm(dataloader, total=len(dataloader)):
139 | step += 1
140 | token_ids = token_ids.to(self.device)
141 | target_ids = target_ids.to(self.device)
142 | labels_ids = labels_ids.to(self.device)
143 | if step % 100 == 0:
144 | # self.save(model_save_path)
145 | self.model.eval()
146 | test_data = ["本文总结了十个可穿戴产品的设计原则,而这些原则同样也是笔者认为是这个行业最吸引人的地方:1为人们解决重复性问题,2从人开始而不是从机器开始,3要引起注意但不要刻意,4提升用户能力而不是取代人",
147 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界,还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落,未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
148 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"]
149 |
150 | for text in test_data:
151 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True, top_k=20))
152 | self.model.train()
153 | print("report loss is " + str(report_loss))
154 | report_loss = 0
155 |
156 | # 因为传入了target标签,因此会计算loss并且返回
157 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0]
158 | # 反向传播
159 | if train:
160 | # 清空之前的梯度
161 | self.optimizer.zero_grad()
162 | # 反向传播, 获取新的梯度
163 | loss.backward()
164 | # 用获取的梯度更新模型参数
165 | self.optimizer.step()
166 |
167 | # 为计算当前epoch的平均loss
168 | total_loss += loss.item()
169 | report_loss += loss.item()
170 |
171 | end_time = time.time()
172 | spend_time = end_time - start_time
173 | # 打印训练信息
174 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
175 | # 保存模型
176 | # self.save(model_save_path)
177 |
178 |
179 | if __name__ == '__main__':
180 |
181 | trainer = Trainer()
182 | train_epoches = 10
183 | for epoch in range(train_epoches):
184 | # 训练一个epoch
185 | trainer.train(epoch)
186 |
--------------------------------------------------------------------------------
/examples/examples_paddle/mBart_translation_en_ro.py:
--------------------------------------------------------------------------------
1 | from paddlenlp.transformers import MBartForConditionalGeneration,MBartTokenizer
2 | import paddle
3 | from paddle.io import Dataset
4 | import argparse
5 | from tqdm import tqdm
6 |
7 | parse = argparse.ArgumentParser()
8 | parse.add_argument("--model_name",type=str)
9 | parse.add_argument("--epoches",type=int)
10 | parse.add_argument("--test_step",type=int)
11 | parse.add_argument("--save_step",type=int)
12 | parse.add_argument("--batch_size",type=int)
13 | parse.add_argument("--src_lang",type=str)
14 | parse.add_argument("--tgt_lang",type=str)
15 | parse.add_argument("--datapath_en",type=str)
16 | parse.add_argument("--datapath_ro",type=str)
17 | parse.add_argument("--max_length",type=int)
18 | opt = parse.parse_args()
19 |
20 | def read_dataset(datapath_en,datapath_ro):
21 | dataset_en = []
22 | dataset_ro = []
23 | with open(datapath_en) as f:
24 | for line in f:
25 | dataset_en.append(line.strip("\n"))
26 | with open(datapath_ro) as ff:
27 | for line in ff:
28 | dataset_ro.append(line.strip("\n"))
29 | return dataset_en,dataset_ro
30 |
31 | class EnRoDataset(Dataset):
32 | def __init__(self,dataset_en,dataset_ro):
33 | super(EnRoDataset,self).__init__()
34 | self.dataset_en = dataset_en
35 | self.dataset_ro = dataset_ro
36 | def __getitem__(self,index):
37 | data_for_en = self.dataset_en[index]
38 | data_for_ro = self.dataset_ro[index]
39 | input_ids = tokenizer.encode(data_for_en)["input_ids"]
40 | decoder_input_ids = [tokenizer.lang_code_to_id[opt.tgt_lang]]+tokenizer.encode(data_for_ro)["input_ids"][:-1]
41 | output = {
42 | "input_ids":input_ids,
43 | "decoder_input_ids":decoder_input_ids
44 | }
45 | return output
46 | def __len__(self):
47 | return len(self.dataset_en)
48 | @staticmethod
49 | def collate_fn(batch):
50 | def padding(indice,max_length=50,pad_idx=tokenizer.pad_token_id):
51 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
52 | return paddle.to_tensor(pad_indice)
53 |
54 | input_ids = [data["input_ids"] for data in batch]
55 | decoder_input_ids = [data["decoder_input_ids"] for data in batch]
56 | max_length_of_input_ids = max([len(text) for text in input_ids])
57 | max_length_of_decoder_input_ids = max([len(text) for text in decoder_input_ids])
58 |
59 | input_ids_padded = padding(input_ids,max_length_of_input_ids)
60 | decoder_input_ids_padded = padding(decoder_input_ids,max_length_of_decoder_input_ids)
61 | return input_ids_padded,decoder_input_ids_padded
62 |
63 |
64 | model = MBartForConditionalGeneration.from_pretrained(opt.model_name)
65 | tokenizer = MBartTokenizer.from_pretrained(opt.model_name,src_lang=opt.src_lang,tgt_lang=opt.tgt_lang)
66 | dataset_en,dataset_ro = read_dataset(opt.datapath_en,opt.datapath_ro)
67 |
68 | dataset = EnRoDataset(
69 | dataset_en,
70 | dataset_ro)
71 | dataloader = paddle.io.DataLoader(
72 | dataset,
73 | batch_size=opt.batch_size,
74 | shuffle=True,
75 | collate_fn=dataset.collate_fn)
76 | optimizer = paddle.optimizer.AdamW(
77 | learning_rate=1e-5,
78 | parameters=model.parameters(),
79 | weight_decay=1e-5)
80 |
81 | def calculate_loss(logits,label):
82 | return paddle.nn.functional.cross_entropy(logits.reshape([-1,tokenizer.vocab_size]),label.reshape([-1]))
83 |
84 | def generate_text_for_test(text,target_language,max_length):
85 | with paddle.no_grad():
86 | input_ids = paddle.to_tensor(tokenizer.encode(text)["input_ids"]).unsqueeze(0)
87 | bos_id = tokenizer.lang_code_to_id[target_language]
88 | outputs, _ = model.generate(
89 | input_ids=input_ids,
90 | forced_bos_token_id=bos_id,
91 | decode_strategy="beam_search",
92 | num_beams=4,
93 | max_length=50)
94 | return tokenizer.convert_ids_to_string(outputs[0].numpy().tolist()[1:-1])
95 |
96 | def train():
97 |
98 | global_step = 1
99 | report_loss = 0
100 |
101 | for epoch in range(opt.epoches):
102 | for input_ids, decoder_input_ids in tqdm(dataloader(), total=len(dataloader)):
103 | model.train()
104 | if global_step % opt.test_step == 0:
105 | model.eval()
106 | texts = ["election of Vice-Presidents of the European Parliament ( deadline for submitting nominations ) : see Minutes","agenda for next sitting : see Minutes"]
107 | for text in texts:
108 | print("English:",text)
109 | print("Romanian",generate_text_for_test(text,opt.tgt_lang,opt.max_length))
110 | print("loss is {}".format(report_loss))
111 | report_loss = 0
112 | model.train()
113 | if global_step % opt.save_step == 0:
114 | pass
115 | logits = model(input_ids=input_ids,decoder_input_ids=decoder_input_ids)
116 | loss = calculate_loss(logits[:,:-2],decoder_input_ids[:,1:-1])
117 | report_loss = report_loss + loss.item()
118 | loss.backward()
119 | optimizer.step()
120 | optimizer.clear_grad()
121 | global_step = global_step + 1
122 |
123 |
124 |
125 | if __name__ == "__main__":
126 | train()
127 |
128 | # python new.py --model_name "mbart-large-en-ro" --epoches 3 --test_step 10 --save_step 10000 --batch_size 3 --src_lang "en_XX" --tgt_lang "ro_RO" --datapath_en "./wmt16_en_ro/corpus.en" --datapath_ro "./wmt16_en_ro/corpus.ro" --max_length 128
129 |
130 |
--------------------------------------------------------------------------------
/examples/gpt2_ancient_translation_train.py:
--------------------------------------------------------------------------------
1 | ## gpt2 进行文言文翻译
2 | from bert_seq2seq import load_gpt
3 | import torch
4 | from tqdm import tqdm
5 | import time
6 | import glob
7 | from torch.utils.data import Dataset, DataLoader
8 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
9 |
10 | vocab_path = "./state_dict/gpt2通用中文模型/vocab.txt"
11 | model_path = "./state_dict/gpt2通用中文模型/pytorch_model.bin"
12 | model_save_path = "./state_dict/gpt_ancient_trans_model.bin"
13 | batch_size = 8
14 | lr = 1e-5
15 | word2idx = load_chinese_base_vocab(vocab_path)
16 |
17 | def read_corpus():
18 | """
19 | 读原始数据
20 | """
21 | src = []
22 | tgt = []
23 | data_path = glob.glob("./corpus/文言文翻译/*")
24 | for p in data_path:
25 | dir = p.split("/")[:-1]
26 | dir = "/".join(dir)
27 | # print(dir)
28 | name = p.split("/")[-1]
29 | if "翻译" in name:
30 | # 找到了一个翻译文件
31 | tgt_name = name
32 | src_name = name[:-2]
33 | with open(dir + "/" + src_name) as fs:
34 | lines = fs.readlines()
35 | for line in lines:
36 | src.append(line.strip("\n").strip())
37 |
38 | with open(dir + "/" + tgt_name) as ft:
39 | lines = ft.readlines()
40 | for line in lines:
41 | tgt.append(line.strip("\n").strip())
42 |
43 | else:
44 | pass
45 |
46 | return src, tgt
47 |
48 | class SeqDataset(Dataset):
49 | """
50 | 针对特定数据集,定义一个相关的取数据的方式
51 | """
52 |
53 | def __init__(self, sents_src, sents_tgt):
54 | ## 一般init函数是加载所有数据
55 | super(SeqDataset, self).__init__()
56 | # 读原始数据
57 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
58 | self.sents_src = sents_src
59 | self.sents_tgt = sents_tgt
60 |
61 | self.idx2word = {k: v for v, k in word2idx.items()}
62 | self.tokenizer = Tokenizer(word2idx)
63 |
64 | def __getitem__(self, i):
65 | ## 得到单个数据
66 | # print(i)
67 | src = self.sents_src[i]
68 | tgt = self.sents_tgt[i]
69 | token_ids, _ = self.tokenizer.encode(src, tgt, max_length=256)
70 | output = {
71 | "token_ids": token_ids,
72 | }
73 | return output
74 |
75 | def __len__(self):
76 | return len(self.sents_src)
77 |
78 |
79 | def collate_fn(batch):
80 | """
81 | 动态padding, batch为一部分sample
82 | """
83 |
84 | def padding(indice, max_length, pad_idx=0):
85 | """
86 | pad 函数
87 | """
88 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
89 | return torch.tensor(pad_indice)
90 |
91 | token_ids = [data["token_ids"] for data in batch]
92 | max_length = max([len(t) for t in token_ids])
93 |
94 | token_ids_padded = padding(token_ids, max_length)
95 | target_ids_padded = token_ids_padded.clone()
96 | target_ids_padded[target_ids_padded == 0] = -100
97 |
98 | return token_ids_padded, target_ids_padded
99 |
100 |
101 | class Trainer:
102 | def __init__(self):
103 | # 加载数据
104 | self.sents_src, self.sents_tgt = read_corpus()
105 |
106 | # 判断是否有可用GPU
107 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
108 | print("device: " + str(self.device))
109 | # 定义模型
110 | self.gpt_model = load_gpt(word2idx)
111 | ## 加载预训练的模型参数~
112 | self.gpt_model.load_pretrain_params(model_path)
113 | # 将模型发送到计算设备(GPU或CPU)
114 | self.gpt_model.set_device(self.device)
115 | # 声明需要优化的参数
116 | self.optim_parameters = list(self.gpt_model.parameters())
117 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
118 | # 声明自定义的数据加载器
119 | dataset = SeqDataset(self.sents_src, self.sents_tgt)
120 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
121 |
122 | def train(self, epoch):
123 | # 一个epoch的训练
124 | self.gpt_model.train()
125 | self.iteration(epoch, dataloader=self.dataloader, train=True)
126 |
127 | def save(self, save_path):
128 | """
129 | 保存模型
130 | """
131 | self.gpt_model.save_all_params(save_path)
132 | print("{} saved!".format(save_path))
133 |
134 | def iteration(self, epoch, dataloader, train=True):
135 | total_loss = 0
136 | report_loss = 0
137 | start_time = time.time() ## 得到当前时间
138 | step = 0
139 | # for token_ids, target_ids in tqdm(dataloader, position=0, leave=True):
140 | for token_ids, target_ids in dataloader:
141 | step += 1
142 | if step % 4000 == 0:
143 | self.gpt_model.eval()
144 | test_data = ["遂入颍川。", "会日暝,结陈相持。", "一言兴邦,斯近之矣。"]
145 | for text in test_data:
146 | print(self.gpt_model.sample_generate(text, add_eos=True))
147 | self.gpt_model.train()
148 | print("report loss is " + str(report_loss))
149 | report_loss = 0.0
150 |
151 | # 因为传入了target标签,因此会计算loss并且返回
152 | loss, _ = self.gpt_model(token_ids,
153 | labels=target_ids,
154 | )
155 | # 反向传播
156 | if train:
157 | # 清空之前的梯度
158 | self.optimizer.zero_grad()
159 | # 反向传播, 获取新的梯度
160 | loss.backward()
161 | # 用获取的梯度更新模型参数
162 | self.optimizer.step()
163 |
164 | # 为计算当前epoch的平均loss
165 | total_loss += loss.item()
166 | report_loss += loss.item()
167 |
168 | end_time = time.time()
169 | spend_time = end_time - start_time
170 | # 打印训练信息
171 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
172 | # 保存模型
173 | self.save(model_save_path)
174 |
175 |
176 | if __name__ == '__main__':
177 |
178 | trainer = Trainer()
179 | train_epoches = 100
180 | for epoch in range(train_epoches):
181 | # 训练一个epoch
182 | trainer.train(epoch)
183 |
--------------------------------------------------------------------------------
/examples/gpt2_english_story_train.py:
--------------------------------------------------------------------------------
1 | # gpt2模型进行英文讲故事 给一个开头 继续讲五句话
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | import pandas as pd
6 | from torch.utils.data import Dataset, DataLoader
7 | from bert_seq2seq import load_gpt
8 | from transformers import AutoTokenizer
9 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
10 | word2ix = tokenizer.get_vocab()
11 | data_path = "./corpus/英文讲故事数据集/train.csv"
12 | model_path = "./state_dict/english_gpt_model/english_gpt_story.bin"
13 | model_save_path = "./state_dict/gpt_auto_story.bin"
14 | batch_size = 8
15 | lr = 1e-5
16 | maxlen = 256
17 |
18 | def load_data():
19 | sents_src = []
20 | sents_tgt = []
21 | df = pd.read_csv(data_path)
22 | for i, row in df.iterrows():
23 | sents_src.append(row[1])
24 | tgt = ""
25 | for j in range(2, 7):
26 | tgt += row[j]
27 | sents_tgt.append(tgt)
28 |
29 | return sents_src, sents_tgt
30 |
31 | class GPTDataset(Dataset):
32 | """
33 | 针对特定数据集,定义一个相关的取数据的方式
34 | """
35 |
36 | def __init__(self):
37 | ## 一般init函数是加载所有数据
38 | super(GPTDataset, self).__init__()
39 | ## 拿到所有文件名字
40 | self.sents_src, self.sents_tgt = load_data()
41 | self.tokenizer = tokenizer
42 |
43 | def __getitem__(self, i):
44 | ## 得到单个数据
45 |
46 | src_d = self.sents_src[i]
47 | tgt_d = self.sents_tgt[i]
48 | src_ids = self.tokenizer.encode(src_d) + [self.tokenizer.eos_token_id]
49 | tgt_ids = self.tokenizer.encode(tgt_d) + [self.tokenizer.eos_token_id]
50 | output = {
51 | "token_ids": src_ids + tgt_ids,
52 | }
53 | return output
54 |
55 |
56 |
57 | def __len__(self):
58 | return len(self.sents_src)
59 |
60 |
61 | def collate_fn(batch):
62 | """
63 | 动态padding, batch为一部分sample
64 | """
65 |
66 | def padding(indice, max_length, pad_idx=0):
67 | """
68 | pad 函数
69 | """
70 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
71 | return torch.tensor(pad_indice)
72 |
73 | token_ids = [data["token_ids"] for data in batch]
74 | max_length = max([len(t) for t in token_ids])
75 |
76 | token_ids_padded = padding(token_ids, max_length, pad_idx=word2ix[""])
77 | token_target_padded = token_ids_padded.clone()
78 | token_target_padded[token_target_padded == word2ix[""]] = -100
79 | return token_ids_padded, token_target_padded
80 |
81 |
82 | class Trainer:
83 | def __init__(self):
84 | # 判断是否有可用GPU
85 | # self.device = torch.device("cpu")
86 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87 | print("device: " + str(self.device))
88 | # 定义模型
89 | self.model = load_gpt(word2ix, tokenizer=tokenizer)
90 | self.model.load_pretrain_params(model_path)
91 | # 加载已经训练好的模型,继续训练
92 |
93 | # 将模型发送到计算设备(GPU或CPU)
94 | self.model.set_device(self.device)
95 | # 声明需要优化的参数
96 | self.optim_parameters = list(self.model.parameters())
97 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
98 | # 声明自定义的数据加载器
99 | dataset = GPTDataset()
100 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
101 |
102 | def train(self, epoch):
103 | # 一个epoch的训练
104 | self.model.train()
105 | self.iteration(epoch, dataloader=self.dataloader, train=True)
106 |
107 | def save(self, save_path):
108 | """
109 | 保存模型
110 | """
111 | self.model.save_all_params(save_path)
112 | print("{} saved!".format(save_path))
113 |
114 | def iteration(self, epoch, dataloader, train=True):
115 | total_loss = 0
116 | start_time = time.time() ## 得到当前时间
117 | step = 0
118 | report_loss = 0
119 | for token_ids, token_target in tqdm(dataloader, position=0, leave=True):
120 | step += 1
121 | if step % 1000 == 0:
122 | self.model.eval()
123 | print(self.model.sample_generate_english("David Drops the Weight", out_max_length=300, add_eos=True))
124 | print("loss is " + str(report_loss))
125 | report_loss = 0
126 | self.model.train()
127 | if step % 6000 == 0:
128 | self.save(model_save_path)
129 |
130 | # 因为传入了target标签,因此会计算loss并且返回
131 | loss, pred_logit = self.model(token_ids, labels=token_target)
132 | report_loss += loss.item()
133 | # 反向传播
134 | if train:
135 | # 清空之前的梯度
136 | self.optimizer.zero_grad()
137 | # 反向传播, 获取新的梯度
138 | loss.backward()
139 | # 用获取的梯度更新模型参数
140 | self.optimizer.step()
141 |
142 | # 为计算当前epoch的平均loss
143 | total_loss += loss.item()
144 |
145 | end_time = time.time()
146 | spend_time = end_time - start_time
147 | # 打印训练信息
148 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
149 |
150 |
151 | if __name__ == '__main__':
152 |
153 | trainer = Trainer()
154 | train_epoches = 20
155 |
156 | for epoch in range(train_epoches):
157 | # 训练一个epoch
158 | trainer.train(epoch)
--------------------------------------------------------------------------------
/examples/gpt2_explain_dream_train.py:
--------------------------------------------------------------------------------
1 | ## gpt2模型进行周公解梦
2 | from bert_seq2seq import load_gpt
3 | import torch
4 | from tqdm import tqdm
5 | import pandas as pd
6 | import time
7 | from torch.utils.data import Dataset, DataLoader
8 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
9 |
10 | vocab_path = "./state_dict/gpt_vocab.txt"
11 | model_path = "./state_dict/gpt_pytorch_model.bin"
12 | model_save_path = "./state_dict/gpt_explain_dream_model.bin"
13 | batch_size = 16
14 | lr = 1e-5
15 | data_path = "./corpus/周公解梦/dream_data.csv"
16 | word2idx = load_chinese_base_vocab(vocab_path)
17 |
18 | def read_corpus():
19 | """
20 | 读原始数据
21 | """
22 | sents_src = []
23 | sents_tgt = []
24 |
25 | df = pd.read_csv(data_path, delimiter="\t")
26 | for i, row in df.iterrows():
27 | # print(row)
28 | json_s = eval(row[0])
29 | sents_src.append(json_s["dream"])
30 | sents_tgt.append(json_s["decode"])
31 |
32 | return sents_src, sents_tgt
33 |
34 |
35 | class SeqDataset(Dataset):
36 | """
37 | 针对特定数据集,定义一个相关的取数据的方式
38 | """
39 |
40 | def __init__(self, sents_src, sents_tgt):
41 | ## 一般init函数是加载所有数据
42 | super(SeqDataset, self).__init__()
43 | # 读原始数据
44 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
45 | self.sents_src = sents_src
46 | self.sents_tgt = sents_tgt
47 |
48 | self.idx2word = {k: v for v, k in word2idx.items()}
49 | self.tokenizer = Tokenizer(word2idx)
50 |
51 | def __getitem__(self, i):
52 | ## 得到单个数据
53 | # print(i)
54 | src = self.sents_src[i]
55 | tgt = self.sents_tgt[i]
56 | token_ids, _ = self.tokenizer.encode(src, tgt)
57 | output = {
58 | "token_ids": token_ids,
59 | }
60 | return output
61 |
62 | def __len__(self):
63 | return len(self.sents_src)
64 |
65 |
66 | def collate_fn(batch):
67 | """
68 | 动态padding, batch为一部分sample
69 | """
70 |
71 | def padding(indice, max_length, pad_idx=0):
72 | """
73 | pad 函数
74 | """
75 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
76 | return torch.tensor(pad_indice)
77 |
78 | token_ids = [data["token_ids"] for data in batch]
79 | max_length = max([len(t) for t in token_ids])
80 |
81 | token_ids_padded = padding(token_ids, max_length)
82 | target_ids_padded = token_ids_padded.clone()
83 | target_ids_padded[target_ids_padded == 0] = -100
84 |
85 | return token_ids_padded, target_ids_padded
86 |
87 |
88 | class Trainer:
89 | def __init__(self):
90 | # 加载数据
91 | self.sents_src, self.sents_tgt = read_corpus()
92 |
93 | # 判断是否有可用GPU
94 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95 | print("device: " + str(self.device))
96 | # 定义模型
97 | self.gpt_model = load_gpt(word2idx)
98 | ## 加载预训练的模型参数~
99 | self.gpt_model.load_pretrain_params(model_path)
100 | # 将模型发送到计算设备(GPU或CPU)
101 | self.gpt_model.set_device(self.device)
102 | # 声明需要优化的参数
103 | self.optim_parameters = list(self.gpt_model.parameters())
104 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
105 | # 声明自定义的数据加载器
106 | dataset = SeqDataset(self.sents_src, self.sents_tgt)
107 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
108 |
109 | def train(self, epoch):
110 | # 一个epoch的训练
111 | self.gpt_model.train()
112 | self.iteration(epoch, dataloader=self.dataloader, train=True)
113 |
114 | def save(self, save_path):
115 | """
116 | 保存模型
117 | """
118 | self.gpt_model.save_all_params(save_path)
119 | print("{} saved!".format(save_path))
120 |
121 | def iteration(self, epoch, dataloader, train=True):
122 | total_loss = 0
123 | start_time = time.time() ## 得到当前时间
124 | step = 0
125 | for token_ids, target_ids in tqdm(dataloader, position=0, leave=True):
126 | step += 1
127 | if step % 1000 == 0:
128 | self.gpt_model.eval()
129 | test_data = ["梦见领袖", "梦见海盗抢东西", "梦见和自己的导师谈话"]
130 | for text in test_data:
131 | print(self.gpt_model.sample_generate(text, add_eos=True))
132 | self.gpt_model.train()
133 |
134 | # 因为传入了target标签,因此会计算loss并且返回
135 | loss, _ = self.gpt_model(token_ids,
136 | labels=target_ids,
137 | )
138 | # 反向传播
139 | if train:
140 | # 清空之前的梯度
141 | self.optimizer.zero_grad()
142 | # 反向传播, 获取新的梯度
143 | loss.backward()
144 | # 用获取的梯度更新模型参数
145 | self.optimizer.step()
146 |
147 | # 为计算当前epoch的平均loss
148 | total_loss += loss.item()
149 |
150 | end_time = time.time()
151 | spend_time = end_time - start_time
152 | # 打印训练信息
153 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
154 | # 保存模型
155 | self.save(model_save_path)
156 |
157 |
158 | if __name__ == '__main__':
159 |
160 | trainer = Trainer()
161 | train_epoches = 10
162 | for epoch in range(train_epoches):
163 | # 训练一个epoch
164 | trainer.train(epoch)
165 |
--------------------------------------------------------------------------------
/examples/nezha_couplets_train.py:
--------------------------------------------------------------------------------
1 | ## 自动对对联的例子
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | from torch.utils.data import Dataset, DataLoader
6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
7 | from bert_seq2seq import load_bert
8 |
9 | vocab_path = "./state_dict/nezha-base-www/vocab.txt" # roberta模型字典的位置
10 | model_name = "nezha" # 选择模型名字
11 | model_path = "./state_dict/nezha-base-www/pytorch_model.bin" # roberta模型位置
12 | recent_model_path = "" # 用于把已经训练好的模型继续训练
13 | model_save_path = "./nezha-duilian.bin"
14 | batch_size = 16
15 | lr = 1e-5
16 | data_dir = "./corpus/对联"
17 | word2idx = load_chinese_base_vocab(vocab_path)
18 |
19 |
20 | def read_corpus(dir_path):
21 | """
22 | 读原始数据
23 | """
24 | sents_src = []
25 | sents_tgt = []
26 | in_path = dir_path + "/in.txt"
27 | out_path = dir_path + "/out.txt"
28 | with open(in_path, "r", encoding="utf-8") as f:
29 | lines = f.readlines()
30 | for line in lines:
31 | sents_src.append(line.strip())
32 | with open(out_path, "r", encoding="utf-8") as f:
33 | lines = f.readlines()
34 | for line in lines:
35 | sents_tgt.append(line.strip())
36 |
37 | return sents_src, sents_tgt
38 |
39 |
40 | class BertDataset(Dataset):
41 | """
42 | 针对特定数据集,定义一个相关的取数据的方式
43 | """
44 |
45 | def __init__(self, sents_src, sents_tgt):
46 | ## 一般init函数是加载所有数据
47 | super(BertDataset, self).__init__()
48 | # 读原始数据
49 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
50 | self.sents_src = sents_src
51 | self.sents_tgt = sents_tgt
52 |
53 | self.idx2word = {k: v for v, k in word2idx.items()}
54 | self.tokenizer = Tokenizer(word2idx)
55 |
56 | def __getitem__(self, i):
57 | ## 得到单个数据
58 | # print(i)
59 | src = self.sents_src[i]
60 | tgt = self.sents_tgt[i]
61 | token_ids, token_type_ids = self.tokenizer.encode(src, tgt)
62 | output = {
63 | "token_ids": token_ids,
64 | "token_type_ids": token_type_ids,
65 | }
66 | return output
67 |
68 | def __len__(self):
69 | return len(self.sents_src)
70 |
71 |
72 | def collate_fn(batch):
73 | """
74 | 动态padding, batch为一部分sample
75 | """
76 |
77 | def padding(indice, max_length, pad_idx=0):
78 | """
79 | pad 函数
80 | """
81 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
82 | return torch.tensor(pad_indice)
83 |
84 | token_ids = [data["token_ids"] for data in batch]
85 | max_length = max([len(t) for t in token_ids])
86 | token_type_ids = [data["token_type_ids"] for data in batch]
87 |
88 | token_ids_padded = padding(token_ids, max_length)
89 | token_type_ids_padded = padding(token_type_ids, max_length)
90 | target_ids_padded = token_ids_padded[:, 1:].contiguous()
91 |
92 | return token_ids_padded, token_type_ids_padded, target_ids_padded
93 |
94 |
95 | class Trainer:
96 | def __init__(self):
97 | # 加载数据
98 | self.sents_src, self.sents_tgt = read_corpus(data_dir)
99 |
100 | # 判断是否有可用GPU
101 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102 | print("device: " + str(self.device))
103 | # 定义模型
104 | self.bert_model = load_bert(word2idx, model_name=model_name)
105 | ## 加载预训练的模型参数~
106 | self.bert_model.load_pretrain_params(model_path)
107 | # 将模型发送到计算设备(GPU或CPU)
108 | self.bert_model.set_device(self.device)
109 | # 声明需要优化的参数
110 | self.optim_parameters = list(self.bert_model.parameters())
111 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
112 | # 声明自定义的数据加载器
113 | dataset = BertDataset(self.sents_src, self.sents_tgt)
114 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
115 |
116 | def train(self, epoch):
117 | # 一个epoch的训练
118 | self.bert_model.train()
119 | self.iteration(epoch, dataloader=self.dataloader, train=True)
120 |
121 | def save(self, save_path):
122 | """
123 | 保存模型
124 | """
125 | self.bert_model.save_all_params(save_path)
126 | print("{} saved!".format(save_path))
127 |
128 | def iteration(self, epoch, dataloader, train=True):
129 | total_loss = 0
130 | start_time = time.time() ## 得到当前时间
131 | step = 0
132 | for token_ids, token_type_ids, target_ids in tqdm(dataloader, position=0, leave=True):
133 | step += 1
134 | if step % 5000 == 0:
135 | self.bert_model.eval()
136 | test_data = ["花海总涵功德水", "广汉飞霞诗似玉", "执政为民,德行天下"]
137 | for text in test_data:
138 | print(self.bert_model.generate(text, beam_size=3))
139 | self.bert_model.train()
140 |
141 | # 因为传入了target标签,因此会计算loss并且返回
142 | predictions, loss = self.bert_model(token_ids,
143 | token_type_ids,
144 | labels=target_ids,
145 | )
146 | # 反向传播
147 | if train:
148 | # 清空之前的梯度
149 | self.optimizer.zero_grad()
150 | # 反向传播, 获取新的梯度
151 | loss.backward()
152 | # 用获取的梯度更新模型参数
153 | self.optimizer.step()
154 |
155 | # 为计算当前epoch的平均loss
156 | total_loss += loss.item()
157 |
158 | end_time = time.time()
159 | spend_time = end_time - start_time
160 | # 打印训练信息
161 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
162 | # 保存模型
163 | self.save(model_save_path)
164 |
165 |
166 | if __name__ == '__main__':
167 |
168 | trainer = Trainer()
169 | train_epoches = 10
170 | for epoch in range(train_epoches):
171 | # 训练一个epoch
172 | trainer.train(epoch)
173 |
--------------------------------------------------------------------------------
/examples/readme.md:
--------------------------------------------------------------------------------
1 | ## 训练文件说明
2 |
3 | ### roberta、bert
4 | 1. [roberta_THUCNews_auto_title.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_THUCNews_auto_title.py) 自动摘要任务,使用THUCNews数据集,数据量较大。
5 | 2. [roberta_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_auto_title_train.py) 自动摘要任务,使用了另一个小的数据集。
6 | 3. [roberta_math_ques_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_math_ques_train.py) 自动解答小学数学题。
7 | 4. [relationship_classify_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/relationship_classify_train.py) 人物关系分类任务。
8 | 5. [roberta_semantic_matching_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_semantic_matching_train.py) 语义匹配任务。
9 | 6. [roberta_relation_extract_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_relation_extract_train.py) 三元组抽取任务。
10 | 7. [roberta_poem_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_poem_train.py) roberta模型自动写诗任务。
11 | 8. [roberta_participle_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_participle_CRF_train.py) 中文分词任务。
12 | 9. [roberta_medical_ner_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_medical_ner_train.py) NER任务,使用医学数据。
13 | 10. [roberta_couplets_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_couplets_train.py) roberta模型对联任务。
14 | 11. [roberta_news_classification_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_news_classification_train.py) 文本分类任务。
15 | 12. [roberta_coarsness_NER_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_coarsness_NER_CRF_train.py) 粗粒度NER任务,使用roberta+CRF。
16 | 13. [roberta_coarsness_NER_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/coarsness_NER_train.py) 粗粒度NER任务,使用roberta。
17 | 14. [roberta_fine_grained_NER_CRF_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_fine_grained_NER_CRF_train.py) 细粒度NER任务,使用Bert+CRF。
18 | 15. [roberta_large_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/roberta_large_auto_title_train.py) roberta-large模型,自动标题任务。
19 |
20 | ### nezha
21 | 1. [nezha_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_auto_title_train.py) nezha模型,自动摘要任务。
22 | 2. [nezha_relation_extract_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_relation_extract_train.py) nezha模型,关系抽取任务。
23 | 3. [nezha_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_auto_title_train.py) 华为nezha模型,自动摘要任务。
24 | 4. [nezha_couplets_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_couplets_train.py) 华为nezha模型,自动对联任务
25 |
26 | ### T5
27 | 1. [t5_ancient_translation_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/t5_ancient_translation_train.py) t5模型进行古文翻译。
28 | 2. [t5_auto_title_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/nezha_relation_extract_train.py) t5模型,自动标题任务。
29 |
30 | ### GPT-2
31 | 1. [gpt2_generate_article.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_generate_article.py) GPT-2自动生成文章任务。
32 | 2. [gpt2_explain_dream_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_explain_dream_train.py) gpt模型,使用周公解梦数据集。
33 | 3. [gpt2_ancient_translation_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_ancient_translation_train.py) gpt2模型进行古文翻译。
34 | 4. [gpt2_english_story_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/gpt2_english_story_train.py) gpt2模型自动生成英文故事。
35 |
36 | ### Simbert
37 | 1. [simbert_train.py](https://github.com/920232796/bert_seq2seq/blob/master/examples/simbert_train.py) SimBert模型生成相似句子。
--------------------------------------------------------------------------------
/examples/roberta_THUCNews_auto_title.py:
--------------------------------------------------------------------------------
1 | ## THUCNews 原始数据集
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | import glob
6 | from torch.utils.data import Dataset, DataLoader
7 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
8 | from bert_seq2seq import load_bert
9 |
10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
11 | word2idx = load_chinese_base_vocab(vocab_path)
12 | model_name = "roberta" # 选择模型名字
13 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # 模型位置
14 | recent_model_path = "./state_dict/bert_auto_title_model.bin" # 用于把已经训练好的模型继续训练
15 | model_save_path = "./state_dict/bert_auto_title_model.bin"
16 | batch_size = 16
17 | lr = 1e-5
18 | maxlen = 256
19 |
20 | class BertDataset(Dataset):
21 | """
22 | 针对特定数据集,定义一个相关的取数据的方式
23 | """
24 | def __init__(self) :
25 | ## 一般init函数是加载所有数据
26 | super(BertDataset, self).__init__()
27 | ## 拿到所有文件名字
28 | self.txts = glob.glob('./corpus/THUCNews/*/*.txt')
29 | self.idx2word = {k: v for v, k in word2idx.items()}
30 | self.tokenizer = Tokenizer(word2idx)
31 |
32 | def __getitem__(self, i):
33 | ## 得到单个数据
34 | # print(i)
35 | text_name = self.txts[i]
36 | with open(text_name, "r", encoding="utf-8") as f:
37 | text = f.read()
38 | text = text.split('\n')
39 | if len(text) > 1:
40 | title = text[0]
41 | content = '\n'.join(text[1:])
42 | token_ids, token_type_ids = self.tokenizer.encode(
43 | content, title, max_length=maxlen
44 | )
45 | output = {
46 | "token_ids": token_ids,
47 | "token_type_ids": token_type_ids,
48 | }
49 | return output
50 |
51 | return self.__getitem__(i + 1)
52 |
53 | def __len__(self):
54 |
55 | return len(self.txts)
56 |
57 | def collate_fn(batch):
58 | """
59 | 动态padding, batch为一部分sample
60 | """
61 |
62 | def padding(indice, max_length, pad_idx=0):
63 | """
64 | pad 函数
65 | """
66 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
67 | return torch.tensor(pad_indice)
68 |
69 | token_ids = [data["token_ids"] for data in batch]
70 | max_length = max([len(t) for t in token_ids])
71 | token_type_ids = [data["token_type_ids"] for data in batch]
72 |
73 | token_ids_padded = padding(token_ids, max_length)
74 | token_type_ids_padded = padding(token_type_ids, max_length)
75 | target_ids_padded = token_ids_padded[:, 1:].contiguous()
76 |
77 | return token_ids_padded, token_type_ids_padded, target_ids_padded
78 |
79 | class Trainer:
80 | def __init__(self):
81 | # 判断是否有可用GPU
82 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83 | print("device: " + str(self.device))
84 | # 定义模型
85 | self.bert_model = load_bert(word2idx, model_name=model_name)
86 | ## 加载预训练的模型参数~
87 |
88 | self.bert_model.load_pretrain_params(model_path)
89 | # 加载已经训练好的模型,继续训练
90 |
91 | # 将模型发送到计算设备(GPU或CPU)
92 | self.bert_model.set_device(self.device)
93 | # 声明需要优化的参数
94 | self.optim_parameters = list(self.bert_model.parameters())
95 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
96 | # 声明自定义的数据加载器
97 | dataset = BertDataset()
98 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
99 |
100 | def train(self, epoch):
101 | # 一个epoch的训练
102 | self.bert_model.train()
103 | self.iteration(epoch, dataloader=self.dataloader, train=True)
104 |
105 | def save(self, save_path):
106 | """
107 | 保存模型
108 | """
109 | self.bert_model.save_all_params(save_path)
110 | print("{} saved!".format(save_path))
111 |
112 | def iteration(self, epoch, dataloader, train=True):
113 | total_loss = 0
114 | start_time = time.time() ## 得到当前时间
115 | step = 0
116 | report_loss = 0
117 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True):
118 | step += 1
119 | if step % 1000 == 0:
120 | self.bert_model.eval()
121 | test_data = ["夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。",
122 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
123 | "8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。"]
124 | for text in test_data:
125 | print(self.bert_model.generate(text, beam_size=3))
126 | print("loss is " + str(report_loss))
127 | report_loss = 0
128 | # self.eval(epoch)
129 | self.bert_model.train()
130 | if step % 8000 == 0:
131 | self.save(model_save_path)
132 |
133 | # 因为传入了target标签,因此会计算loss并且返回
134 | predictions, loss = self.bert_model(token_ids,
135 | token_type_ids,
136 | labels=target_ids,
137 |
138 | )
139 | report_loss += loss.item()
140 | # 反向传播
141 | if train:
142 | # 清空之前的梯度
143 | self.optimizer.zero_grad()
144 | # 反向传播, 获取新的梯度
145 | loss.backward()
146 | # 用获取的梯度更新模型参数
147 | self.optimizer.step()
148 |
149 | # 为计算当前epoch的平均loss
150 | total_loss += loss.item()
151 |
152 | end_time = time.time()
153 | spend_time = end_time - start_time
154 | # 打印训练信息
155 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time))
156 | # 保存模型
157 | self.save(model_save_path)
158 |
159 | if __name__ == '__main__':
160 |
161 | trainer = Trainer()
162 | train_epoches = 20
163 |
164 | for epoch in range(train_epoches):
165 | # 训练一个epoch
166 | trainer.train(epoch)
--------------------------------------------------------------------------------
/examples/roberta_couplets_train.py:
--------------------------------------------------------------------------------
1 | ## 自动对对联的例子
2 | import sys
3 | import torch
4 | from tqdm import tqdm
5 | import time
6 | from torch.utils.data import Dataset, DataLoader
7 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
8 | from bert_seq2seq import load_bert
9 |
10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
11 | model_name = "roberta" # 选择模型名字
12 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置
13 | recent_model_path = "" # 用于把已经训练好的模型继续训练
14 | model_save_path = "./bert_duilian_model.bin"
15 | batch_size = 16
16 | lr = 1e-5
17 | data_dir = "./corpus/对联"
18 | word2idx = load_chinese_base_vocab(vocab_path)
19 |
20 | def read_corpus(dir_path):
21 | """
22 | 读原始数据
23 | """
24 | sents_src = []
25 | sents_tgt = []
26 | in_path = dir_path + "/in.txt"
27 | out_path = dir_path + "/out.txt"
28 | with open(in_path, "r", encoding="utf-8") as f:
29 | lines = f.readlines()
30 | for line in lines:
31 | sents_src.append(line.strip())
32 | with open(out_path, "r", encoding="utf-8") as f:
33 | lines = f.readlines()
34 | for line in lines:
35 | sents_tgt.append(line.strip())
36 |
37 | return sents_src, sents_tgt
38 |
39 | class BertDataset(Dataset):
40 | """
41 | 针对特定数据集,定义一个相关的取数据的方式
42 | """
43 | def __init__(self, sents_src, sents_tgt) :
44 | ## 一般init函数是加载所有数据
45 | super(BertDataset, self).__init__()
46 | # 读原始数据
47 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
48 | self.sents_src = sents_src
49 | self.sents_tgt = sents_tgt
50 |
51 | self.idx2word = {k: v for v, k in word2idx.items()}
52 | self.tokenizer = Tokenizer(word2idx)
53 |
54 | def __getitem__(self, i):
55 | ## 得到单个数据
56 | # print(i)
57 | src = self.sents_src[i]
58 | tgt = self.sents_tgt[i]
59 | token_ids, token_type_ids = self.tokenizer.encode(src, tgt)
60 | output = {
61 | "token_ids": token_ids,
62 | "token_type_ids": token_type_ids,
63 | }
64 | return output
65 |
66 | def __len__(self):
67 |
68 | return len(self.sents_src)
69 |
70 | def collate_fn(batch):
71 | """
72 | 动态padding, batch为一部分sample
73 | """
74 |
75 | def padding(indice, max_length, pad_idx=0):
76 | """
77 | pad 函数
78 | """
79 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
80 | return torch.tensor(pad_indice)
81 |
82 | token_ids = [data["token_ids"] for data in batch]
83 | max_length = max([len(t) for t in token_ids])
84 | token_type_ids = [data["token_type_ids"] for data in batch]
85 |
86 | token_ids_padded = padding(token_ids, max_length)
87 | token_type_ids_padded = padding(token_type_ids, max_length)
88 | target_ids_padded = token_ids_padded[:, 1:].contiguous()
89 |
90 | return token_ids_padded, token_type_ids_padded, target_ids_padded
91 |
92 | class Trainer:
93 | def __init__(self):
94 | # 加载数据
95 | self.sents_src, self.sents_tgt = read_corpus(data_dir)
96 |
97 | # 判断是否有可用GPU
98 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99 | print("device: " + str(self.device))
100 | # 定义模型
101 | self.bert_model = load_bert(word2idx, model_name=model_name)
102 | ## 加载预训练的模型参数~
103 | self.bert_model.load_pretrain_params(model_path)
104 | # 将模型发送到计算设备(GPU或CPU)
105 | self.bert_model.set_device(self.device)
106 | # 声明需要优化的参数
107 | self.optim_parameters = list(self.bert_model.parameters())
108 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
109 | # 声明自定义的数据加载器
110 | dataset = BertDataset(self.sents_src, self.sents_tgt)
111 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
112 |
113 | def train(self, epoch):
114 | # 一个epoch的训练
115 | self.bert_model.train()
116 | self.iteration(epoch, dataloader=self.dataloader, train=True)
117 |
118 | def save(self, save_path):
119 | """
120 | 保存模型
121 | """
122 | self.bert_model.save_all_params(save_path)
123 | print("{} saved!".format(save_path))
124 |
125 | def iteration(self, epoch, dataloader, train=True):
126 | total_loss = 0
127 | start_time = time.time() ## 得到当前时间
128 | step = 0
129 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True):
130 | step += 1
131 | if step % 5000 == 0:
132 | self.bert_model.eval()
133 | test_data = ["花海总涵功德水", "广汉飞霞诗似玉", "执政为民,德行天下"]
134 | for text in test_data:
135 | print(self.bert_model.generate(text, beam_size=3))
136 | self.bert_model.train()
137 |
138 | # 因为传入了target标签,因此会计算loss并且返回
139 | predictions, loss = self.bert_model(token_ids,
140 | token_type_ids,
141 | labels=target_ids,
142 | )
143 | # 反向传播
144 | if train:
145 | # 清空之前的梯度
146 | self.optimizer.zero_grad()
147 | # 反向传播, 获取新的梯度
148 | loss.backward()
149 | # 用获取的梯度更新模型参数
150 | self.optimizer.step()
151 |
152 | # 为计算当前epoch的平均loss
153 | total_loss += loss.item()
154 |
155 | end_time = time.time()
156 | spend_time = end_time - start_time
157 | # 打印训练信息
158 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time))
159 | # 保存模型
160 | self.save(model_save_path)
161 |
162 | if __name__ == '__main__':
163 |
164 | trainer = Trainer()
165 | train_epoches = 10
166 | for epoch in range(train_epoches):
167 | # 训练一个epoch
168 | trainer.train(epoch)
169 |
--------------------------------------------------------------------------------
/examples/roberta_news_classification_train.py:
--------------------------------------------------------------------------------
1 | ## 文本分类的例子
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | from torch.utils.data import Dataset, DataLoader
6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
7 | from bert_seq2seq import load_bert
8 |
9 | target = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"]
10 |
11 | data_path = "./corpus/新闻标题文本分类/Train.txt"
12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
13 | model_name = "roberta" # 选择模型名字
14 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置
15 | recent_model_path = "" # 用于把已经训练好的模型继续训练
16 | model_save_path = "./bert_multi_classify_model.bin"
17 | batch_size = 16
18 | lr = 1e-5
19 | # 加载字典
20 | word2idx = load_chinese_base_vocab(vocab_path)
21 |
22 | def read_corpus():
23 | """
24 | 读原始数据
25 | """
26 | sents_src = []
27 | sents_tgt = []
28 |
29 | with open(data_path) as f:
30 | lines = f.readlines()
31 | for line in lines:
32 | line = line.split("\t")
33 | sents_tgt.append(int(line[0]))
34 | sents_src.append(line[2])
35 | return sents_src, sents_tgt
36 |
37 | ## 自定义dataset
38 | class NLUDataset(Dataset):
39 | """
40 | 针对特定数据集,定义一个相关的取数据的方式
41 | """
42 | def __init__(self, sents_src, sents_tgt) :
43 | ## 一般init函数是加载所有数据
44 | super(NLUDataset, self).__init__()
45 | # 读原始数据
46 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
47 | self.sents_src = sents_src
48 | self.sents_tgt = sents_tgt
49 |
50 | self.idx2word = {k: v for v, k in word2idx.items()}
51 | self.tokenizer = Tokenizer(word2idx)
52 |
53 | def __getitem__(self, i):
54 | ## 得到单个数据
55 | # print(i)
56 | src = self.sents_src[i]
57 | tgt = self.sents_tgt[i]
58 | token_ids, token_type_ids = self.tokenizer.encode(src)
59 | output = {
60 | "token_ids": token_ids,
61 | "token_type_ids": token_type_ids,
62 | "target_id": tgt
63 | }
64 | return output
65 |
66 | def __len__(self):
67 | return len(self.sents_src)
68 |
69 | def collate_fn(batch):
70 | """
71 | 动态padding, batch为一部分sample
72 | """
73 |
74 | def padding(indice, max_length, pad_idx=0):
75 | """
76 | pad 函数
77 | """
78 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
79 | return torch.tensor(pad_indice)
80 |
81 | token_ids = [data["token_ids"] for data in batch]
82 | max_length = max([len(t) for t in token_ids])
83 | token_type_ids = [data["token_type_ids"] for data in batch]
84 | target_ids = [data["target_id"] for data in batch]
85 | target_ids = torch.tensor(target_ids, dtype=torch.long)
86 |
87 | token_ids_padded = padding(token_ids, max_length)
88 | token_type_ids_padded = padding(token_type_ids, max_length)
89 | # target_ids_padded = token_ids_padded[:, 1:].contiguous()
90 |
91 | return token_ids_padded, token_type_ids_padded, target_ids
92 |
93 | class Trainer:
94 | def __init__(self):
95 | # 加载数据
96 | self.sents_src, self.sents_tgt = read_corpus()
97 | self.tokenier = Tokenizer(word2idx)
98 | # 判断是否有可用GPU
99 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100 | print("device: " + str(self.device))
101 | # 定义模型
102 | self.bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target))
103 | ## 加载预训练的模型参数~
104 | self.bert_model.load_pretrain_params(model_path)
105 | # 将模型发送到计算设备(GPU或CPU)
106 | self.bert_model.set_device(self.device)
107 | # 声明需要优化的参数
108 | self.optim_parameters = list(self.bert_model.parameters())
109 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
110 | # 声明自定义的数据加载器
111 | dataset = NLUDataset(self.sents_src, self.sents_tgt)
112 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
113 |
114 | def train(self, epoch):
115 | # 一个epoch的训练
116 | self.bert_model.train()
117 | self.iteration(epoch, dataloader=self.dataloader, train=True)
118 |
119 | def save(self, save_path):
120 | """
121 | 保存模型
122 | """
123 | self.bert_model.save_all_params(save_path)
124 | print("{} saved!".format(save_path))
125 |
126 | def iteration(self, epoch, dataloader, train=True):
127 | total_loss = 0
128 | start_time = time.time() ## 得到当前时间
129 | step = 0
130 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True):
131 | step += 1
132 | if step % 2000 == 0:
133 | self.bert_model.eval()
134 | test_data = ["编剧梁馨月讨稿酬六六何念助阵 公司称协商解决", "西班牙BBVA第三季度净利降至15.7亿美元", "基金巨亏30亿 欲打开云天系跌停自救"]
135 | for text in test_data:
136 | text, text_ids = self.tokenier.encode(text)
137 | text = torch.tensor(text, device=self.device).view(1, -1)
138 | print(target[torch.argmax(self.bert_model(text)).item()])
139 | self.bert_model.train()
140 |
141 | # 因为传入了target标签,因此会计算loss并且返回
142 | predictions, loss = self.bert_model(token_ids,
143 | labels=target_ids,
144 | )
145 | # 反向传播
146 | if train:
147 | # 清空之前的梯度
148 | self.optimizer.zero_grad()
149 | # 反向传播, 获取新的梯度
150 | loss.backward()
151 | # 用获取的梯度更新模型参数
152 | self.optimizer.step()
153 |
154 | # 为计算当前epoch的平均loss
155 | total_loss += loss.item()
156 |
157 | end_time = time.time()
158 | spend_time = end_time - start_time
159 | # 打印训练信息
160 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time))
161 | # 保存模型
162 | self.save(model_save_path)
163 |
164 | if __name__ == '__main__':
165 |
166 | trainer = Trainer()
167 | train_epoches = 10
168 | for epoch in range(train_epoches):
169 | # 训练一个epoch
170 | trainer.train(epoch)
--------------------------------------------------------------------------------
/examples/roberta_semantic_matching_train.py:
--------------------------------------------------------------------------------
1 | # https://tianchi.aliyun.com/competition/entrance/531851/information
2 | import torch
3 | from tqdm import tqdm
4 | import time
5 | from torch.utils.data import Dataset, DataLoader
6 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
7 | from bert_seq2seq import load_bert
8 |
9 | target = [0, 1]
10 | train_path = "./data/语义匹配/train.tsv"
11 | test_path = "./data/语义匹配/test.tsv"
12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
13 | model_name = "roberta" # 选择模型名字
14 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置
15 | recent_model_path = "" # 用于把已经训练好的模型继续训练
16 | model_save_path = "./bert_semantic_matching.bin"
17 | batch_size = 16
18 | lr = 1e-5
19 | # 加载字典
20 | word2idx = load_chinese_base_vocab(vocab_path)
21 |
22 | def read_corpus(data_path):
23 | """
24 | 读原始数据
25 | """
26 | sents_src = []
27 | sents_tgt = []
28 |
29 | with open(data_path) as f:
30 | lines = f.readlines()
31 | for line in lines:
32 | line = line.split("\t")
33 | sents_tgt.append(int(line[2]))
34 | sents_src.append(line[0] + "#" +line[1])
35 | return sents_src, sents_tgt
36 |
37 | ## 自定义dataset
38 | class NLUDataset(Dataset):
39 | """
40 | 针对特定数据集,定义一个相关的取数据的方式
41 | """
42 | def __init__(self, sents_src, sents_tgt) :
43 | ## 一般init函数是加载所有数据
44 | super(NLUDataset, self).__init__()
45 | # 读原始数据
46 | self.sents_src = sents_src
47 | self.sents_tgt = sents_tgt
48 |
49 | self.idx2word = {k: v for v, k in word2idx.items()}
50 | self.tokenizer = Tokenizer(word2idx)
51 |
52 | def __getitem__(self, i):
53 | ## 得到单个数据
54 | # print(i)
55 | src = self.sents_src[i]
56 | tgt = self.sents_tgt[i]
57 | token_ids, token_type_ids = self.tokenizer.encode(src)
58 | output = {
59 | "token_ids": token_ids,
60 | "token_type_ids": token_type_ids,
61 | "target_id": tgt
62 | }
63 | return output
64 |
65 | def __len__(self):
66 | return len(self.sents_src)
67 |
68 | def collate_fn(batch):
69 | """
70 | 动态padding, batch为一部分sample
71 | """
72 |
73 | def padding(indice, max_length, pad_idx=0):
74 | """
75 | pad 函数
76 | """
77 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
78 | return torch.tensor(pad_indice)
79 |
80 | token_ids = [data["token_ids"] for data in batch]
81 | max_length = max([len(t) for t in token_ids])
82 | token_type_ids = [data["token_type_ids"] for data in batch]
83 | target_ids = [data["target_id"] for data in batch]
84 | target_ids = torch.tensor(target_ids, dtype=torch.long)
85 |
86 | token_ids_padded = padding(token_ids, max_length)
87 | token_type_ids_padded = padding(token_type_ids, max_length)
88 | # target_ids_padded = token_ids_padded[:, 1:].contiguous()
89 |
90 | return token_ids_padded, token_type_ids_padded, target_ids
91 |
92 | class Trainer:
93 | def __init__(self):
94 | # 加载数据
95 | self.sents_src, self.sents_tgt = read_corpus(train_path)
96 | self.tokenier = Tokenizer(word2idx)
97 | # 判断是否有可用GPU
98 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99 | print("device: " + str(self.device))
100 | # 定义模型
101 | self.bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target))
102 | ## 加载预训练的模型参数~
103 | self.bert_model.load_pretrain_params(model_path)
104 | # 将模型发送到计算设备(GPU或CPU)
105 | self.bert_model.set_device(self.device)
106 | # 声明需要优化的参数
107 | self.optim_parameters = list(self.bert_model.parameters())
108 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
109 | # 声明自定义的数据加载器
110 | dataset = NLUDataset(self.sents_src, self.sents_tgt)
111 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
112 |
113 | def train(self, epoch):
114 | # 一个epoch的训练
115 | self.bert_model.train()
116 | self.iteration(epoch, dataloader=self.dataloader, train=True)
117 |
118 | def save(self, save_path):
119 | """
120 | 保存模型
121 | """
122 | self.bert_model.save_all_params(save_path)
123 | print("{} saved!".format(save_path))
124 |
125 | def iteration(self, epoch, dataloader, train=True):
126 | total_loss = 0
127 | start_time = time.time() ## 得到当前时间
128 | step = 0
129 | for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True):
130 | step += 1
131 | if step % 2000 == 0:
132 | self.bert_model.eval()
133 | test_data = ["后悔了吗#你有没有后悔", "打开自动横屏#开启移动数据", "我觉得你很聪明#你聪明我是这么觉得"]
134 | for text in test_data:
135 | text, text_ids = self.tokenier.encode(text)
136 | text = torch.tensor(text, device=self.device).view(1, -1)
137 | print(target[torch.argmax(self.bert_model(text)).item()])
138 | self.bert_model.train()
139 | # 保存模型
140 | self.save(model_save_path)
141 |
142 | # 因为传入了target标签,因此会计算loss并且返回
143 | predictions, loss = self.bert_model(token_ids,
144 | labels=target_ids,
145 | )
146 | # 反向传播
147 | if train:
148 | # 清空之前的梯度
149 | self.optimizer.zero_grad()
150 | # 反向传播, 获取新的梯度
151 | loss.backward()
152 | # 用获取的梯度更新模型参数
153 | self.optimizer.step()
154 |
155 | # 为计算当前epoch的平均loss
156 | total_loss += loss.item()
157 |
158 | end_time = time.time()
159 | spend_time = end_time - start_time
160 | # 打印训练信息
161 | print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time))
162 |
163 |
164 | if __name__ == '__main__':
165 |
166 | trainer = Trainer()
167 | train_epoches = 10
168 | for epoch in range(train_epoches):
169 | # 训练一个epoch
170 | trainer.train(epoch)
--------------------------------------------------------------------------------
/examples/t5_ancient_translation_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import glob
4 | from torch.utils.data import Dataset, DataLoader
5 | from bert_seq2seq import T5PegasusTokenizer, load_chinese_base_vocab
6 | from bert_seq2seq import T5Model
7 |
8 | vocab_path = "./state_dict/t5-chinese/vocab.txt"
9 | model_path = "./state_dict/t5-chinese/pytorch_model.bin"
10 | model_save_path = "./state_dict/t5_ancient_trans_model.bin"
11 | batch_size = 8
12 | lr = 1e-5
13 | word2idx = load_chinese_base_vocab(vocab_path)
14 | tokenizer = T5PegasusTokenizer(word2idx)
15 |
16 |
17 | def read_corpus():
18 | """
19 | 读原始数据
20 | """
21 | src = []
22 | tgt = []
23 | data_path = glob.glob("./corpus/文言文翻译/*")
24 | for p in data_path:
25 | dir = p.split("/")[:-1]
26 | dir = "/".join(dir)
27 | # print(dir)
28 | name = p.split("/")[-1]
29 | if "翻译" in name:
30 | # 找到了一个翻译文件
31 | tgt_name = name
32 | src_name = name[:-2]
33 | with open(dir + "/" + src_name) as fs:
34 | lines = fs.readlines()
35 | for line in lines:
36 | src.append(line.strip("\n").strip())
37 |
38 | with open(dir + "/" + tgt_name) as ft:
39 | lines = ft.readlines()
40 | for line in lines:
41 | tgt.append(line.strip("\n").strip())
42 |
43 | else:
44 | pass
45 |
46 | return src, tgt
47 |
48 | class SeqDataset(Dataset):
49 | """
50 | 针对特定数据集,定义一个相关的取数据的方式
51 | """
52 |
53 | def __init__(self, sents_src, sents_tgt):
54 | ## 一般init函数是加载所有数据
55 | super(SeqDataset, self).__init__()
56 | # 读原始数据
57 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
58 | self.sents_src = sents_src
59 | self.sents_tgt = sents_tgt
60 |
61 | self.idx2word = {k: v for v, k in word2idx.items()}
62 |
63 | def __getitem__(self, i):
64 | ## 得到单个数据
65 | # print(i)
66 | src = self.sents_src[i]
67 | tgt = self.sents_tgt[i]
68 | token_ids_src, _ = tokenizer.encode(src, max_length=256)
69 | token_ids_tgt, _ = tokenizer.encode(tgt, max_length=256)
70 | output = {
71 | "token_ids_src": token_ids_src,
72 | "token_ids_tgt": token_ids_tgt,
73 | }
74 | return output
75 |
76 | def __len__(self):
77 | return len(self.sents_src)
78 |
79 |
80 | def collate_fn(batch):
81 | """
82 | 动态padding, batch为一部分sample
83 | """
84 |
85 | def padding(indice, max_length, pad_idx=0):
86 | """
87 | pad 函数
88 | """
89 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
90 | return torch.tensor(pad_indice)
91 |
92 | token_ids_src = [data["token_ids_src"] for data in batch]
93 | max_length_src = max([len(t) for t in token_ids_src])
94 | token_ids_tgt = [data["token_ids_tgt"] for data in batch]
95 | max_length_tgt = max([len(t) for t in token_ids_tgt])
96 |
97 | token_ids_padded = padding(token_ids_src, max_length_src)
98 | target_ids_padded = padding(token_ids_tgt, max_length_tgt)
99 | labels_ids = target_ids_padded.clone()
100 | labels_ids[labels_ids == 0] = -100
101 | target_ids_padded = target_ids_padded[:, :-1].contiguous()
102 | labels_ids = labels_ids[:, 1:].contiguous()
103 |
104 | return token_ids_padded, target_ids_padded, labels_ids
105 |
106 |
107 | class Trainer:
108 | def __init__(self):
109 | # 加载数据
110 | self.sents_src, self.sents_tgt = read_corpus()
111 |
112 | # 判断是否有可用GPU
113 | # self.device = torch.device("cpu")
114 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
115 | print("device: " + str(self.device))
116 | # 定义模型
117 | self.model = T5Model(word2idx)
118 | ## 加载预训练的模型参数~
119 | self.model.load_pretrain_params(model_path)
120 | # 将模型发送到计算设备(GPU或CPU)
121 | self.model.set_device(self.device)
122 | # 声明需要优化的参数
123 | self.optim_parameters = list(self.model.parameters())
124 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
125 | # 声明自定义的数据加载器
126 | dataset = SeqDataset(self.sents_src, self.sents_tgt)
127 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
128 |
129 | def train(self, epoch):
130 | # 一个epoch的训练
131 | self.model.train()
132 | self.iteration(epoch, dataloader=self.dataloader, train=True)
133 |
134 | def save(self, save_path):
135 | """
136 | 保存模型
137 | """
138 | self.model.save_all_params(save_path)
139 | print("{} saved!".format(save_path))
140 |
141 | def iteration(self, epoch, dataloader, train=True):
142 | total_loss = 0
143 | report_loss = 0
144 | start_time = time.time() ## 得到当前时间
145 | step = 0
146 | for token_ids, target_ids, labels_ids in dataloader:
147 | step += 1
148 | # print(token_ids.shape)
149 | # print(target_ids.shape)
150 | # print(labels_ids.shape)
151 | if step % 4000 == 0:
152 | self.save(model_save_path)
153 | self.model.eval()
154 | test_data = ["遂入颍川。", "会日暝,结陈相持。", "一言兴邦,斯近之矣。"]
155 | for text in test_data:
156 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True))
157 | self.model.train()
158 | print("report loss is " + str(report_loss))
159 |
160 | # 因为传入了target标签,因此会计算loss并且返回
161 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0]
162 | # 反向传播
163 | if train:
164 | # 清空之前的梯度
165 | self.optimizer.zero_grad()
166 | # 反向传播, 获取新的梯度
167 | loss.backward()
168 | # 用获取的梯度更新模型参数
169 | self.optimizer.step()
170 |
171 | # 为计算当前epoch的平均loss
172 | total_loss += loss.item()
173 | report_loss += loss.item()
174 |
175 | end_time = time.time()
176 | spend_time = end_time - start_time
177 | # 打印训练信息
178 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
179 | # 保存模型
180 | self.save(model_save_path)
181 |
182 |
183 | if __name__ == '__main__':
184 |
185 | trainer = Trainer()
186 | train_epoches = 10
187 | for epoch in range(train_epoches):
188 | # 训练一个epoch
189 | trainer.train(epoch)
190 |
--------------------------------------------------------------------------------
/examples/t5_auto_title_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import glob
4 | from torch.utils.data import Dataset, DataLoader
5 | from bert_seq2seq import T5PegasusTokenizer, load_chinese_base_vocab
6 | from bert_seq2seq import T5Model
7 | from tqdm import tqdm
8 |
9 | src_dir = './corpus/auto_title/train.src'
10 | tgt_dir = './corpus/auto_title/train.tgt'
11 |
12 | vocab_path = "./state_dict/t5-chinese/vocab.txt" ## 字典
13 | model_path = "./state_dict/t5-chinese/pytorch_model.bin" ## 预训练参数
14 |
15 | model_save_path = "./state_dict/t5_autotile.bin" ## 训练完模型 保存在哪里
16 | batch_size = 1
17 | lr = 1e-5
18 | word2idx = load_chinese_base_vocab(vocab_path)
19 | tokenizer = T5PegasusTokenizer(word2idx)
20 |
21 | def read_file(src_dir, tgt_dir):
22 | src = []
23 | tgt = []
24 |
25 | with open(src_dir,'r',encoding='utf-8') as f:
26 | lines = f.readlines()
27 |
28 | for line in lines:
29 | src.append(line.strip('\n').lower())
30 |
31 | with open(tgt_dir,'r',encoding='utf-8') as f:
32 | lines = f.readlines()
33 | for line in lines:
34 | tgt.append(line.strip('\n').lower())
35 |
36 | return src, tgt
37 |
38 |
39 | class SeqDataset(Dataset):
40 | """
41 | 针对特定数据集,定义一个相关的取数据的方式
42 | """
43 |
44 | def __init__(self, sents_src, sents_tgt):
45 | ## 一般init函数是加载所有数据
46 | super(SeqDataset, self).__init__()
47 | # 读原始数据
48 | # self.sents_src, self.sents_tgt = read_corpus(poem_corpus_dir)
49 | self.sents_src = sents_src
50 | self.sents_tgt = sents_tgt
51 |
52 | self.idx2word = {k: v for v, k in word2idx.items()}
53 |
54 | def __getitem__(self, i):
55 | ## 得到单个数据
56 | # print(i)
57 | src = self.sents_src[i]
58 | tgt = self.sents_tgt[i]
59 | token_ids_src, _ = tokenizer.encode(src, max_length=256)
60 | token_ids_tgt, _ = tokenizer.encode(tgt, max_length=256)
61 | output = {
62 | "token_ids_src": token_ids_src,
63 | "token_ids_tgt": token_ids_tgt,
64 | }
65 | return output
66 |
67 | def __len__(self):
68 | return len(self.sents_src)
69 |
70 |
71 | def collate_fn(batch):
72 | """
73 | 动态padding, batch为一部分sample
74 | """
75 |
76 | def padding(indice, max_length, pad_idx=0):
77 | """
78 | pad 函数
79 | """
80 | pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
81 | return torch.tensor(pad_indice)
82 |
83 | token_ids_src = [data["token_ids_src"] for data in batch]
84 | max_length_src = max([len(t) for t in token_ids_src])
85 | token_ids_tgt = [data["token_ids_tgt"] for data in batch]
86 | max_length_tgt = max([len(t) for t in token_ids_tgt])
87 |
88 | token_ids_padded = padding(token_ids_src, max_length_src)
89 | target_ids_padded = padding(token_ids_tgt, max_length_tgt)
90 | labels_ids = target_ids_padded.clone()
91 | labels_ids[labels_ids == 0] = -100
92 | target_ids_padded = target_ids_padded[:, :-1].contiguous()
93 | labels_ids = labels_ids[:, 1:].contiguous()
94 |
95 | return token_ids_padded, target_ids_padded, labels_ids
96 |
97 |
98 | class Trainer:
99 | def __init__(self):
100 | # 加载数据
101 | self.sents_src, self.sents_tgt = read_file(src_dir, tgt_dir)
102 |
103 | # 判断是否有可用GPU
104 | # self.device = torch.device("cpu")
105 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
106 | print("device: " + str(self.device))
107 | # 定义模型
108 | self.model = T5Model(word2idx)
109 | ## 加载预训练的模型参数~
110 | self.model.load_pretrain_params(model_path)
111 | # self.model.load_all_params(save_model_path)
112 | # 将模型发送到计算设备(GPU或CPU)
113 | self.model.set_device(self.device)
114 | # 声明需要优化的参数
115 | self.optim_parameters = list(self.model.parameters())
116 | self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
117 | # 声明自定义的数据加载器
118 | dataset = SeqDataset(self.sents_src, self.sents_tgt)
119 | self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
120 |
121 | def train(self, epoch):
122 | # 一个epoch的训练
123 | self.model.train()
124 | self.iteration(epoch, dataloader=self.dataloader, train=True)
125 |
126 | def save(self, save_path):
127 | """
128 | 保存模型
129 | """
130 | self.model.save_all_params(save_path)
131 | print("{} saved!".format(save_path))
132 |
133 | def iteration(self, epoch, dataloader, train=True):
134 | total_loss = 0
135 | report_loss = 0
136 | start_time = time.time() ## 得到当前时间
137 | step = 0
138 | for token_ids, target_ids, labels_ids in tqdm(dataloader, total=len(dataloader)):
139 | step += 1
140 | if step % 100 == 0:
141 | self.save(model_save_path)
142 | self.model.eval()
143 | test_data = ["本文总结了十个可穿戴产品的设计原则,而这些原则同样也是笔者认为是这个行业最吸引人的地方:1为人们解决重复性问题,2从人开始而不是从机器开始,3要引起注意但不要刻意,4提升用户能力而不是取代人",
144 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界,还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落,未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
145 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"]
146 |
147 | for text in test_data:
148 | print(self.model.sample_generate_encoder_decoder(text, add_eos=True, top_k=5))
149 | self.model.train()
150 | print("report loss is " + str(report_loss))
151 | report_loss = 0
152 |
153 | # 因为传入了target标签,因此会计算loss并且返回
154 | loss = self.model(token_ids,labels=labels_ids, decoder_input_ids=target_ids)[0]
155 | # 反向传播
156 | if train:
157 | # 清空之前的梯度
158 | self.optimizer.zero_grad()
159 | # 反向传播, 获取新的梯度
160 | loss.backward()
161 | # 用获取的梯度更新模型参数
162 | self.optimizer.step()
163 |
164 | # 为计算当前epoch的平均loss
165 | total_loss += loss.item()
166 | report_loss += loss.item()
167 |
168 | end_time = time.time()
169 | spend_time = end_time - start_time
170 | # 打印训练信息
171 | print("epoch is " + str(epoch) + ". loss is " + str(total_loss) + ". spend time is " + str(spend_time))
172 | # 保存模型
173 | self.save(model_save_path)
174 |
175 |
176 | if __name__ == '__main__':
177 |
178 | trainer = Trainer()
179 | train_epoches = 10
180 | for epoch in range(train_epoches):
181 | # 训练一个epoch
182 | trainer.train(epoch)
183 |
--------------------------------------------------------------------------------
/img/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/.DS_Store
--------------------------------------------------------------------------------
/img/fenci.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/fenci.png
--------------------------------------------------------------------------------
/img/ner-input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner-input.png
--------------------------------------------------------------------------------
/img/ner-out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner-out.png
--------------------------------------------------------------------------------
/img/ner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/920232796/bert_seq2seq/c7988b01e3e69d66a061b28974ff9cc8fc4a36de/img/ner.jpg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='bert_seq2seq',
5 | version='2.3.5',
6 | description='use torch to do bert_seq2seq task',
7 | long_description='bert_seq2seq: https://github.com/920232796/bert_seq2seq',
8 | license='Apache License 2.0',
9 | url='https://github.com/920232796/bert_seq2seq',
10 | author='xingzhaohu',
11 | author_email='920232796@qq.com',
12 | packages=find_packages()
13 | )
14 |
--------------------------------------------------------------------------------
/test/auto_title_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | auto_title_model = "./state_dict/bert_auto_title_model2.bin"
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | if __name__ == "__main__":
9 | vocab_path = "../state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
10 | model_name = "roberta" # 选择模型名字
11 | # 加载字典
12 | word2idx = load_chinese_base_vocab(vocab_path)
13 | # 定义模型
14 | bert_model = load_bert(word2idx, model_name=model_name)
15 | bert_model.set_device(device)
16 | bert_model.eval()
17 | ## 加载训练的模型参数~
18 | bert_model.load_all_params(model_path=auto_title_model, device=device)
19 |
20 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
21 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
22 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
23 | for text in test_data:
24 | with torch.no_grad():
25 | print(bert_model.generate(text, beam_size=3))
26 | print("\n")
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/test/bert_english_autotitle_test.py:
--------------------------------------------------------------------------------
1 | ## 英文自动摘要测试文件
2 | import torch
3 | import glob
4 | import json
5 | from rouge import Rouge
6 | from bert_seq2seq import load_bert
7 | from transformers import AutoTokenizer
8 |
9 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10 | word2idx = tokenizer.get_vocab()
11 | auto_title_model = "./state_dict/bert_english_auto_title_model.bin"
12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | maxlen = 256
14 |
15 | if __name__ == "__main__":
16 | model_name = "bert" # 选择模型名字
17 | # 定义模型
18 | bert_model = load_bert(word2idx, tokenizer=tokenizer, model_name=model_name)
19 | bert_model.set_device(device)
20 | bert_model.eval()
21 | ## 加载训练的模型参数~
22 | bert_model.load_all_params(model_path=auto_title_model, device=device)
23 | rouge = Rouge()
24 | test_file = glob.glob("./corpus/english_autotitle_test/*.json")
25 | num_file = len(test_file)
26 | rouge_1_item = [0.0, 0.0, 0.0]
27 | with open("./auto_title_res.txt", "a+") as fw:
28 | for s_file in test_file :
29 | with open(s_file, "r") as f:
30 | c = f.read()
31 | j = json.loads(c)
32 | title = j["Title"]
33 | text = j["abstract"]
34 | out = bert_model.generate(text, beam_size=3, out_max_length=100, max_length=maxlen)
35 | print(out)
36 | fw.write(title + "\t" + out + "\t" + text + "\n")
37 |
38 | rouge_score = rouge.get_scores(title, out)
39 | print(rouge_score)
40 | rouge_1 = rouge_score[0]["rouge-1"]
41 | rouge_1_item[0] += rouge_1["f"]
42 | rouge_1_item[1] += rouge_1["p"]
43 | rouge_1_item[2] += rouge_1["r"]
44 | # print(rouge_score[0]["rouge-2"])
45 | # print(rouge_score[0]["rouge-l"])
46 | for i in range(len(rouge_1_item)):
47 | rouge_1_item[i] = rouge_1_item[i] / num_file
48 |
49 |
50 | print(rouge_1_item)
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/test/english_t5_test.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4 | from bert_seq2seq.extend_model_method import ExtendModel
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | if __name__ == "__main__":
9 |
10 | tokenizer = AutoTokenizer.from_pretrained("/Users/xingzhaohu/Downloads/t5_test")
11 | model = AutoModelForSeq2SeqLM.from_pretrained("/Users/xingzhaohu/Downloads/t5_test")
12 | model.eval()
13 | model.to(device)
14 | model = ExtendModel(model, tokenizer, bos_id=0, eos_id=1)
15 | print(model.sample_generate_encoder_decoder("translate English to German: That is good", out_max_length=300, add_eos=True))
16 |
17 |
18 |
--------------------------------------------------------------------------------
/test/get_bert_embedding.py:
--------------------------------------------------------------------------------
1 | ## 使用bert对一个句子进行编码
2 |
3 | import torch
4 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
5 | from bert_seq2seq import load_bert
6 |
7 | model_path = "./state_dict/roberta_wwm_pytorch_model.bin"
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | if __name__ == "__main__":
11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
12 | model_name = "roberta" # 选择模型名字
13 | # 加载字典
14 | word2idx = load_chinese_base_vocab(vocab_path)
15 | # 定义模型
16 | bert_model = load_bert(word2idx, model_name=model_name, model_class="embedding")
17 | bert_model.set_device(device)
18 | bert_model.eval()
19 | ## 加载训练的模型参数~
20 | bert_model.load_pretrain_params(model_path)
21 |
22 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
23 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
24 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
25 | for text in test_data:
26 | with torch.no_grad():
27 | print(bert_model(text).shape)
28 | print("\n")
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/test/gpt_ancient_translation_test.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq import load_gpt
4 | from bert_seq2seq import load_chinese_base_vocab
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | vocab_path = "./state_dict/gpt2通用中文模型/vocab.txt"
9 | model_path = "./state_dict/gpt_ancient_trans_model.bin"
10 |
11 | if __name__ == "__main__":
12 | word2ix = load_chinese_base_vocab(vocab_path)
13 | model = load_gpt(word2ix)
14 | model.eval()
15 | model.set_device(device)
16 | model.load_all_params(model_path)
17 |
18 | print(model.sample_generate("余忆童稚时,能张目对日。", out_max_length=300, add_eos=True))
--------------------------------------------------------------------------------
/test/gpt_english_story_test.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq import load_gpt
4 | from transformers import AutoTokenizer
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | model_path = "./state_dict/gpt_auto_story.bin"
9 |
10 | if __name__ == "__main__":
11 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
12 | word2ix = tokenizer.get_vocab()
13 | model = load_gpt(word2ix, tokenizer=tokenizer)
14 | model.eval()
15 | model.set_device(device)
16 | model.load_all_params(model_path, device=device)
17 |
18 | print(model.sample_generate_english("Strong Winds", out_max_length=300, add_eos=True))
--------------------------------------------------------------------------------
/test/gpt_explain_dream_test.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq import load_gpt
4 | from bert_seq2seq import load_chinese_base_vocab
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | vocab_path = "./state_dict/gpt_vocab.txt"
9 | model_path = "./state_dict/gpt_explain_dream_model.bin"
10 |
11 | if __name__ == "__main__":
12 | word2ix = load_chinese_base_vocab(vocab_path)
13 | model = load_gpt(word2ix)
14 | model.eval()
15 | model.set_device(device)
16 | model.load_all_params(model_path)
17 |
18 | print(model.sample_generate("梦见天气很好", out_max_length=300, add_eos=True))
--------------------------------------------------------------------------------
/test/gpt_test_english.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bert_seq2seq import load_gpt
4 | from transformers import AutoTokenizer
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | model_path = "./state_dict/english_gpt_model/english_gpt_story.bin"
9 |
10 | if __name__ == "__main__":
11 | tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
12 | word2ix = tokenizer.get_vocab()
13 | model = load_gpt(word2ix, tokenizer=tokenizer)
14 | model.eval()
15 | model.set_device(device)
16 | model.load_pretrain_params(model_path)
17 |
18 | print(model.sample_generate_english("Nice weather today", out_max_length=300))
--------------------------------------------------------------------------------
/test/nezha_auto_title_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | auto_title_model = "./state_dict/nezha_auto_title.bin"
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | if __name__ == "__main__":
9 | vocab_path = "./state_dict/nezha-base-www/vocab.txt" # roberta模型字典的位置
10 | model_name = "nezha" # 选择模型名字
11 | # 加载字典
12 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
13 | # 定义模型
14 | bert_model = load_bert(word2idx, model_name=model_name)
15 | bert_model.set_device(device)
16 | bert_model.eval()
17 | ## 加载训练的模型参数~
18 | bert_model.load_all_params(model_path=auto_title_model, device=device)
19 |
20 | test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
21 | "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
22 | "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
23 | for text in test_data:
24 | with torch.no_grad():
25 | print(bert_model.generate(text, beam_size=3))
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/test/nezha_relation_extract_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import json
4 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
5 | from bert_seq2seq import load_bert
6 |
7 | relation_extrac_model = "./state_dict/nezha_relation_extract.bin"
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
10 | model_name = "nezha" # 选择模型名字
11 | # model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位
12 | # 加载字典
13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
14 | tokenizer = Tokenizer(word2idx)
15 | idx2word = {v: k for k, v in word2idx.items()}
16 |
17 | predicate2id, id2predicate = {}, {}
18 | with open('./corpus/三元组抽取/all_50_schemas') as f:
19 | for l in f:
20 | l = json.loads(l)
21 | if l['predicate'] not in predicate2id:
22 | id2predicate[len(predicate2id)] = l['predicate']
23 | predicate2id[l['predicate']] = len(predicate2id)
24 |
25 |
26 | def search(pattern, sequence):
27 | """从sequence中寻找子串pattern
28 | 如果找到,返回第一个下标;否则返回-1。
29 | """
30 | n = len(pattern)
31 | for i in range(len(sequence)):
32 | if sequence[i:i + n] == pattern:
33 | return i
34 | return -1
35 |
36 | def search_subject(token_ids, subject_labels):
37 | # subject_labels: (lens, 2)
38 | if type(subject_labels) is torch.Tensor:
39 | subject_labels = subject_labels.numpy()
40 | if type(token_ids) is torch.Tensor:
41 | token_ids = token_ids.cpu().numpy()
42 | subjects = []
43 | subject_ids = []
44 | start = -1
45 | end = -1
46 | for i in range(len(token_ids)):
47 | if subject_labels[i, 0] > 0.5:
48 | start = i
49 | for j in range(len(token_ids)):
50 | if subject_labels[j, 1] > 0.5:
51 | subject_labels[j, 1] = 0
52 | end = j
53 | break
54 | if start == -1 or end == -1:
55 | continue
56 | subject = ""
57 | for k in range(start, end + 1):
58 | subject += idx2word[token_ids[k]]
59 | # print(subject)
60 | subject_ids.append([start, end])
61 | start = -1
62 | end = -1
63 | subjects.append(subject)
64 |
65 | return subjects, subject_ids
66 |
67 | def search_object(token_ids, object_labels):
68 | objects = []
69 | if type(object_labels) is torch.Tensor:
70 | object_labels = object_labels.numpy()
71 | if type(token_ids) is torch.Tensor:
72 | token_ids = token_ids.cpu().numpy()
73 | # print(object_labels.sum())
74 | start = np.where(object_labels[:, :, 0] > 0.5)
75 | end = np.where(object_labels[:, :, 1] > 0.5)
76 | # print(start)
77 | # print(end)
78 | for _start, predicate1 in zip(*start):
79 | for _end, predicate2 in zip(*end):
80 | if _start <= _end and predicate1 == predicate2:
81 | object_text = ""
82 | for k in range(_start, _end + 1):
83 | # print(token_ids(k))
84 | object_text += idx2word[token_ids[k]]
85 | objects.append(
86 | (id2predicate[predicate1], object_text)
87 | )
88 | break
89 |
90 | return objects
91 |
92 | if __name__ == "__main__":
93 |
94 | # 定义模型
95 | bert_model = load_bert(word2idx, model_class="relation_extrac", model_name=model_name, target_size=len(predicate2id))
96 | bert_model.eval()
97 | bert_model.set_device(device)
98 | # ## 加载预训练的模型参数~
99 | checkpoint = torch.load(relation_extrac_model, map_location="cpu")
100 | # print(checkpoint)
101 | bert_model.load_all_params(model_path=relation_extrac_model, device=device)
102 | text = ["查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
103 | "《星空黑夜传奇》是连载于起点中文网的网络小说,作者是啤酒的罪孽",
104 | "《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧",
105 | "杨铁心和郭啸天兄弟二人在牛家村的农屋里喝酒,他们的岳飞大将军在风波亭被害之事,二人希望能够像岳飞大将军一样精忠报国。"]
106 |
107 | for d in text:
108 | with torch.no_grad():
109 | token_ids_test, segment_ids = tokenizer.encode(d, max_length=256)
110 | token_ids_test = torch.tensor(token_ids_test, device=device).view(1, -1)
111 | # 先预测subject
112 | pred_subject = bert_model.predict_subject(token_ids_test)
113 | pred_subject = pred_subject.squeeze(0)
114 | subject_texts, subject_idss = search_subject(token_ids_test[0], pred_subject.cpu())
115 | if len(subject_texts) == 0:
116 | print("no subject predicted~")
117 | for sub_text, sub_ids in zip(subject_texts, subject_idss):
118 | print("subject is " + str(sub_text))
119 | sub_ids = torch.tensor(sub_ids, device=device).view(1, -1)
120 | # print("sub_ids shape is " + str(sub_ids))
121 | object_p_pred = bert_model.predict_object_predicate(token_ids_test, sub_ids)
122 | res = search_object(token_ids_test[0], object_p_pred.squeeze(0).cpu())
123 | print("p and obj is " + str(res))
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/test/poem_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | auto_title_model = "./state_dict/bert_model_poem_ci_duilian.bin"
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | if __name__ == "__main__":
9 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
10 | model_name = "roberta" # 选择模型名字
11 | # model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位
12 | # 加载字典
13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
14 | # 定义模型
15 | bert_model = load_bert(word2idx, model_name=model_name)
16 | bert_model.set_device(device)
17 | bert_model.eval()
18 | # ## 加载预训练的模型参数~
19 | checkpoint = torch.load(auto_title_model, map_location="cpu")
20 | # print(checkpoint)
21 | bert_model.load_all_params(model_path=auto_title_model, device=device)
22 | test_data = ["江山竞秀,万里风光入画图##对联"]
23 | with torch.no_grad():
24 | for text in test_data:
25 | if text[-1] == "句" or text[-1] == "诗":
26 | print(bert_model.generate(text, beam_size=3, is_poem=True))
27 | else:
28 | print(bert_model.generate(text, beam_size=3, is_poem=False))
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/test/relation_extract_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import numpy as np
4 | import json
5 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
6 | from bert_seq2seq import load_bert
7 |
8 | relation_extrac_model = "./state_dict/bert_model_relation_extrac.bin"
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
11 | model_name = "roberta" # 选择模型名字
12 | # 加载字典
13 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
14 | tokenizer = Tokenizer(word2idx)
15 | idx2word = {v: k for k, v in word2idx.items()}
16 |
17 | predicate2id, id2predicate = {}, {}
18 | with open('./corpus/三元组抽取/all_50_schemas') as f:
19 | for l in f:
20 | l = json.loads(l)
21 | if l['predicate'] not in predicate2id:
22 | id2predicate[len(predicate2id)] = l['predicate']
23 | predicate2id[l['predicate']] = len(predicate2id)
24 |
25 |
26 | def search(pattern, sequence):
27 | """从sequence中寻找子串pattern
28 | 如果找到,返回第一个下标;否则返回-1。
29 | """
30 | n = len(pattern)
31 | for i in range(len(sequence)):
32 | if sequence[i:i + n] == pattern:
33 | return i
34 | return -1
35 |
36 | def search_subject(token_ids, subject_labels):
37 | # subject_labels: (lens, 2)
38 | if type(subject_labels) is torch.Tensor:
39 | subject_labels = subject_labels.numpy()
40 | if type(token_ids) is torch.Tensor:
41 | token_ids = token_ids.cpu().numpy()
42 | subjects = []
43 | subject_ids = []
44 | start = -1
45 | end = -1
46 | for i in range(len(token_ids)):
47 | if subject_labels[i, 0] > 0.5:
48 | start = i
49 | for j in range(len(token_ids)):
50 | if subject_labels[j, 1] > 0.5:
51 | subject_labels[j, 1] = 0
52 | end = j
53 | break
54 | if start == -1 or end == -1:
55 | continue
56 | subject = ""
57 | for k in range(start, end + 1):
58 | subject += idx2word[token_ids[k]]
59 | # print(subject)
60 | subject_ids.append([start, end])
61 | start = -1
62 | end = -1
63 | subjects.append(subject)
64 |
65 | return subjects, subject_ids
66 |
67 | def search_object(token_ids, object_labels):
68 | objects = []
69 | if type(object_labels) is torch.Tensor:
70 | object_labels = object_labels.numpy()
71 | if type(token_ids) is torch.Tensor:
72 | token_ids = token_ids.cpu().numpy()
73 | # print(object_labels.sum())
74 | start = np.where(object_labels[:, :, 0] > 0.5)
75 | end = np.where(object_labels[:, :, 1] > 0.5)
76 | # print(start)
77 | # print(end)
78 | for _start, predicate1 in zip(*start):
79 | for _end, predicate2 in zip(*end):
80 | if _start <= _end and predicate1 == predicate2:
81 | object_text = ""
82 | for k in range(_start, _end + 1):
83 | # print(token_ids(k))
84 | object_text += idx2word[token_ids[k]]
85 | objects.append(
86 | (id2predicate[predicate1], object_text)
87 | )
88 | break
89 |
90 | return objects
91 |
92 | if __name__ == "__main__":
93 |
94 | # 定义模型
95 | bert_model = load_bert(word2idx, model_class="relation_extrac", model_name=model_name, target_size=len(predicate2id))
96 | bert_model.eval()
97 | bert_model.set_device(device)
98 | # ## 加载预训练的模型参数~
99 | checkpoint = torch.load(relation_extrac_model, map_location="cpu")
100 | # print(checkpoint)
101 | bert_model.load_all_params(model_path=relation_extrac_model, device=device)
102 | text = ["查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
103 | "《星空黑夜传奇》是连载于起点中文网的网络小说,作者是啤酒的罪孽",
104 | "《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧",
105 | "杨铁心和郭啸天兄弟二人在牛家村的农屋里喝酒,他们的岳飞大将军在风波亭被害之事,二人希望能够像岳飞大将军一样精忠报国。"]
106 |
107 | for d in text:
108 | with torch.no_grad():
109 | token_ids_test, segment_ids = tokenizer.encode(d, max_length=256)
110 | token_ids_test = torch.tensor(token_ids_test, device=device).view(1, -1)
111 | # 先预测subject
112 | pred_subject = bert_model.predict_subject(token_ids_test)
113 | pred_subject = pred_subject.squeeze(0)
114 | subject_texts, subject_idss = search_subject(token_ids_test[0], pred_subject.cpu())
115 | if len(subject_texts) == 0:
116 | print("no subject predicted~")
117 | for sub_text, sub_ids in zip(subject_texts, subject_idss):
118 | print("subject is " + str(sub_text))
119 | sub_ids = torch.tensor(sub_ids, device=device).view(1, -1)
120 | # print("sub_ids shape is " + str(sub_ids))
121 | object_p_pred = bert_model.predict_object_predicate(token_ids_test, sub_ids)
122 | res = search_object(token_ids_test[0], object_p_pred.squeeze(0).cpu())
123 | print("p and obj is " + str(res))
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/test/semantic_matching_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | target = ["0", "1"]
6 |
7 | cls_model = "./state_dict/bert_semantic_matching.bin"
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | if __name__ == "__main__":
11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
12 | model_name = "roberta" # 选择模型名字
13 | # 加载字典
14 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
15 | tokenizer = Tokenizer(word2idx)
16 | # 定义模型
17 | bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target))
18 | bert_model.set_device(device)
19 | bert_model.eval()
20 | ## 加载训练的模型参数~
21 | bert_model.load_all_params(model_path=cls_model, device=device)
22 | test_data = ["你是不是我仇人#你是俺的仇人吗",
23 | "这个就没意思了#我没别的意思",
24 | "查一下我的家在哪里#家在哪里?"]
25 | for text in test_data:
26 | with torch.no_grad():
27 | text_ids, _ = tokenizer.encode(text)
28 | text_ids = torch.tensor(text_ids, device=device).view(1, -1)
29 | print(text + " -> res is " + str(target[torch.argmax(bert_model(text_ids)).item()]))
30 |
--------------------------------------------------------------------------------
/test/t5_chinese_autotitle_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, T5PegasusTokenizer
3 | from bert_seq2seq.extend_model_method import ExtendModel
4 | import glob
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | from bert_seq2seq.t5_ch import T5Model
9 |
10 | vocab_path = "./state_dict/t5-chinese/vocab.txt"
11 | model_path = "./state_dict/t5_autotile.bin"
12 | word2idx = load_chinese_base_vocab(vocab_path)
13 |
14 | model = T5Model(word2idx, size="base")
15 | model.set_device(device)
16 | model.load_all_params(model_path)
17 | model.eval()
18 |
19 | all_txt = glob.glob("./*.txt")
20 | print(all_txt)
21 | for t in all_txt:
22 | with open(t, encoding="utf-8") as f:
23 | content = f.read()
24 | out = model.sample_generate_encoder_decoder(content)
25 | print(out)
26 |
27 | # for t in all_txt:
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/test/t5_chinese_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq.tokenizer import load_chinese_base_vocab, T5PegasusTokenizer
3 | from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration
4 | from bert_seq2seq.extend_model_method import ExtendModel
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | # transformer t5 代码
9 | model_path = './state_dict/t5-chinese'
10 | model = MT5ForConditionalGeneration.from_pretrained(model_path)
11 | word2ix = load_chinese_base_vocab("./state_dict/t5-chinese/vocab.txt")
12 | tokenizer = T5PegasusTokenizer(word2ix)
13 | model.eval()
14 | model = ExtendModel(model, tokenizer=tokenizer, bos_id=word2ix["[CLS]"], eos_id=word2ix["[SEP]"], device=device)
15 | text = '从那之后,一发不可收拾。此后的近百年间,一共有十七位新娘在与君山一带失踪。有时十几年相安无事,有时短短一个月内失踪两名。一个恐怖传说迅速传开:与君山里住着一位鬼新郎,若是他看中了一位女子,便会在她出嫁的路上将她掳走,再把送亲的队伍吃掉。'
16 | out = model.sample_generate_encoder_decoder(text)
17 | print(out)
18 |
19 | # 加载自己t5代码
20 | from bert_seq2seq.t5_ch import T5Model
21 | vocab_path = "./state_dict/t5-chinese/vocab.txt"
22 | model = T5Model(vocab_path, size="base")
23 | model.set_device(device)
24 | model.load_pretrain_params("./state_dict/t5-chinese/pytorch_model.bin")
25 | model.eval()
26 | text = '从那之后,一发不可收拾。此后的近百年间,一共有十七位新娘在与君山一带失踪。有时十几年相安无事,有时短短一个月内失踪两名。一个恐怖传说迅速传开:与君山里住着一位鬼新郎,若是他看中了一位女子,便会在她出嫁的路上将她掳走,再把送亲的队伍吃掉。'
27 | out = model.sample_generate_encoder_decoder(text)
28 | print(out)
29 |
30 |
31 |
--------------------------------------------------------------------------------
/test/test_paddle/bert_couplet_test_paddle.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/home/bert_seq2seq/paddle_model")
3 | import paddle
4 | import numpy as np
5 | from paddle import nn
6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, BPETokenizer, CTRLTokenizer
7 | from paddlenlp.transformers.bert.modeling import BertSeq2Seq
8 | import paddle.nn.functional as F
9 | from tqdm import tqdm
10 | from paddle.io import Dataset
11 |
12 |
13 | class Predictor:
14 |
15 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ):
16 | self.out_max_length = out_max_length
17 | self.beam_size = beam_size
18 | self.max_length = max_length
19 | self.tokenizer = tokenizer
20 | self.model = model
21 |
22 | def generate(self, text):
23 | self.model.eval()
24 | input_max_length = self.max_length - self.out_max_length
25 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length)
26 | vocab = self.tokenizer.vocab
27 | token_ids = tokenizer_out["input_ids"]
28 | token_type_ids = tokenizer_out["token_type_ids"]
29 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1])
30 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1])
31 |
32 | # print(f"token_ids is {token_ids}")
33 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size)
34 | # print(out_puts_ids)
35 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids)
36 |
37 | return self.tokenizer.convert_tokens_to_string(tokens)
38 |
39 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1,):
40 | """
41 | beam-search操作
42 | """
43 | sep_id = word2ix["[SEP]"]
44 |
45 | # 用来保存输出序列
46 | # output_ids = paddle.empty([1, 0]).astype("int")
47 | output_ids = None
48 | # output_ids = np.empty([1, 0]).astype(np.int)
49 | # 用来保存累计得分
50 | with paddle.no_grad():
51 | output_scores = np.zeros([token_ids.shape[0]])
52 | for step in range(self.out_max_length):
53 | if step == 0:
54 | scores = self.model(token_ids, token_type_ids)
55 | # 重复beam-size次 输入ids
56 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1])
57 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1])
58 | else:
59 | scores = self.model(new_input_ids, new_token_type_ids)
60 |
61 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy()
62 |
63 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分
64 | ## 取topk的时候我们是展平了然后再去调用topk函数
65 | # 展平
66 | logit_score = logit_score.reshape([-1])
67 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:]
68 | hype_score = logit_score[hype_pos]
69 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引
70 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引
71 |
72 | output_scores = hype_score
73 | if output_ids is None:
74 | output_ids = indice2.reshape([beam_size, 1])
75 | else :
76 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int)
77 |
78 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1)
79 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1)
80 |
81 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记
82 | best_one = output_scores.argmax()
83 | if end_counts[best_one] == 1:
84 | # 说明出现终止了~
85 | return output_ids[best_one][:-1]
86 | else :
87 | # 保留未完成部分
88 | flag = (end_counts < 1) # 标记未完成序列
89 | if not flag.all(): # 如果有已完成的
90 | token_ids = token_ids[flag]
91 | token_type_ids = token_type_ids[flag]
92 | new_input_ids = new_input_ids[flag]
93 | new_token_type_ids = new_token_type_ids[flag]
94 | output_ids = output_ids[flag] # 扔掉已完成序列
95 | output_scores = output_scores[flag] # 扔掉已完成序列
96 | beam_size = flag.sum() # topk相应变化
97 |
98 | return output_ids[output_scores.argmax()]
99 |
100 |
101 |
102 |
103 | def returnForecast(data,modelAddress = None,tokenizerAddress = None):
104 |
105 | if modelAddress is None:
106 | model = BertSeq2Seq.from_pretrained('./model')
107 |
108 | else:
109 | model = BertSeq2Seq.from_pretrained(modelAddress)
110 |
111 |
112 | if tokenizerAddress is None:
113 | tokenizer = BertTokenizer.from_pretrained('./tokenizer')
114 |
115 | else:
116 | tokenizer = BertTokenizer.from_pretrained(tokenizerAddress)
117 |
118 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512)
119 |
120 |
121 | import collections
122 |
123 |
124 | OutCouplet = {}
125 |
126 | OutCouplet = collections.OrderedDict()
127 |
128 | for simply_in in data:
129 |
130 | out = predictor.generate(simply_in)
131 | OutCouplet[simply_in] = out
132 |
133 | return OutCouplet
134 |
135 | def forecastForCouplet(data,modelAddress = None,tokenizerAddress = None):
136 |
137 | Couplet = returnForecast(data,modelAddress,tokenizerAddress)
138 |
139 | for Uplink,Downlink in Couplet.items():
140 |
141 | print(f"上联:{Uplink},下联:{Downlink}。")
142 |
143 |
144 |
145 |
146 |
147 | if __name__ == '__main__':
148 |
149 | test_data = ["床前明月光", "万里悲秋常作客","广汉飞霞诗似玉", "执政为民,德行天下","春回大地万事新"]
150 |
151 | forecastForCouplet(test_data)
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/test/test_paddle/roberta_autotitle_test_paddle.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/home/bert_seq2seq/paddle_model")
3 | import paddle
4 | import numpy as np
5 | from paddle import nn
6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, CTRLTokenizer,RobertaTokenizer,RobertaModel
7 | from paddlenlp.transformers.roberta.modeling import robertaSeq2Seq
8 | from paddlenlp.transformers.roberta.modeling import RobertaPretrainedModel
9 | import paddle.nn.functional as F
10 | from tqdm import tqdm
11 | from paddle.io import Dataset
12 |
13 | class Predictor:
14 |
15 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ):
16 | self.out_max_length = out_max_length
17 | self.beam_size = beam_size
18 | self.max_length = max_length
19 | self.tokenizer = tokenizer
20 | self.model = model
21 |
22 | def generate(self, text):
23 | self.model.eval()
24 | input_max_length = self.max_length - self.out_max_length
25 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length)
26 | vocab = self.tokenizer.vocab
27 | token_ids = tokenizer_out["input_ids"]
28 | token_type_ids = tokenizer_out["token_type_ids"]
29 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1])
30 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1])
31 |
32 | # print(f"token_ids is {token_ids}")
33 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size)
34 | # print(out_puts_ids)
35 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids)
36 |
37 | return self.tokenizer.convert_tokens_to_string(tokens)
38 |
39 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, ):
40 | """
41 | beam-search操作
42 | """
43 | sep_id = word2ix["[SEP]"]
44 |
45 | # 用来保存输出序列
46 | # output_ids = paddle.empty([1, 0]).astype("int")
47 | output_ids = None
48 | # output_ids = np.empty([1, 0]).astype(np.int)
49 | # 用来保存累计得分
50 | with paddle.no_grad():
51 | output_scores = np.zeros([token_ids.shape[0]])
52 | for step in range(self.out_max_length):
53 | if step == 0:
54 | scores = self.model(token_ids, token_type_ids)
55 | # 重复beam-size次 输入ids
56 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1])
57 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1])
58 | else:
59 | scores = self.model(new_input_ids, new_token_type_ids)
60 |
61 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy()
62 |
63 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分
64 | ## 取topk的时候我们是展平了然后再去调用topk函数
65 | # 展平
66 | logit_score = logit_score.reshape([-1])
67 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:]
68 | hype_score = logit_score[hype_pos]
69 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引
70 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引
71 |
72 | output_scores = hype_score
73 | if output_ids is None:
74 | output_ids = indice2.reshape([beam_size, 1])
75 | else:
76 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int)
77 |
78 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1)
79 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1)
80 |
81 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记
82 | best_one = output_scores.argmax()
83 | if end_counts[best_one] == 1:
84 | # 说明出现终止了~
85 | return output_ids[best_one][:-1]
86 | else:
87 | # 保留未完成部分
88 | flag = (end_counts < 1) # 标记未完成序列
89 | if not flag.all(): # 如果有已完成的
90 | token_ids = token_ids[flag]
91 | token_type_ids = token_type_ids[flag]
92 | new_input_ids = new_input_ids[flag]
93 | new_token_type_ids = new_token_type_ids[flag]
94 | output_ids = output_ids[flag] # 扔掉已完成序列
95 | output_scores = output_scores[flag] # 扔掉已完成序列
96 | beam_size = flag.sum() # topk相应变化
97 |
98 | return output_ids[output_scores.argmax()]
99 |
100 | def returnForecast(data,modelAddress = None,tokenizerAddress = None):
101 |
102 | if modelAddress is None:
103 | model = robertaSeq2Seq.from_pretrained('./model')
104 |
105 | else:
106 | model = robertaSeq2Seq.from_pretrained(modelAddress)
107 |
108 |
109 | if tokenizerAddress is None:
110 | tokenizer = RobertaTokenizer.from_pretrained('./tokenizer')
111 |
112 | else:
113 | tokenizer = RobertaTokenizer.from_pretrained(tokenizerAddress)
114 |
115 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512)
116 |
117 |
118 | import collections
119 |
120 |
121 | OutCouplet = {}
122 |
123 | OutCouplet = collections.OrderedDict()
124 |
125 | for simply_in in data:
126 |
127 | out = predictor.generate(simply_in)
128 | OutCouplet[simply_in] = out
129 |
130 | return OutCouplet
131 |
132 | def forecastForAutoTitle(data,modelAddress = None,tokenizerAddress = None):
133 |
134 | Couplet = returnForecast(data,modelAddress,tokenizerAddress)
135 |
136 | for Uplink,Downlink in Couplet.items():
137 |
138 | print(f"新闻:{Uplink},标题:{Downlink}。")
139 |
140 |
141 |
142 |
143 |
144 | if __name__ == '__main__':
145 |
146 | test_data = [
147 | "本文总结了十个可穿戴产品的设计原则而这些原则同样也是笔者认为是这个行业最吸引人的地方1为人们解决重复性问题2从人开始而不是从机器开始3要引起注意但不要刻意4提升用户能力而不是取代人",
148 | "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
149 | "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元"]
150 |
151 | forecastForAutoTitle(test_data)
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/test/test_paddle/roberta_math_test_paddle.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/home/bert_seq2seq/paddle_model")
3 | import paddle
4 | import numpy as np
5 | from paddle import nn
6 | from paddlenlp.transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer, CTRLTokenizer, RobertaTokenizer, \
7 | RobertaModel
8 | from paddlenlp.transformers.roberta.modeling import robertaSeq2Seq
9 | from paddlenlp.transformers.roberta.modeling import RobertaPretrainedModel
10 | import paddle.nn.functional as F
11 | from tqdm import tqdm
12 | from paddle.io import Dataset
13 | import json
14 | import re
15 |
16 |
17 |
18 | class Predictor:
19 |
20 | def __init__(self, model: BertModel, tokenizer: BertTokenizer, out_max_length=100, beam_size=1, max_length=512, ):
21 | self.out_max_length = out_max_length
22 | self.beam_size = beam_size
23 | self.max_length = max_length
24 | self.tokenizer = tokenizer
25 | self.model = model
26 |
27 | def generate(self, text):
28 | self.model.eval()
29 | input_max_length = self.max_length - self.out_max_length
30 | tokenizer_out = self.tokenizer.encode(text, max_seq_len=input_max_length)
31 | vocab = self.tokenizer.vocab
32 | token_ids = tokenizer_out["input_ids"]
33 | token_type_ids = tokenizer_out["token_type_ids"]
34 | token_ids = paddle.to_tensor(token_ids).reshape([1, -1])
35 | token_type_ids = paddle.to_tensor(token_type_ids).reshape([1, -1])
36 |
37 | # print(f"token_ids is {token_ids}")
38 | out_puts_ids = self.beam_search(token_ids, token_type_ids, vocab, beam_size=self.beam_size)
39 | # print(out_puts_ids)
40 | tokens = self.tokenizer.convert_ids_to_tokens(out_puts_ids)
41 |
42 | return self.tokenizer.convert_tokens_to_string(tokens)
43 |
44 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1,):
45 | """
46 | beam-search操作
47 | """
48 | sep_id = word2ix["[SEP]"]
49 |
50 | # 用来保存输出序列
51 | # output_ids = paddle.empty([1, 0]).astype("int")
52 | output_ids = None
53 | # output_ids = np.empty([1, 0]).astype(np.int)
54 | # 用来保存累计得分
55 | with paddle.no_grad():
56 | output_scores = np.zeros([token_ids.shape[0]])
57 | for step in range(self.out_max_length):
58 | if step == 0:
59 | scores = self.model(token_ids, token_type_ids)
60 | # 重复beam-size次 输入ids
61 | token_ids = np.tile(token_ids.reshape([1, -1]), [beam_size, 1])
62 | token_type_ids = np.tile(token_type_ids.reshape([1, -1]), [beam_size, 1])
63 | else:
64 | scores = self.model(new_input_ids, new_token_type_ids)
65 |
66 | logit_score = F.log_softmax(scores[:, -1], axis=-1).numpy()
67 |
68 | logit_score = output_scores.reshape([-1, 1]) + logit_score # 累计得分
69 | ## 取topk的时候我们是展平了然后再去调用topk函数
70 | # 展平
71 | logit_score = logit_score.reshape([-1])
72 | hype_pos = np.argpartition(logit_score, -beam_size, axis=-1)[-beam_size:]
73 | hype_score = logit_score[hype_pos]
74 | indice1 = (hype_pos // scores.shape[-1]).reshape([-1]) # 行索引
75 | indice2 = (hype_pos % scores.shape[-1]).astype(np.int).reshape([-1, 1]) # 列索引
76 |
77 | output_scores = hype_score
78 | if output_ids is None:
79 | output_ids = indice2.reshape([beam_size, 1])
80 | else :
81 | output_ids = np.concatenate([output_ids[indice1], indice2], axis=1).astype(np.int)
82 |
83 | new_input_ids = np.concatenate([token_ids, output_ids], axis=1)
84 | new_token_type_ids = np.concatenate([token_type_ids, np.ones_like(output_ids)], axis=1)
85 |
86 | end_counts = (output_ids == sep_id).sum(1) # 统计出现的end标记
87 | best_one = output_scores.argmax()
88 | if end_counts[best_one] == 1:
89 | # 说明出现终止了~
90 | return output_ids[best_one][:-1]
91 | else :
92 | # 保留未完成部分
93 | flag = (end_counts < 1) # 标记未完成序列
94 | if not flag.all(): # 如果有已完成的
95 | token_ids = token_ids[flag]
96 | token_type_ids = token_type_ids[flag]
97 | new_input_ids = new_input_ids[flag]
98 | new_token_type_ids = new_token_type_ids[flag]
99 | output_ids = output_ids[flag] # 扔掉已完成序列
100 | output_scores = output_scores[flag] # 扔掉已完成序列
101 | beam_size = flag.sum() # topk相应变化
102 |
103 | return output_ids[output_scores.argmax()]
104 |
105 |
106 |
107 |
108 |
109 |
110 | def returnForecast(data,modelAddress = None,tokenizerAddress = None):
111 |
112 | if modelAddress is None:
113 | model = robertaSeq2Seq.from_pretrained('./model')
114 |
115 | else:
116 | model = robertaSeq2Seq.from_pretrained(modelAddress)
117 |
118 |
119 | if tokenizerAddress is None:
120 | tokenizer = RobertaTokenizer.from_pretrained('./tokenizer')
121 |
122 | else:
123 | tokenizer = RobertaTokenizer.from_pretrained(tokenizerAddress)
124 |
125 | predictor = Predictor(model, tokenizer, beam_size=2, out_max_length=40, max_length=512)
126 |
127 |
128 | import collections
129 |
130 |
131 | OutCouplet = {}
132 |
133 | OutCouplet = collections.OrderedDict()
134 |
135 | for simply_in in data:
136 |
137 | out = predictor.generate(simply_in)
138 | OutCouplet[simply_in] = out
139 |
140 | return OutCouplet
141 |
142 | def forecastForMath(data,modelAddress = None,tokenizerAddress = None):
143 |
144 | Couplet = returnForecast(data,modelAddress,tokenizerAddress)
145 |
146 | for Uplink,Downlink in Couplet.items():
147 |
148 | print(f"问题:{Uplink},算式:{Downlink}。")
149 |
150 |
151 |
152 |
153 |
154 | if __name__ == '__main__':
155 |
156 | test_data = [
157 | "王艳家买了一台洗衣机和一台电冰箱,一共花了6000元,电冰箱的价钱是洗衣机的3/5,求洗衣机的价钱.",
158 | "六1班原来男生占总数的2/5,又转来5名男生,现在男生占总数的5/11,女生有多少人?",
159 | "两个相同的数相乘,积是3600,这个数是多少."]
160 |
161 | forecastForMath(test_data)
162 |
163 |
164 |
165 |
166 |
167 |
--------------------------------------------------------------------------------
/test/做数学题_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
6 |
7 | model_name = "roberta" # 选择模型名字
8 | model_path = "./state_dict/bert_math_ques_model.bin"
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 | if __name__ == "__main__":
12 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
13 | model_name = "roberta" # 选择模型名字
14 | # 加载字典
15 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
16 | tokenizer = Tokenizer(word2idx)
17 | # 定义模型
18 | bert_model = load_bert(word2idx, model_name=model_name, model_class="seq2seq")
19 | bert_model.set_device(device)
20 | bert_model.eval()
21 | ## 加载训练的模型参数~
22 | bert_model.load_all_params(model_path=model_path, device=device)
23 | test_data = ["王艳家买了一台洗衣机和一台电冰箱,一共花了6000元,电冰箱的价钱是洗衣机的3/5,求洗衣机的价钱.",
24 | "六1班原来男生占总数的2/5,又转来5名男生,现在男生占总数的5/11,女生有多少人?",
25 | "两个相同的数相乘,积是3600,这个数是多少.",
26 | "1加1等于几"]
27 | for text in test_data:
28 | with torch.no_grad():
29 | print(bert_model.generate(text, beam_size=3))
--------------------------------------------------------------------------------
/test/新闻标题文本分类_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | target = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"]
6 |
7 | cls_model = "./state_dict/bert_multi_classify_model.bin"
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | if __name__ == "__main__":
11 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
12 | model_name = "roberta" # 选择模型名字
13 | # 加载字典
14 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
15 | tokenizer = Tokenizer(word2idx)
16 | # 定义模型
17 | bert_model = load_bert(word2idx, model_name=model_name, model_class="cls", target_size=len(target))
18 | bert_model.set_device(device)
19 | bert_model.eval()
20 | ## 加载训练的模型参数~
21 | bert_model.load_all_params(model_path=cls_model, device=device)
22 | test_data = ["编剧梁馨月讨稿酬六六何念助阵 公司称协商解决",
23 | "西班牙BBVA第三季度净利降至15.7亿美元",
24 | "基金巨亏30亿 欲打开云天系跌停自救"]
25 | for text in test_data:
26 | with torch.no_grad():
27 | text, text_ids = tokenizer.encode(text)
28 | text = torch.tensor(text, device=device).view(1, -1)
29 | print(target[torch.argmax(bert_model(text)).item()])
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/test/粗粒度ner_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | target = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"]
6 |
7 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
8 |
9 | model_name = "roberta" # 选择模型名字
10 | model_path = "./state_dict/bert_粗粒度ner_crf.bin"
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 | def viterbi_decode(nodes, trans):
14 | """
15 | 维特比算法 解码
16 | nodes: (seq_len, target_size)
17 | trans: (target_size, target_size)
18 | """
19 | with torch.no_grad():
20 | scores = nodes[0]
21 | scores[1:] -= 100000 # 刚开始标签肯定是"O"
22 | target_size = nodes.shape[1]
23 | seq_len = nodes.shape[0]
24 | labels = torch.arange(0, target_size).view(1, -1)
25 | path = labels
26 | for l in range(1, seq_len):
27 | scores = scores.view(-1, 1)
28 | M = scores + trans + nodes[l].view(1, -1)
29 | scores, ids = M.max(0)
30 | path = torch.cat((path[:, ids], labels), dim=0)
31 | # print(scores)
32 | # print(scores)
33 | return path[:, scores.argmax()]
34 |
35 | def ner_print(model, test_data, device="cpu"):
36 | model.eval()
37 | idxtword = {v: k for k, v in word2idx.items()}
38 |
39 | tokenier = Tokenizer(word2idx)
40 | trans = model.state_dict()["crf_layer.trans"]
41 | for text in test_data:
42 | decode = []
43 | text_encode, text_ids = tokenier.encode(text)
44 |
45 | text_tensor = torch.tensor(text_encode, device=device).view(1, -1)
46 | out = model(text_tensor).squeeze(0) # 其实是nodes
47 | labels = viterbi_decode(out, trans)
48 | starting = False
49 | for l in labels:
50 | if l > 0:
51 | label = target[l.item()]
52 | if label[0] == "B":
53 | decode.append(label[2: ])
54 | starting = True
55 | elif starting:
56 | decode.append(label[2: ])
57 | else:
58 | starting = False
59 | decode.append("O")
60 | else :
61 | decode.append("O")
62 | flag = 0
63 |
64 | res = {}
65 | text_decode = [idxtword[i] for i in text_encode]
66 | for index, each_entity in enumerate(decode):
67 | if each_entity != "O":
68 | if flag != each_entity:
69 | # cur_text = "".join([text[t] for t in mapping[index]])
70 | cur_text = text_decode[index]
71 | if each_entity in res.keys():
72 | res[each_entity].append(cur_text)
73 | else :
74 | res[each_entity] = [cur_text]
75 | flag = each_entity
76 | elif flag == each_entity:
77 | res[each_entity][-1] += text_decode[index]
78 | # res[each_entity][-1] += "".join([text[t] for t in mapping[index]])
79 | else :
80 | flag = 0
81 | print(res)
82 |
83 | if __name__ == "__main__":
84 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
85 | model_name = "roberta" # 选择模型名字
86 | # 加载字典
87 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
88 | tokenizer = Tokenizer(word2idx)
89 | # 定义模型
90 | bert_model = load_bert(word2idx, model_name=model_name, model_class="sequence_labeling_crf", target_size=len(target))
91 | bert_model.set_device(device)
92 | bert_model.eval()
93 | ## 加载训练的模型参数~
94 | bert_model.load_all_params(model_path=model_path, device=device)
95 | test_data = ["日寇在京掠夺文物详情。",
96 | "以书结缘,把欧美,港台流行的食品类食谱汇集一堂。",
97 | "明天天津下雨,不知道杨永康主任还能不能来学校吃个饭。",
98 | "美国的华莱士,我和他谈笑风生",
99 | "看包公断案的戏"
100 | ]
101 | ner_print(bert_model, test_data, device=device)
102 |
--------------------------------------------------------------------------------
/test/细粒度ner_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
3 | from bert_seq2seq import load_bert
4 |
5 | target = ["other", "address", "book", "company", "game", "government", "movie", "name", "organization", "position", "scene"]
6 |
7 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
8 |
9 | model_name = "roberta" # 选择模型名字
10 | model_path = "./state_dict/细粒度_bert_ner_model_crf.bin"
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 | def viterbi_decode(nodes, trans):
14 | """
15 | 维特比算法 解码
16 | nodes: (seq_len, target_size)
17 | trans: (target_size, target_size)
18 | """
19 | with torch.no_grad():
20 | scores = nodes[0]
21 | scores[1:] -= 100000 # 刚开始标签肯定是"O"
22 | target_size = nodes.shape[1]
23 | seq_len = nodes.shape[0]
24 | labels = torch.arange(0, target_size).view(1, -1)
25 | path = labels
26 | for l in range(1, seq_len):
27 | scores = scores.view(-1, 1)
28 | M = scores + trans + nodes[l].view(1, -1)
29 | scores, ids = M.max(0)
30 | path = torch.cat((path[:, ids], labels), dim=0)
31 |
32 | return path[:, scores.argmax()]
33 |
34 | def ner_print(model, test_data, device="cpu"):
35 | model.eval()
36 | idxtword = {v: k for k, v in word2idx.items()}
37 | tokenier = Tokenizer(word2idx)
38 | trans = model.state_dict()["crf_layer.trans"]
39 | for text in test_data:
40 | decode = []
41 | text_encode, _ = tokenier.encode(text)
42 | text_tensor = torch.tensor(text_encode, device=device).view(1, -1)
43 | out = model(text_tensor).squeeze(0) # 其实是nodes
44 | labels = viterbi_decode(out, trans)
45 | starting = False
46 | for l in labels:
47 | if l > 0:
48 | label = target[l.item()]
49 | decode.append(label)
50 | else :
51 | decode.append("other")
52 | flag = 0
53 | res = {}
54 | # print(decode)
55 | # print(text)
56 | decode_text = [idxtword[i] for i in text_encode]
57 | for index, each_entity in enumerate(decode):
58 | if each_entity != "other":
59 | if flag != each_entity:
60 | # cur_text = "".join([text[t] for t in mapping[index]])
61 |
62 | cur_text = decode_text[index]
63 | if each_entity in res.keys():
64 | res[each_entity].append(cur_text)
65 | else :
66 | res[each_entity] = [cur_text]
67 | flag = each_entity
68 | elif flag == each_entity:
69 | res[each_entity][-1] += decode_text[index]
70 | else :
71 | flag = 0
72 | print(res)
73 |
74 |
75 | if __name__ == "__main__":
76 | vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
77 | model_name = "roberta" # 选择模型名字
78 | # 加载字典
79 | word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
80 | tokenizer = Tokenizer(word2idx)
81 | # 定义模型
82 | bert_model = load_bert(word2idx, model_name=model_name, model_class="sequence_labeling_crf", target_size=len(target))
83 | bert_model.set_device(device)
84 | bert_model.eval()
85 | ## 加载训练的模型参数~
86 | bert_model.load_all_params(model_path=model_path, device=device)
87 | # test_data = ["在广州经营小古董珠宝店的潘凝已经收藏了200多款泰迪熊,其中不少更是老牌泰迪熊厂商史蒂夫、赫曼。",
88 | # "2009年1月,北京市长郭金龙在其政府工作报告中曾明确提出,限价房不停建",
89 | # "昨天,记者连线农业银行亳州市支行办公室主任沈伦,他表示,亳州市支行已经对此事进行了讨论和研究",
90 | # "他们又有会怎样的读书经历。曾经留学海外的香港《号外杂志》主编、著名城市文化学者和作家陈冠中先生"
91 | # ]
92 | test_data = ["曹操南征荆州,刘表之子刘琮投降,刘备领军民十余万避难,于当阳遭遇曹军追兵,惨败。",
93 | "赵哲妮(曾用名赵哲),女,1953年11月19日出生,汉族,初中文化程度,户籍所在地河南省漯河市源汇区。"]
94 | ner_print(bert_model, test_data, device=device)
--------------------------------------------------------------------------------