├── code
├── generation
│ ├── __init__.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ ├── text.py
│ │ ├── optim.py
│ │ ├── utils.py
│ │ ├── dataset.py
│ │ ├── transformer_module.py
│ │ ├── trainer.py
│ │ ├── postprocessing.py
│ │ ├── trainer_kd.py
│ │ └── transformer_model.py
│ ├── run_dialog_kd.py
│ ├── run.py
│ ├── run_dialog.py
│ ├── metrics.py
│ ├── train.py
│ ├── train_kd.py
│ └── config.py
└── retrieval
│ ├── readme
│ ├── run.sh
│ ├── pytorch_pretrained_bert
│ ├── __init__.py
│ ├── convert_tf_checkpoint_to_pytorch.py
│ ├── convert_gpt2_checkpoint_to_pytorch.py
│ ├── convert_openai_checkpoint_to_pytorch.py
│ ├── __main__.py
│ ├── convert_transfo_xl_checkpoint_to_pytorch.py
│ ├── optimization_openai.py
│ ├── file_utils.py
│ ├── tokenization_gpt2.py
│ ├── tokenization_openai.py
│ ├── optimization.py
│ └── tokenization.py
│ └── metrics.py
└── README.md
/code/generation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/generation/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/retrieval/readme:
--------------------------------------------------------------------------------
1 | Modified from https://github.com/huggingface/transformers.
--------------------------------------------------------------------------------
/code/retrieval/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## for training
4 | python run_distilling.py \
5 | --task_name buy_data \
6 | --do_train \
7 | --do_eval \
8 | --do_lower_case \
9 | --data_dir $data_dir \
10 | --teacher_bert_model $teacher \
11 | --student_bert_model $student \
12 | --max_seq_length 64 \
13 | --train_batch_size 128 \
14 | --eval_batch_size 128 \
15 | --learning_rate 2e-5 \
16 | --num_train_epochs 10.0 \
17 | --output_dir $output \
18 | --temperature 1 --alpha 0.5
19 |
20 | ## for testing
21 | python run_ranker_test.py \
22 | --task_name bug_data \
23 | --do_eval \
24 | --do_lower_case \
25 | --data_dir $data_dir \
26 | --bert_model $model \
27 | --max_seq_length 64 \
28 | --train_batch_size 128 \
29 | --eval_batch_size 128 \
30 | --learning_rate 2e-5 \
31 | --num_train_epochs 10.0 \
32 | --output_dir $output
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dialogue Distillation
2 |
3 | code/data for EMNLP'2020 long paper "[Dialogue Distillation: Open-domain Dialogue Augmentation Using Unpaired Data](https://arxiv.org/abs/2009.09427)"
4 |
5 | ## Code
6 |
7 | `code/generation`: code of generation based dialogue model
8 |
9 | `code/retrieval`: code of retrieval based dialogue model
10 |
11 | ## Data
12 |
13 | The data we used can be downloaded from this [link](https://drive.google.com/file/d/1mNQf7QydWGhxPE1-1IW0yfwSLJJ9zVG7/view?usp=sharing)
14 |
15 | ## Citation
16 |
17 | Please cite our EMNLP paper if you find our work useful :)
18 |
19 | @inproceedings{zhang2020distill,
20 | title={Dialogue Distillation: Open-domain Dialogue Augmentation Using Unpaired Data},
21 | author={Zheng, Rongsheng and Zheng, Yinhe and Shao, Jianzhi and Mao, Xiaoxi and Xi, Yadong and Huang, Minlie},
22 | booktitle={EMNLP},
23 | year={2020},
24 | url={https://arxiv.org/abs/2009.09427}
25 | }
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | from .tokenization_openai import OpenAIGPTTokenizer
4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
5 | # from .tokenization_gpt2 import GPT2Tokenizer # omit regex install error
6 |
7 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
8 | BertForMaskedLM, BertForNextSentencePrediction,
9 | BertForSequenceClassification, BertForMultipleChoice,
10 | BertForTokenClassification, BertForQuestionAnswering,
11 | load_tf_weights_in_bert)
12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
14 | load_tf_weights_in_openai_gpt)
15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
16 | load_tf_weights_in_transfo_xl)
17 | from .modeling_gpt2 import (GPT2Config, GPT2Model,
18 | GPT2LMHeadModel, GPT2DoubleHeadsModel,
19 | load_tf_weights_in_gpt2)
20 |
21 | from .optimization import BertAdam
22 | from .optimization_openai import OpenAIAdam
23 |
24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
25 |
--------------------------------------------------------------------------------
/code/generation/model/loss.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class LabelSmoothingLoss(nn.Module):
22 | def __init__(self, n_labels, smoothing=0.0, ignore_index=-100, size_average=True):
23 | super(LabelSmoothingLoss, self).__init__()
24 | assert 0 <= smoothing <= 1
25 |
26 | self.ignore_index = ignore_index
27 | self.confidence = 1 - smoothing
28 |
29 | if smoothing > 0:
30 | print('using label smoothing')
31 | self.criterion = nn.KLDivLoss(size_average=size_average)
32 | n_ignore_idxs = 1 + (ignore_index >= 0)
33 | one_hot = torch.full((1, n_labels), fill_value=(smoothing / (n_labels - n_ignore_idxs)))
34 | if ignore_index >= 0:
35 | one_hot[0, ignore_index] = 0
36 | self.register_buffer('one_hot', one_hot)
37 | else:
38 | self.criterion = nn.NLLLoss(size_average=size_average, ignore_index=ignore_index)
39 |
40 | def forward(self, log_inputs, targets):
41 | if self.confidence < 1:
42 | tdata = targets.data
43 |
44 | tmp = self.one_hot.repeat(targets.shape[0], 1)
45 | tmp.scatter_(1, tdata.unsqueeze(1), self.confidence)
46 |
47 | if self.ignore_index >= 0:
48 | mask = torch.nonzero(tdata.eq(self.ignore_index)).squeeze(-1)
49 | if mask.numel() > 0:
50 | tmp.index_fill_(0, mask, 0)
51 |
52 | targets = tmp
53 |
54 | return self.criterion(log_inputs, targets)
55 |
--------------------------------------------------------------------------------
/code/generation/model/text.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | class myVocab:
18 | spl = '
'
19 | pad = ''
20 | eos = ''
21 | unk = ''
22 |
23 | def __init__(self, vocab_file):
24 | # TODO: add check for special tokens
25 | self.spec_tokens = [myVocab.spl, myVocab.pad, myVocab.eos, myVocab.unk]
26 | with open(vocab_file, 'r', encoding='utf8') as fr:
27 | vocab = [line.strip('\n').split()[0] for line in fr.readlines()]
28 | vocab = self.spec_tokens + vocab
29 | self.token2id = {t: i for i, t in enumerate(vocab)}
30 | self.id2token = {i: t for i, t in enumerate(vocab)}
31 |
32 | def __len__(self):
33 | return len(self.token2id)
34 |
35 | @property
36 | def n_special_tokens(self):
37 | return len(self.spec_tokens)
38 |
39 | @property
40 | def special_tokens_ids(self):
41 | return [self.token2id[t] for t in self.spec_tokens]
42 |
43 | @property
44 | def pad_id(self):
45 | return self.token2id[myVocab.pad]
46 |
47 | @property
48 | def spl_id(self):
49 | return self.token2id[myVocab.spl]
50 |
51 | @property
52 | def bos_id(self):
53 | return self.token2id[myVocab.eos]
54 |
55 | @property
56 | def eos_id(self):
57 | return self.token2id[myVocab.eos]
58 |
59 | def string2ids(self, string):
60 | tokens = string.split()
61 | ids = [self.token2id[t] for t in tokens if t in self.token2id]
62 | return ids
63 |
64 | def ids2string(self, ids):
65 | tokens = [self.id2token[id] for id in ids]
66 | return ''.join(tokens)
67 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | # Initialise PyTorch model
32 | config = BertConfig.from_json_file(bert_config_file)
33 | print("Building PyTorch model from configuration: {}".format(str(config)))
34 | model = BertForPreTraining(config)
35 |
36 | # Load weights from tf checkpoint
37 | load_tf_weights_in_bert(model, tf_checkpoint_path)
38 |
39 | # Save pytorch-model
40 | print("Save PyTorch model to {}".format(pytorch_dump_path))
41 | torch.save(model.state_dict(), pytorch_dump_path)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | ## Required parameters
47 | parser.add_argument("--tf_checkpoint_path",
48 | default = None,
49 | type = str,
50 | required = True,
51 | help = "Path the TensorFlow checkpoint path.")
52 | parser.add_argument("--bert_config_file",
53 | default = None,
54 | type = str,
55 | required = True,
56 | help = "The config json file corresponding to the pre-trained BERT model. \n"
57 | "This specifies the model architecture.")
58 | parser.add_argument("--pytorch_dump_path",
59 | default = None,
60 | type = str,
61 | required = True,
62 | help = "Path to the output PyTorch model.")
63 | args = parser.parse_args()
64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
65 | args.bert_config_file,
66 | args.pytorch_dump_path)
67 |
--------------------------------------------------------------------------------
/code/generation/run_dialog_kd.py:
--------------------------------------------------------------------------------
1 | from run_dialog_interface import dialog
2 | from metrics import calc_f1, calc_bleu, calc_distinct, calc_avg_len
3 |
4 | bz1 = 5; lp1 = 1.6
5 | bz2 = 5; lp2 = 1.0
6 | bz3 = 5; lp3 = 2.0
7 | bz4 = 5; lp4 = 1.6
8 | bz5 = 5; lp5 = 1.6
9 |
10 | crowded = dialog('/root/generation_with_augmentation/checkpoints/dialog_v2/crowded/last_checkpoint15', bz1, lp1)
11 | fake = dialog('/root/generation_with_augmentation/checkpoints/dialog_fake_v2/last_checkpoint7', bz2, lp2)
12 | fake_kd = dialog('/root/generation_with_augmentation/checkpoints/dialog_kd/fake/last_checkpoint28', bz3, lp3)
13 | crowded_fake = dialog('/root/generation_with_augmentation/checkpoints/dialog_v2/crowded_fake/last_checkpoint18', bz4, lp4)
14 | crowded_fake_kd = dialog('/root/generation_with_augmentation/checkpoints/dialog_kd/crowded_fake/last_checkpoint29', bz5, lp5)
15 |
16 |
17 | crowded_pairs = []
18 | fake_pairs = []
19 | fake_kd_pairs = []
20 | crowded_fake_pairs = []
21 | crowded_fake_kd_pairs = []
22 |
23 | if __name__ == '__main__':
24 | with open('dataset/dialog/test_1k.txt', 'r', encoding='utf8') as fr:
25 | lines = fr.readlines()
26 | cnt = 0
27 | with open('dataset/dialog/test_1k_output_kd_bm5_lp_diff_v2.txt', 'w', encoding='utf8') as fw:
28 | fw.write('question\tanswer\tcrowded\tfake\tfake_kd\tcrowded_fake\tcrowded_fake_kd\n')
29 | for line in lines:
30 | cnt += 1
31 | if cnt%100==0:
32 | print(cnt)
33 | q,a = line.strip('\n').split('\t')
34 | a1 = crowded.answer_beams(q)
35 | crowded_pairs.append([list(a1), list(a)])
36 | a2 = fake.answer_beams(q)
37 | fake_pairs.append([list(a2), list(a)])
38 | a3 = fake_kd.answer_beams(q)
39 | fake_kd_pairs.append([list(a3), list(a)])
40 | a4 = crowded_fake.answer_beams(q)
41 | crowded_fake_pairs.append([list(a4), list(a)])
42 | a5 = crowded_fake_kd.answer_beams(q)
43 | crowded_fake_kd_pairs.append([list(a5), list(a)])
44 | res = [q, a, a1, a2, a3, a4, a5]
45 | fw.write('\t'.join(res) + '\n')
46 | for res in [crowded_pairs, fake_pairs, fake_kd_pairs, crowded_fake_pairs, crowded_fake_kd_pairs]:
47 | f1 = calc_f1(res)
48 | bleu = calc_bleu(res)
49 | distinct = calc_distinct(res)
50 | avg_len = calc_avg_len(res)
51 | print('f1: ', f1)
52 | print('bleu: ', bleu)
53 | print('distinct: ', distinct)
54 | print('avg_len: ', avg_len)
55 | '''
56 | while True:
57 | message = input('>')
58 | print('crowded', crowded.answer_beams(message))
59 | print('crowded_fake', crowded_fake.answer_beams(message))
60 | print('fake', fake.answer_beams(message))
61 | '''
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
25 | GPT2Config,
26 | GPT2Model,
27 | load_tf_weights_in_gpt2)
28 |
29 |
30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if gpt2_config_file == "":
33 | config = GPT2Config()
34 | else:
35 | config = GPT2Config(gpt2_config_file)
36 | model = GPT2Model(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--gpt2_checkpoint_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--gpt2_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
71 | args.gpt2_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
25 | OpenAIGPTConfig,
26 | OpenAIGPTModel,
27 | load_tf_weights_in_openai_gpt)
28 |
29 |
30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if openai_config_file == "":
33 | config = OpenAIGPTConfig()
34 | else:
35 | config = OpenAIGPTConfig(openai_config_file)
36 | model = OpenAIGPTModel(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--openai_checkpoint_folder_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--openai_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
71 | args.openai_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/code/retrieval/metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | def mean_average_precision(sort_data):
5 | # to do
6 | count_1 = 0
7 | sum_precision = 0
8 | for index in range(len(sort_data)):
9 | if sort_data[index][1] == 1:
10 | count_1 += 1
11 | sum_precision += 1.0 * count_1 / (index + 1)
12 | return sum_precision / count_1 if count_1 != 0 else 0.0
13 |
14 |
15 | def mean_reciprocal_rank(sort_data):
16 | sort_lable = [s_d[1] for s_d in sort_data]
17 | assert 1 in sort_lable
18 | return 1.0 / (1 + sort_lable.index(1))
19 |
20 |
21 | def precision_at_position_1(sort_data):
22 | if sort_data[0][1] == 1:
23 | return 1
24 | else:
25 | return 0
26 |
27 |
28 | def recall_at_position_k_in_10(sort_data, k):
29 | sort_lable = [s_d[1] for s_d in sort_data]
30 | select_lable = sort_lable[:k]
31 | return 1.0 * select_lable.count(1) / sort_lable.count(1)
32 |
33 |
34 | def evaluation_one_session(data):
35 | sort_data = sorted(data, key=lambda x: x[0], reverse=True)
36 | m_a_p = mean_average_precision(sort_data)
37 | m_r_r = mean_reciprocal_rank(sort_data)
38 | p_1 = precision_at_position_1(sort_data)
39 | r_1 = recall_at_position_k_in_10(sort_data, 1)
40 | r_2 = recall_at_position_k_in_10(sort_data, 2)
41 | r_5 = recall_at_position_k_in_10(sort_data, 5)
42 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5
43 |
44 |
45 | def evaluate(file_path, n=10):
46 | sum_m_a_p = 0
47 | sum_m_r_r = 0
48 | sum_p_1 = 0
49 | sum_r_1 = 0
50 | sum_r_2 = 0
51 | sum_r_5 = 0
52 |
53 | i = 0
54 | total_num = 0
55 | data = None
56 | with open(file_path, 'r') as infile:
57 | for line in infile:
58 | if i % n == 0:
59 | data = []
60 |
61 | tokens = line.strip().split('\t')
62 | if len(tokens) < 2:
63 | print('i', i, 'tokens', tokens)
64 | data.append((float(tokens[0]), int(tokens[1])))
65 |
66 | if i % n == n - 1:
67 | if 1 not in [s_d[1] for s_d in data]:
68 | continue
69 | total_num += 1
70 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data)
71 | sum_m_a_p += m_a_p
72 | sum_m_r_r += m_r_r
73 | sum_p_1 += p_1
74 | sum_r_1 += r_1
75 | sum_r_2 += r_2
76 | sum_r_5 += r_5
77 |
78 | i += 1
79 |
80 | print('total num: %s' % total_num)
81 | print('MAP: %s'.format(1.0*sum_m_a_p/total_num))
82 | print('MRR: {}'.format(1.0*sum_m_r_r/total_num))
83 | print('P@1: {}'.format(1.0*sum_p_1/total_num))
84 | print('R{}@1: {}'.format(n, str(1.0*sum_r_1/total_num)))
85 | print('R{}@2: {}'.format(n, str(1.0*sum_r_2/total_num)))
86 | print('R{}@5: {}'.format(n, str(1.0*sum_r_5/total_num)))
87 | # print('R10@1: %s' %(1.0*sum_r_1/total_num))
88 | # print('R10@2: %s' %(1.0*sum_r_2/total_num))
89 | # print('R10@5: %s' %(1.0*sum_r_5/total_num))
90 | return (1.0 * sum_m_a_p / total_num, 1.0 * sum_m_r_r / total_num, 1.0 * sum_p_1 / total_num,
91 | 1.0 * sum_r_1 / total_num, 1.0 * sum_r_2 / total_num, 1.0 * sum_r_5 / total_num)
92 |
93 |
94 | if __name__ == '__main__':
95 | print(sys.argv)
96 | result = evaluate(sys.argv[1], sys.argv[2])
97 | # for r in result:
98 | # print(r)
99 |
100 |
101 | """
102 | i = 0
103 | cnt = 0
104 | utterances, responses, labels = [], [], []
105 | for line in open('data/test.txt'):
106 | contexts = line.strip().split('\t')
107 | uttes, resp, l = contexts[1:-1], contexts[-1], contexts[0]
108 | uttes = [utte.split() for utte in uttes]
109 | labels.append(int(l))
110 | if i % 10 == 9:
111 | if 1 in labels[i-9:]:
112 | cnt += 10
113 | i += 1
114 | output:
115 | i: 10000
116 | cnt: 6670
117 | """
118 |
--------------------------------------------------------------------------------
/code/generation/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from model.utils import load_openai_weights_chinese, set_seed, f1_score
4 | from model.transformer_model import TransformerModel
5 | from model.text import myVocab
6 | from config import get_model_config_poem, get_test_config_poem
7 | import readline
8 |
9 |
10 | def main():
11 | model_config = get_model_config_poem()
12 | test_config = get_test_config_poem()
13 |
14 | set_seed(test_config.seed)
15 | device = torch.device(test_config.device)
16 |
17 | vocab = myVocab(model_config.vocab_path)
18 |
19 | transformer = TransformerModel(n_layers=model_config.n_layers,
20 | n_embeddings=len(vocab),
21 | n_pos_embeddings=model_config.n_pos_embeddings,
22 | embeddings_size=model_config.embeddings_size,
23 | padding_idx=vocab.pad_id,
24 | n_heads=model_config.n_heads,
25 | dropout=model_config.dropout,
26 | embed_dropout=model_config.embed_dropout,
27 | attn_dropout=model_config.attn_dropout,
28 | ff_dropout=model_config.ff_dropout,
29 | bos_id=vocab.bos_id,
30 | eos_id=vocab.eos_id,
31 | max_seq_len=model_config.max_seq_len,
32 | beam_size=model_config.beam_size,
33 | length_penalty=model_config.length_penalty,
34 | n_segments=model_config.n_segments,
35 | annealing_topk=model_config.annealing_topk,
36 | annealing=model_config.annealing,
37 | diversity_coef=model_config.diversity_coef,
38 | diversity_groups=model_config.diversity_groups)
39 |
40 | transformer = transformer.to(device)
41 | state_dict = torch.load(test_config.last_checkpoint_path, map_location=device)
42 | temp = dict(state_dict['model'])
43 | keys = list(temp.keys())
44 | for key in keys:
45 | # new_key = '.'.join([i for i in key.split('.') if i != 'module'])
46 | new_key = key.replace('.module', '')
47 | temp[new_key] = temp.pop(key)
48 | transformer.load_state_dict(temp)
49 | transformer.eval()
50 | print('Weights loaded from {}'.format(test_config.last_checkpoint_path))
51 |
52 |
53 | def answer(message):
54 | message = ' '.join(message)
55 | message = vocab.string2ids(message)
56 | message = [vocab.bos_id] + message + [vocab.eos_id]
57 | message = message[:60]
58 | # print(message)
59 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0]
60 | prediction = transformer.predict(contexts)[0]
61 | prediction_str = vocab.ids2string(prediction)
62 | return prediction_str
63 |
64 | def answer_beams(message):
65 | message = ' '.join(message)
66 | message = vocab.string2ids(message)
67 | message = [vocab.bos_id] + message + [vocab.eos_id]
68 | message = message[:20]
69 | # print(message)
70 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0]
71 | predictions = transformer.predict_beams(contexts)[0]
72 | prediction_strs = [vocab.ids2string(prediction) for prediction in predictions]
73 | return prediction_strs
74 |
75 | '''
76 | with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw:
77 | with open('data/test200.txt', 'r', encoding='utf8') as fr:
78 | lines = fr.readlines()
79 | for line in lines:
80 | post, response = line.strip('\n').replace(' ', '').split('\t')
81 | ans = answer(post)
82 | fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n')
83 | '''
84 | '''
85 | while True:
86 | message = input('>')
87 | ans = answer(message)
88 | print(ans)
89 | '''
90 |
91 | while True:
92 | message = input('>')
93 | ans = answer_beams(message)
94 | for i in ans:
95 | print(i)
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
--------------------------------------------------------------------------------
/code/generation/run_dialog.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from model.utils import load_openai_weights_chinese, set_seed, f1_score
4 | from model.transformer_model import TransformerModel
5 | from model.text import myVocab
6 | from config import get_model_config_dialog, get_test_config_dialog
7 | import readline
8 |
9 |
10 | def main():
11 | model_config = get_model_config_dialog()
12 | test_config = get_test_config_dialog()
13 |
14 | set_seed(test_config.seed)
15 | device = torch.device(test_config.device)
16 |
17 | vocab = myVocab(model_config.vocab_path)
18 |
19 | transformer = TransformerModel(n_layers=model_config.n_layers,
20 | n_embeddings=len(vocab),
21 | n_pos_embeddings=model_config.n_pos_embeddings,
22 | embeddings_size=model_config.embeddings_size,
23 | padding_idx=vocab.pad_id,
24 | n_heads=model_config.n_heads,
25 | dropout=model_config.dropout,
26 | embed_dropout=model_config.embed_dropout,
27 | attn_dropout=model_config.attn_dropout,
28 | ff_dropout=model_config.ff_dropout,
29 | bos_id=vocab.bos_id,
30 | eos_id=vocab.eos_id,
31 | max_seq_len=model_config.max_seq_len,
32 | beam_size=model_config.beam_size,
33 | length_penalty=model_config.length_penalty,
34 | n_segments=model_config.n_segments,
35 | annealing_topk=model_config.annealing_topk,
36 | annealing=model_config.annealing,
37 | diversity_coef=model_config.diversity_coef,
38 | diversity_groups=model_config.diversity_groups)
39 |
40 | transformer = transformer.to(device)
41 | state_dict = torch.load(test_config.last_checkpoint_path, map_location=device)
42 | temp = dict(state_dict['model'])
43 | keys = list(temp.keys())
44 | for key in keys:
45 | # new_key = '.'.join([i for i in key.split('.') if i != 'module'])
46 | new_key = key.replace('.module', '')
47 | temp[new_key] = temp.pop(key)
48 | transformer.load_state_dict(temp)
49 | transformer.eval()
50 | print('Weights loaded from {}'.format(test_config.last_checkpoint_path))
51 |
52 |
53 | def answer(message):
54 | message = ' '.join(message)
55 | message = vocab.string2ids(message)
56 | message = [vocab.bos_id] + message + [vocab.eos_id]
57 | message = message[:60]
58 | # print(message)
59 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0]
60 | prediction = transformer.predict(contexts)[0]
61 | prediction_str = vocab.ids2string(prediction)
62 | return prediction_str
63 |
64 | def answer_beams(message):
65 | message = ' '.join(message)
66 | message = vocab.string2ids(message)
67 | message = [vocab.bos_id] + message + [vocab.eos_id]
68 | message = message[:30]
69 | # print(message)
70 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0]
71 | predictions = transformer.predict_beams(contexts)[0]
72 | prediction_strs = [vocab.ids2string(prediction) for prediction in predictions]
73 | return prediction_strs
74 |
75 | '''
76 | with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw:
77 | with open('data/test200.txt', 'r', encoding='utf8') as fr:
78 | lines = fr.readlines()
79 | for line in lines:
80 | post, response = line.strip('\n').replace(' ', '').split('\t')
81 | ans = answer(post)
82 | fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n')
83 | '''
84 | '''
85 | while True:
86 | message = input('>')
87 | ans = answer(message)
88 | print(ans)
89 | '''
90 |
91 | while True:
92 | message = input('>')
93 | ans = answer_beams(message)
94 | for i in ans:
95 | print(i)
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
5 | "convert_tf_checkpoint_to_pytorch",
6 | "convert_openai_checkpoint",
7 | "convert_transfo_xl_checkpoint",
8 | "convert_gpt2_checkpoint",
9 | ]:
10 | print(
11 | "Should be used as one of: \n"
12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
16 | else:
17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
18 | try:
19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
20 | except ImportError:
21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
22 | "In that case, it requires TensorFlow to be installed. Please see "
23 | "https://www.tensorflow.org/install/ for installation instructions.")
24 | raise
25 |
26 | if len(sys.argv) != 5:
27 | # pylint: disable=line-too-long
28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
29 | else:
30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
31 | TF_CONFIG = sys.argv.pop()
32 | TF_CHECKPOINT = sys.argv.pop()
33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
34 | elif sys.argv[1] == "convert_openai_checkpoint":
35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
37 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
38 | if len(sys.argv) == 5:
39 | OPENAI_GPT_CONFIG = sys.argv[4]
40 | else:
41 | OPENAI_GPT_CONFIG = ""
42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
43 | OPENAI_GPT_CONFIG,
44 | PYTORCH_DUMP_OUTPUT)
45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint":
46 | try:
47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
48 | except ImportError:
49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
50 | "In that case, it requires TensorFlow to be installed. Please see "
51 | "https://www.tensorflow.org/install/ for installation instructions.")
52 | raise
53 |
54 | if 'ckpt' in sys.argv[2].lower():
55 | TF_CHECKPOINT = sys.argv[2]
56 | TF_DATASET_FILE = ""
57 | else:
58 | TF_DATASET_FILE = sys.argv[2]
59 | TF_CHECKPOINT = ""
60 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
61 | if len(sys.argv) == 5:
62 | TF_CONFIG = sys.argv[4]
63 | else:
64 | TF_CONFIG = ""
65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
66 | else:
67 | try:
68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
69 | except ImportError:
70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
71 | "In that case, it requires TensorFlow to be installed. Please see "
72 | "https://www.tensorflow.org/install/ for installation instructions.")
73 | raise
74 |
75 | TF_CHECKPOINT = sys.argv[2]
76 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
77 | if len(sys.argv) == 5:
78 | TF_CONFIG = sys.argv[4]
79 | else:
80 | TF_CONFIG = ""
81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/code/generation/metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 |
4 | def get_dict(tokens, ngram, gdict=None):
5 | """
6 | get_dict
7 | """
8 | token_dict = {}
9 | if gdict is not None:
10 | token_dict = gdict
11 | tlen = len(tokens)
12 | for i in range(0, tlen - ngram + 1):
13 | ngram_token = "".join(tokens[i:(i + ngram)])
14 | if token_dict.get(ngram_token) is not None:
15 | token_dict[ngram_token] += 1
16 | else:
17 | token_dict[ngram_token] = 1
18 | return token_dict
19 |
20 |
21 | def count(pred_tokens, gold_tokens, ngram, result):
22 | """
23 | count
24 | """
25 | cover_count, total_count = result
26 | pred_dict = get_dict(pred_tokens, ngram)
27 | gold_dict = get_dict(gold_tokens, ngram)
28 | cur_cover_count = 0
29 | cur_total_count = 0
30 | for token, freq in pred_dict.items():
31 | if gold_dict.get(token) is not None:
32 | gold_freq = gold_dict[token]
33 | cur_cover_count += min(freq, gold_freq)
34 | cur_total_count += freq
35 | result[0] += cur_cover_count
36 | result[1] += cur_total_count
37 |
38 |
39 | def calc_bp(pair_list):
40 | """
41 | calc_bp
42 | """
43 | c_count = 0.0
44 | r_count = 0.0
45 | for pair in pair_list:
46 | pred_tokens, gold_tokens = pair
47 | c_count += len(pred_tokens)
48 | r_count += len(gold_tokens)
49 | bp = 1
50 | if c_count < r_count:
51 | bp = math.exp(1 - r_count / c_count)
52 | return bp
53 |
54 |
55 | def calc_cover_rate(pair_list, ngram):
56 | """
57 | calc_cover_rate
58 | """
59 | result = [0.0, 0.0] # [cover_count, total_count]
60 | for pair in pair_list:
61 | pred_tokens, gold_tokens = pair
62 | count(pred_tokens, gold_tokens, ngram, result)
63 | if result[1] == 0:
64 | cover_rate = 0
65 | else:
66 | cover_rate = result[0] / result[1]
67 | return cover_rate
68 |
69 |
70 | def calc_bleu(pair_list):
71 | """
72 | calc_bleu: [[predict, golden], ...]
73 | """
74 | bp = calc_bp(pair_list)
75 | print('bp: ', bp)
76 | cover_rate1 = calc_cover_rate(pair_list, 1)
77 | cover_rate2 = calc_cover_rate(pair_list, 2)
78 | cover_rate3 = calc_cover_rate(pair_list, 3)
79 | bleu1 = 0
80 | bleu2 = 0
81 | bleu3 = 0
82 | if cover_rate1 > 0:
83 | bleu1 = bp * math.exp(math.log(cover_rate1))
84 | if cover_rate2 > 0:
85 | bleu2 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2)) / 2)
86 | if cover_rate3 > 0:
87 | bleu3 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2) + math.log(cover_rate3)) / 3)
88 | return [bleu1, bleu2]
89 |
90 |
91 | def calc_distinct_ngram(pair_list, ngram):
92 | """
93 | calc_distinct_ngram
94 | """
95 | ngram_total = 0.0
96 | ngram_distinct_count = 0.0
97 | pred_dict = {}
98 | for predict_tokens, _ in pair_list:
99 | get_dict(predict_tokens, ngram, pred_dict)
100 | for key, freq in pred_dict.items():
101 | ngram_total += freq
102 | ngram_distinct_count += 1
103 | #if freq == 1:
104 | # ngram_distinct_count += freq
105 | if ngram_total == 0:
106 | return 0
107 | return ngram_distinct_count / ngram_total
108 |
109 |
110 | def calc_distinct(pair_list):
111 | """
112 | calc_distinct
113 | """
114 | distinct1 = calc_distinct_ngram(pair_list, 1)
115 | distinct2 = calc_distinct_ngram(pair_list, 2)
116 | return [distinct1, distinct2]
117 |
118 |
119 | def calc_f1(data):
120 | """
121 | calc_f1
122 | """
123 | golden_char_total = 0.0
124 | pred_char_total = 0.0
125 | hit_char_total = 0.0
126 | for response, golden_response in data:
127 | #golden_response = "".join(golden_response).decode("utf8")
128 | #response = "".join(response).decode("utf8")
129 | golden_response = "".join(golden_response)
130 | response = "".join(response)
131 | common = Counter(response) & Counter(golden_response)
132 | hit_char_total += sum(common.values())
133 | golden_char_total += len(golden_response)
134 | pred_char_total += len(response)
135 | if pred_char_total == 0:
136 | p = 0
137 | else:
138 | p = hit_char_total / pred_char_total
139 | if golden_char_total == 0:
140 | r = 0
141 | else:
142 | r = hit_char_total / golden_char_total
143 | if p + r == 0:
144 | f1 = 0
145 | else:
146 | f1 = 2 * p * r / (p + r)
147 | return f1
148 |
149 | def calc_avg_len(data):
150 | """
151 | calc average length
152 | """
153 | all_len = 0
154 | for response, golden_response in data:
155 | all_len += len(response)
156 | if len(data) == 0:
157 | return 0
158 | return all_len/len(data)
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import os
21 | import sys
22 | from io import open
23 |
24 | import torch
25 |
26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
28 | WEIGHTS_NAME,
29 | TransfoXLConfig,
30 | TransfoXLLMHeadModel,
31 | load_tf_weights_in_transfo_xl)
32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
33 | VOCAB_NAME)
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 | # We do this to be able to load python 2 datasets pickles
41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
42 | data_utils.Vocab = data_utils.TransfoXLTokenizer
43 | data_utils.Corpus = data_utils.TransfoXLCorpus
44 | sys.modules['data_utils'] = data_utils
45 | sys.modules['vocabulary'] = data_utils
46 |
47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
48 | transfo_xl_config_file,
49 | pytorch_dump_folder_path,
50 | transfo_xl_dataset_file):
51 | if transfo_xl_dataset_file:
52 | # Convert a pre-processed corpus (see original TensorFlow repo)
53 | with open(transfo_xl_dataset_file, "rb") as fp:
54 | corpus = pickle.load(fp, encoding="latin1")
55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
58 | corpus_vocab_dict = corpus.vocab.__dict__
59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
60 |
61 | corpus_dict_no_vocab = corpus.__dict__
62 | corpus_dict_no_vocab.pop('vocab', None)
63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
64 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
66 |
67 | if tf_checkpoint_path:
68 | # Convert a pre-trained TensorFlow model
69 | config_path = os.path.abspath(transfo_xl_config_file)
70 | tf_path = os.path.abspath(tf_checkpoint_path)
71 |
72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
73 | # Initialise PyTorch model
74 | if transfo_xl_config_file == "":
75 | config = TransfoXLConfig()
76 | else:
77 | config = TransfoXLConfig(transfo_xl_config_file)
78 | print("Building PyTorch model from configuration: {}".format(str(config)))
79 | model = TransfoXLLMHeadModel(config)
80 |
81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
82 | # Save pytorch-model
83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
86 | torch.save(model.state_dict(), pytorch_weights_dump_path)
87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
89 | f.write(config.to_json_string())
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--pytorch_dump_folder_path",
95 | default = None,
96 | type = str,
97 | required = True,
98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
99 | parser.add_argument("--tf_checkpoint_path",
100 | default = "",
101 | type = str,
102 | help = "An optional path to a TensorFlow checkpoint path to be converted.")
103 | parser.add_argument("--transfo_xl_config_file",
104 | default = "",
105 | type = str,
106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n"
107 | "This specifies the model architecture.")
108 | parser.add_argument("--transfo_xl_dataset_file",
109 | default = "",
110 | type = str,
111 | help = "An optional dataset file to be converted in a vocabulary.")
112 | args = parser.parse_args()
113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
114 | args.transfo_xl_config_file,
115 | args.pytorch_dump_folder_path,
116 | args.transfo_xl_dataset_file)
117 |
--------------------------------------------------------------------------------
/code/generation/model/optim.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import math
18 | import torch
19 |
20 |
21 | class Adam(torch.optim.Optimizer):
22 | """Implements Adam algorithm.
23 | This implementation is modified from torch.optim.Adam based on:
24 | `Fixed Weight Decay Regularization in Adam`
25 | (see https://arxiv.org/abs/1711.05101)
26 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
27 | Arguments:
28 | params (iterable): iterable of parameters to optimize or dicts defining
29 | parameter groups
30 | lr (float, optional): learning rate (default: 1e-3)
31 | betas (Tuple[float, float], optional): coefficients used for computing
32 | running averages of gradient and its square (default: (0.9, 0.999))
33 | eps (float, optional): term added to the denominator to improve
34 | numerical stability (default: 1e-8)
35 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
37 | algorithm from the paper `On the Convergence of Adam and Beyond`_
38 | .. _Adam\: A Method for Stochastic Optimization:
39 | https://arxiv.org/abs/1412.6980
40 | .. _On the Convergence of Adam and Beyond:
41 | https://openreview.net/forum?id=ryQu7f-RZ
42 | """
43 |
44 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
45 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
46 | super(Adam, self).__init__(params, defaults)
47 |
48 | def step(self, closure=None):
49 | """Performs a single optimization step.
50 | Arguments:
51 | closure (callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | loss = None
55 | if closure is not None:
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in group['params']:
60 | if p.grad is None:
61 | continue
62 | grad = p.grad.data
63 | if grad.is_sparse:
64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
65 | amsgrad = group['amsgrad']
66 |
67 | state = self.state[p]
68 |
69 | # State initialization
70 | if len(state) == 0:
71 | state['step'] = 0
72 | # Exponential moving average of gradient values
73 | state['exp_avg'] = torch.zeros_like(p.data)
74 | # Exponential moving average of squared gradient values
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 | if amsgrad:
77 | # Maintains max of all exp. moving avg. of sq. grad. values
78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
79 |
80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
81 | if amsgrad:
82 | max_exp_avg_sq = state['max_exp_avg_sq']
83 | beta1, beta2 = group['betas']
84 |
85 | state['step'] += 1
86 |
87 | # Decay the first and second moment running average coefficient
88 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
90 | if amsgrad:
91 | # Maintains the maximum of all 2nd moment running avg. till now
92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
93 | # Use the max. for normalizing running avg. of gradient
94 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
95 | else:
96 | denom = exp_avg_sq.sqrt().add_(group['eps'])
97 |
98 | bias_correction1 = 1 - beta1 ** state['step']
99 | bias_correction2 = 1 - beta2 ** state['step']
100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
101 |
102 | if group['weight_decay'] != 0:
103 | p.data.add_(-group['weight_decay'] * group['lr'], p.data)
104 |
105 | p.data.addcdiv_(-step_size, exp_avg, denom)
106 |
107 | return loss
108 |
109 |
110 | class NoamOpt:
111 | def __init__(self, embeddings_size, factor, warmup, optimizer):
112 | self.embeddings_size = embeddings_size
113 | self.factor = factor
114 | self.warmup = warmup
115 | self.optimizer = optimizer
116 |
117 | self._step = 1
118 |
119 | def state_dict(self):
120 | return {'step': self._step,
121 | 'optimizer': self.optimizer.state_dict()}
122 |
123 | def load_state_dict(self, state_dict):
124 | self._step = state_dict['step']
125 | self.optimizer.load_state_dict(state_dict['optimizer'])
126 |
127 | def zero_grad(self):
128 | return self.optimizer.zero_grad()
129 |
130 | @property
131 | def param_groups(self):
132 | return self.optimizer.param_groups
133 |
134 | def step(self):
135 | self._step += 1
136 | rate = self.rate()
137 | for p in self.optimizer.param_groups:
138 | p['lr'] = rate
139 | self.optimizer.step()
140 |
141 | def rate(self, step=None):
142 | if step is None:
143 | step = self._step
144 |
145 | return self.factor * (self.embeddings_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/optimization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for OpenAI GPT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | return 0.5 * (1.0 + torch.cos(math.pi * x))
30 |
31 | def warmup_constant(x, warmup=0.002):
32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
33 | Learning rate is 1. afterwards. """
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0
37 |
38 | def warmup_linear(x, warmup=0.002):
39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
40 | After `t_total`-th training step, learning rate is zero. """
41 | if x < warmup:
42 | return x/warmup
43 | return max((x-1.)/(warmup-1.), 0)
44 |
45 | SCHEDULES = {
46 | 'warmup_cosine':warmup_cosine,
47 | 'warmup_constant':warmup_constant,
48 | 'warmup_linear':warmup_linear,
49 | }
50 |
51 |
52 | class OpenAIAdam(Optimizer):
53 | """Implements Open AI version of Adam algorithm with weight decay fix.
54 | """
55 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
56 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
57 | vector_l2=False, max_grad_norm=-1, **kwargs):
58 | if lr is not required and lr < 0.0:
59 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
60 | if schedule not in SCHEDULES:
61 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
62 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
63 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
64 | if not 0.0 <= b1 < 1.0:
65 | raise ValueError("Invalid b1 parameter: {}".format(b1))
66 | if not 0.0 <= b2 < 1.0:
67 | raise ValueError("Invalid b2 parameter: {}".format(b2))
68 | if not e >= 0.0:
69 | raise ValueError("Invalid epsilon value: {}".format(e))
70 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
71 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
72 | max_grad_norm=max_grad_norm)
73 | super(OpenAIAdam, self).__init__(params, defaults)
74 |
75 | def get_lr(self):
76 | lr = []
77 | for group in self.param_groups:
78 | for p in group['params']:
79 | state = self.state[p]
80 | if len(state) == 0:
81 | return [0]
82 | if group['t_total'] != -1:
83 | schedule_fct = SCHEDULES[group['schedule']]
84 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
85 | else:
86 | lr_scheduled = group['lr']
87 | lr.append(lr_scheduled)
88 | return lr
89 |
90 | def step(self, closure=None):
91 | """Performs a single optimization step.
92 |
93 | Arguments:
94 | closure (callable, optional): A closure that reevaluates the model
95 | and returns the loss.
96 | """
97 | loss = None
98 | if closure is not None:
99 | loss = closure()
100 |
101 | warned_for_t_total = False
102 |
103 | for group in self.param_groups:
104 | for p in group['params']:
105 | if p.grad is None:
106 | continue
107 | grad = p.grad.data
108 | if grad.is_sparse:
109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
110 |
111 | state = self.state[p]
112 |
113 | # State initialization
114 | if len(state) == 0:
115 | state['step'] = 0
116 | # Exponential moving average of gradient values
117 | state['exp_avg'] = torch.zeros_like(p.data)
118 | # Exponential moving average of squared gradient values
119 | state['exp_avg_sq'] = torch.zeros_like(p.data)
120 |
121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
122 | beta1, beta2 = group['b1'], group['b2']
123 |
124 | state['step'] += 1
125 |
126 | # Add grad clipping
127 | if group['max_grad_norm'] > 0:
128 | clip_grad_norm_(p, group['max_grad_norm'])
129 |
130 | # Decay the first and second moment running average coefficient
131 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
132 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
133 | denom = exp_avg_sq.sqrt().add_(group['e'])
134 |
135 | bias_correction1 = 1 - beta1 ** state['step']
136 | bias_correction2 = 1 - beta2 ** state['step']
137 |
138 | if group['t_total'] != -1:
139 | schedule_fct = SCHEDULES[group['schedule']]
140 | progress = state['step']/group['t_total']
141 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
142 | # warning for exceeding t_total (only active with warmup_linear
143 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
144 | logger.warning(
145 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
146 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
147 | warned_for_t_total = True
148 | # end warning
149 | else:
150 | lr_scheduled = group['lr']
151 |
152 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
153 |
154 | p.data.addcdiv_(-step_size, exp_avg, denom)
155 |
156 | # Add weight decay at the end (fixed version)
157 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
158 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
159 |
160 | return loss
161 |
--------------------------------------------------------------------------------
/code/generation/model/utils.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import re
18 | import os
19 | import json
20 | import random
21 | from collections import namedtuple, Counter
22 |
23 | import torch
24 | import numpy as np
25 | from scipy.interpolate import RectBivariateSpline
26 | from torch.utils.checkpoint import checkpoint
27 |
28 |
29 | def set_seed(seed):
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | random.seed(seed)
33 |
34 |
35 | def pad_sequence(sequences, batch_first=False, padding_value=0):
36 | # assuming trailing dimensions and type of all the Tensors
37 | # in sequences are same and fetching those from sequences[0]
38 | max_size = sequences[0].size()
39 | trailing_dims = max_size[1:]
40 | max_len = max([s.size(0) for s in sequences])
41 | if batch_first:
42 | out_dims = (len(sequences), max_len) + trailing_dims
43 | else:
44 | out_dims = (max_len, len(sequences)) + trailing_dims
45 |
46 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
47 | for i, tensor in enumerate(sequences):
48 | length = tensor.size(0)
49 | # use index notation to prevent duplicate references to the tensor
50 | if batch_first:
51 | out_tensor[i, :length, ...] = tensor
52 | else:
53 | out_tensor[:length, i, ...] = tensor
54 |
55 | return out_tensor
56 |
57 |
58 | def checkpoint_sequential(functions, segments, *inputs):
59 | def run_function(start, end, functions):
60 | def forward(*inputs):
61 | for j in range(start, end + 1):
62 | inputs = functions[j](*inputs)
63 | return inputs
64 | return forward
65 |
66 | if isinstance(functions, torch.nn.Sequential):
67 | functions = list(functions.children())
68 |
69 | segment_size = len(functions) // segments
70 | # the last chunk has to be non-volatile
71 | end = -1
72 | for start in range(0, segment_size * (segments - 1), segment_size):
73 | end = start + segment_size - 1
74 | inputs = checkpoint(run_function(start, end, functions), *inputs)
75 | if not isinstance(inputs, tuple):
76 | inputs = (inputs,)
77 | return run_function(end + 1, len(functions) - 1, functions)(*inputs)
78 |
79 |
80 | def f1_score(predictions, targets, average=True):
81 | def f1_score_items(pred_items, gold_items):
82 | common = Counter(gold_items) & Counter(pred_items)
83 | num_same = sum(common.values())
84 |
85 | if num_same == 0:
86 | return 0
87 |
88 | precision = num_same / len(pred_items)
89 | recall = num_same / len(gold_items)
90 | f1 = (2 * precision * recall) / (precision + recall)
91 |
92 | return f1
93 |
94 | scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)]
95 |
96 | if average:
97 | return sum(scores) / len(scores)
98 |
99 | return scores
100 |
101 |
102 | def openai_transformer_config():
103 | class dotdict(dict):
104 | __getattr__ = dict.get
105 | __setattr__ = dict.__setitem__
106 | __delattr__ = dict.__delitem__
107 |
108 | cfg = dotdict({'n_layers': 12, 'n_embeddings': 40477, 'n_pos_embeddings': 512,
109 | 'embeddings_size': 768, 'n_heads': 12, 'dropout': 0.1,
110 | 'embed_dropout': 0.1, 'attn_dropout': 0.1, 'ff_dropout': 0.1})
111 |
112 | return cfg
113 |
114 |
115 | def load_openai_weights_chinese(model, directory):
116 | openai_model = torch.load(directory)
117 | openai_model.pop('decoder.pre_softmax.weight')
118 | b = list(openai_model.keys())
119 | for i in b:
120 | openai_model['decoder.' + i] = openai_model.pop(i)
121 | model.load_state_dict(openai_model)
122 |
123 |
124 | def load_openai_weights(model, directory, n_special_tokens=0):
125 | # TODO: add check of shapes
126 |
127 | parameters_names_path = os.path.join(directory, 'parameters_names.json')
128 | parameters_shapes_path = os.path.join(directory, 'parameters_shapes.json')
129 | parameters_weights_paths = [os.path.join(directory, 'params_{}.npy'.format(n)) for n in range(10)]
130 |
131 | with open(parameters_names_path, 'r') as parameters_names_file:
132 | parameters_names = json.load(parameters_names_file)
133 |
134 | with open(parameters_shapes_path, 'r') as parameters_shapes_file:
135 | parameters_shapes = json.load(parameters_shapes_file)
136 |
137 | parameters_weights = [np.load(path) for path in parameters_weights_paths]
138 | parameters_offsets = np.cumsum([np.prod(shape) for shape in parameters_shapes])
139 | parameters_weights = np.split(np.concatenate(parameters_weights, 0), parameters_offsets)[:-1]
140 | parameters_weights = [p.reshape(s) for p, s in zip(parameters_weights, parameters_shapes)]
141 |
142 | parameters_weights[1] = parameters_weights[1][1:] # skip 0 -
143 |
144 |
145 | if model.pos_embeddings.num_embeddings - 1 > parameters_weights[0].shape[0]:
146 | xx = np.linspace(0, parameters_weights[0].shape[0], model.pos_embeddings.num_embeddings - 1)
147 | new_kernel = RectBivariateSpline(np.arange(parameters_weights[0].shape[0]),
148 | np.arange(parameters_weights[0].shape[1]),
149 | parameters_weights[0])
150 | parameters_weights[0] = new_kernel(xx, np.arange(parameters_weights[0].shape[1]))
151 |
152 | parameters_weights[0] = parameters_weights[0][:model.pos_embeddings.num_embeddings - 1]
153 | parameters_weights[1] = parameters_weights[1][:model.embeddings.num_embeddings - n_special_tokens]
154 |
155 | model.pos_embeddings.weight.data[1:] = torch.from_numpy(parameters_weights[0])
156 | model.embeddings.weight.data[n_special_tokens:] = torch.from_numpy(parameters_weights[1])
157 |
158 |
159 | parameters_weights = parameters_weights[2:]
160 |
161 | for name, weights in zip(parameters_names, parameters_weights):
162 | name = name[6:] # skip "model/"
163 | assert name[-2:] == ':0'
164 | name = name[:-2]
165 | name = name.split('/')
166 |
167 | pointer = model
168 | for m_name in name:
169 | if re.fullmatch(r'[A-Za-z]+\d+', m_name):
170 | l = re.split(r'(\d+)', m_name)
171 | else:
172 | l = [m_name]
173 |
174 | pointer = getattr(pointer, l[0])
175 |
176 | if len(l) >= 2:
177 | num = int(l[1])
178 | pointer = pointer[num]
179 |
180 | if len(weights.shape) == 3: # conv1d to linear
181 | weights = weights[0].transpose((1, 0))
182 |
183 | pointer.data[...] = torch.from_numpy(weights)
184 |
--------------------------------------------------------------------------------
/code/generation/model/dataset.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | from torch.utils.data import Dataset
18 | import random
19 |
20 | class S2sDataset_dialog(Dataset):
21 | def __init__(self, paths, vocab, max_lengths=2048):
22 | if isinstance(paths, str):
23 | print('path is str')
24 | paths = [paths]
25 |
26 | self.vocab = vocab
27 | self.max_lengths = max_lengths
28 | self.data = S2sDataset_dialog.make_dataset(paths, vocab, max_lengths)
29 | print(len(self.data))
30 |
31 | @staticmethod
32 | def make_dataset(paths, vocab, max_lengths):
33 | dataset = []
34 | for path in paths:
35 | with open(path, 'r', encoding='utf8') as fr:
36 | lines = fr.readlines()
37 | for line in lines:
38 | context, response = line.strip('\n').split('\t')
39 | context = vocab.string2ids(' '.join(context))
40 | response = vocab.string2ids(' '.join(response))
41 | dataset.append((context, response))
42 | random.shuffle(dataset)
43 | return dataset
44 |
45 | def __len__(self):
46 | return len(self.data)
47 |
48 | def __getitem__(self, idx):
49 | context, response = self.data[idx]
50 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id]
51 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id]
52 | context = context[-40:]
53 | response = response[:40]
54 | return context, response
55 |
56 | class S2sDataset_dialog_overlap(Dataset):
57 | def __init__(self, paths, vocab, max_lengths=2048):
58 | if isinstance(paths, str):
59 | print('path is str')
60 | paths = [paths]
61 |
62 | self.vocab = vocab
63 | self.max_lengths = max_lengths
64 | self.data = S2sDataset_dialog_overlap.make_dataset(paths, vocab, max_lengths)
65 | print(len(self.data))
66 | print(self.data[0])
67 |
68 | @staticmethod
69 | def make_dataset(paths, vocab, max_lengths):
70 | dataset = []
71 | for path in paths:
72 | with open(path, 'r', encoding='utf8') as fr:
73 | lines = fr.readlines()
74 | for line in lines:
75 | context, response, overlap = line.strip('\n').split('\t')
76 | context = vocab.string2ids(' '.join(context))
77 | response = vocab.string2ids(' '.join(response))
78 | overlap = int(overlap)
79 | if overlap < 2:
80 | overlap_sym = vocab.token2id['']
81 | elif overlap < 4:
82 | overlap_sym = vocab.token2id['']
83 | elif overlap < 6:
84 | overlap_sym = vocab.token2id['']
85 | else:
86 | overlap_sym = vocab.token2id['']
87 | dataset.append((context, response, overlap_sym))
88 | random.shuffle(dataset)
89 | return dataset
90 |
91 | def __len__(self):
92 | return len(self.data)
93 |
94 | def __getitem__(self, idx):
95 | context, response, overlap = self.data[idx]
96 | context = [self.vocab.bos_id] + context[-37:] + [overlap] + [self.vocab.eos_id]
97 | response = [self.vocab.bos_id] + response[-38:] + [self.vocab.eos_id]
98 | #context = context[-40:]
99 | #response = response[:40]
100 | return context, response
101 |
102 | class S2sDataset_poem(Dataset):
103 | def __init__(self, paths, vocab, max_lengths=2048):
104 | if isinstance(paths, str):
105 | paths = [paths]
106 |
107 | self.vocab = vocab
108 | self.max_lengths = max_lengths
109 | self.data = S2sDataset_poem.make_dataset_wu(paths[0], vocab, max_lengths)
110 |
111 | @staticmethod
112 | def make_dataset_xiandai(paths, vocab, max_lengths):
113 | dataset = []
114 | with open(paths, 'r', encoding='utf8') as fr:
115 | lines = fr.readlines()
116 | for line in lines:
117 | context, response = line.strip('\n').split('\t')
118 | context = vocab.string2ids(context)
119 | response = vocab.string2ids(response)
120 | dataset.append((context[:18], response[:118]))
121 | return dataset
122 |
123 | @staticmethod
124 | def make_dataset_wu(paths, vocab, max_lengths):
125 | dataset = []
126 | with open(paths, 'r', encoding='utf8') as fr:
127 | lines = fr.readlines()
128 | for line in lines:
129 | target, source = line.strip('\n').split('\t')
130 | source = ' '.join([' '.join(i) for i in source.split()])
131 | target = ' '.join(target)
132 | context = vocab.string2ids(source)
133 | response = vocab.string2ids(target)
134 | dataset.append((context[:18], response[:30]))
135 | return dataset
136 |
137 | def __len__(self):
138 | return len(self.data)
139 |
140 | def __getitem__(self, idx):
141 | context, response = self.data[idx]
142 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id]
143 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id]
144 | # context = context[:20]
145 | # response = response[:120]
146 | return context, response
147 |
148 | class S2sDataset_meme(Dataset):
149 | def __init__(self, paths, vocab, max_lengths=2048):
150 | if isinstance(paths, str):
151 | paths = [paths]
152 |
153 | self.vocab = vocab
154 | self.max_lengths = max_lengths
155 | self.data = S2sDataset_meme.make_dataset(paths[0], vocab, max_lengths)
156 |
157 | @staticmethod
158 | def make_dataset(paths, vocab, max_lengths):
159 | dataset = []
160 | with open(paths, 'r', encoding='utf8') as fr:
161 | lines = fr.readlines()
162 | for line in lines:
163 | context, response = line.strip('\n').split('\t')
164 | context = vocab.string2ids(context)
165 | response = vocab.string2ids(response)
166 | dataset.append((context, response))
167 | return dataset
168 |
169 | def __len__(self):
170 | return len(self.data)
171 |
172 | def __getitem__(self, idx):
173 | context, response = self.data[idx]
174 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id]
175 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id]
176 | context = context[:30]
177 | response = response[:30]
178 | return context, response
--------------------------------------------------------------------------------
/code/generation/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import os
4 | from model.utils import load_openai_weights_chinese, set_seed, f1_score
5 | from model.transformer_model import TransformerModel
6 | from model.trainer import Trainer
7 | from model.text import myVocab
8 | from model.dataset import S2sDataset_dialog
9 | from config import get_model_config_dialog, get_trainer_config_dialog
10 | from torch.nn.parallel import DistributedDataParallel
11 | import argparse
12 |
13 |
14 | def main():
15 | model_config = get_model_config_dialog()
16 | trainer_config = get_trainer_config_dialog()
17 |
18 | set_seed(trainer_config.seed)
19 | device = torch.device(trainer_config.device)
20 | # zrs
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--local_rank", type=int, default=-1)
23 | args = parser.parse_args()
24 | distributed = (args.local_rank != -1)
25 | if distributed:
26 | print(args.local_rank)
27 | torch.cuda.set_device(args.local_rank)
28 | device = torch.device("cuda", args.local_rank)
29 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
30 |
31 | vocab = myVocab(model_config.vocab_path)
32 |
33 | transformer = TransformerModel(n_layers=model_config.n_layers,
34 | n_embeddings=len(vocab),
35 | n_pos_embeddings=model_config.n_pos_embeddings,
36 | embeddings_size=model_config.embeddings_size,
37 | padding_idx=vocab.pad_id,
38 | n_heads=model_config.n_heads,
39 | dropout=model_config.dropout,
40 | embed_dropout=model_config.embed_dropout,
41 | attn_dropout=model_config.attn_dropout,
42 | ff_dropout=model_config.ff_dropout,
43 | bos_id=vocab.bos_id,
44 | eos_id=vocab.eos_id,
45 | max_seq_len=model_config.max_seq_len,
46 | beam_size=model_config.beam_size,
47 | length_penalty=model_config.length_penalty,
48 | n_segments=model_config.n_segments,
49 | annealing_topk=model_config.annealing_topk,
50 | temperature=model_config.temperature,
51 | annealing=model_config.annealing,
52 | diversity_coef=model_config.diversity_coef,
53 | diversity_groups=model_config.diversity_groups)
54 |
55 | if not trainer_config.load_last:
56 | openai_model = torch.load(trainer_config.openai_parameters_dir, map_location=device)
57 | openai_model.pop('decoder.pre_softmax.weight')
58 | b = list(openai_model.keys())
59 | for i in b:
60 | temp = i.split('.')
61 | keep = True
62 | for j in range(model_config.n_layers, 12):
63 | if str(j) in temp:
64 | keep = False
65 | break
66 | if keep:
67 | openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
68 | else:
69 | print(i)
70 | openai_model.pop(i)
71 | #openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
72 | transformer.transformer_module.load_state_dict(openai_model, strict=True)
73 | # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir)
74 | print('OpenAI weights chinese loaded from {}'.format(trainer_config.openai_parameters_dir))
75 |
76 | train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1)
77 | test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1)
78 |
79 | model_trainer = Trainer(transformer,
80 | train_dataset,
81 | test_dataset,
82 | batch_size=trainer_config.batch_size,
83 | batch_split=trainer_config.batch_split,
84 | lr=trainer_config.lr,
85 | lr_warmup=trainer_config.lr_warmup,
86 | lm_weight=trainer_config.lm_weight,
87 | risk_weight=trainer_config.risk_weight,
88 | n_jobs=trainer_config.n_jobs,
89 | clip_grad=trainer_config.clip_grad,
90 | # label_smoothing=trainer_config.label_smoothing,
91 | device=device,
92 | ignore_idxs=vocab.special_tokens_ids,
93 | distributed=distributed)
94 | if distributed:
95 | model_trainer.model.transformer_module = DistributedDataParallel(model_trainer.model.transformer_module,
96 | device_ids=[args.local_rank], output_device=args.local_rank)
97 |
98 | start_epoch = 0
99 | init_epoch = 0
100 |
101 | if trainer_config.load_last:
102 | state_dict = torch.load(trainer_config.last_checkpoint_path + str(init_epoch - 1), map_location=device)
103 | model_trainer.load_state_dict(state_dict)
104 | # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1
105 | start_epoch = init_epoch
106 | print('Weights loaded from {}'.format(trainer_config.last_checkpoint_path + str(init_epoch - 1)))
107 |
108 |
109 | # helpers -----------------------------------------------------
110 | def save_func(epoch):
111 | dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1])
112 | if not os.path.exists(dirs):
113 | os.makedirs(dirs)
114 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path)
115 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path + str(epoch))
116 | if os.path.exists(trainer_config.last_checkpoint_path + str(epoch - 100)):
117 | os.remove(trainer_config.last_checkpoint_path + str(epoch - 100))
118 |
119 | def sample_text_func(epoch):
120 | n_samples = 5
121 | samples_idxs = random.sample(range(len(test_dataset)), n_samples)
122 | samples = [test_dataset[idx] for idx in samples_idxs]
123 | for source, target in samples:
124 | contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [source] if len(c) > 0]
125 | prediction = model_trainer.model.predict(contexts)[0]
126 | source_str = vocab.ids2string(source)
127 | target_str = vocab.ids2string(target[1:-1])
128 | prediction_str = vocab.ids2string(prediction)
129 | print('\n')
130 | print('Source:{}'.format(source_str))
131 | print('Target:\n\t{}'.format(target_str))
132 | print('Prediction:\n\t{}'.format(prediction_str))
133 |
134 | def test_func(epoch):
135 | if (epoch+1) % trainer_config.test_period == 0:
136 | metric_funcs = {'f1_score': f1_score}
137 | model_trainer.test(metric_funcs)
138 |
139 | def f1_risk(predictions, targets):
140 | scores = f1_score(predictions, targets, average=False)
141 | return [1-s for s in scores]
142 |
143 | # helpers -----------------------------------------------------
144 |
145 | # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1])
146 | try:
147 | if args.local_rank in [-1, 0]:
148 | model_trainer.train(start_epoch, trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func],
149 | risk_func=f1_risk)
150 | else:
151 | model_trainer.train(start_epoch, trainer_config.n_epochs)
152 | except (KeyboardInterrupt, Exception, RuntimeError) as e:
153 | torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path)
154 | raise e
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
160 |
--------------------------------------------------------------------------------
/code/generation/model/transformer_module.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import math
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from .utils import checkpoint_sequential
22 |
23 |
24 | class MultiheadAttention(nn.Module):
25 | @classmethod
26 | def _get_future_mask(cls, size, device):
27 | if not hasattr(cls, '_future_mask') or cls._future_mask.device != device or cls._future_mask.shape < size:
28 | cls._future_mask = torch.triu(torch.ones(size[0], size[1], dtype=torch.uint8, device=device), 1)
29 |
30 | mask = cls._future_mask[:size[0], :size[1]]
31 |
32 | return mask
33 |
34 | def __init__(self, n_features, n_heads, dropout):
35 | super(MultiheadAttention, self).__init__()
36 | assert n_features % n_heads == 0
37 |
38 | self.n_features = n_features
39 | self.n_heads = n_heads
40 | self.qkv_proj = nn.Linear(n_features, 3 * n_features)
41 | self.out_proj = nn.Linear(n_features, n_features)
42 | self.dropout = nn.Dropout(dropout)
43 |
44 | self._init_weights()
45 |
46 | def _init_weights(self):
47 | nn.init.normal_(self.qkv_proj.weight, std=0.02)
48 | nn.init.normal_(self.out_proj.weight, std=0.02)
49 |
50 | def _split_heads(self, x, is_key=False):
51 | x = x.view(x.shape[0], x.shape[1], self.n_heads, self.n_features // self.n_heads)
52 | x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3)
53 |
54 | return x
55 |
56 | def _attn(self, q, k, v, apply_future_mask=True, padding_mask=None):
57 | w = torch.matmul(q, k) / math.sqrt(self.n_features // self.n_heads)
58 |
59 | if apply_future_mask:
60 | future_mask = MultiheadAttention._get_future_mask(w.shape[-2:], w.device).unsqueeze(0).unsqueeze(0)
61 | w.masked_fill_(future_mask, float('-inf'))
62 |
63 | if padding_mask is not None:
64 | w.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
65 |
66 | w = F.softmax(w, dim=-1)
67 | w = self.dropout(w)
68 |
69 | if padding_mask is not None:
70 | w.masked_fill_(padding_mask.all(dim=-1).unsqueeze(1).unsqueeze(2).unsqueeze(3), 0)
71 |
72 | out = torch.matmul(w, v)
73 |
74 | return out
75 |
76 | def _merge_heads(self, x):
77 | x = x.permute(0, 2, 1, 3).contiguous()
78 | x = x.view(x.shape[0], x.shape[1], self.n_features)
79 |
80 | return x
81 |
82 | def forward(self, query, key, value, padding_mask):
83 | qkv_same = (query.data_ptr() == key.data_ptr() == value.data_ptr())
84 | kv_same = (key.data_ptr() == value.data_ptr())
85 |
86 | if qkv_same:
87 | query, key, value = self.qkv_proj(query).split(self.n_features, dim=-1)
88 | apply_future_mask = True # self-attention
89 | elif kv_same:
90 | q_w, q_b = self.qkv_proj.weight[:self.n_features, :], self.qkv_proj.bias[:self.n_features]
91 | query = F.linear(query, q_w, q_b)
92 | kv_w, kv_b = self.qkv_proj.weight[self.n_features:, :], self.qkv_proj.bias[self.n_features:]
93 | key, value = F.linear(key, kv_w, kv_b).split(self.n_features, dim=-1)
94 | apply_future_mask = False
95 | else:
96 | assert False
97 |
98 | query = self._split_heads(query)
99 | key = self._split_heads(key, is_key=True)
100 | value = self._split_heads(value)
101 |
102 | x = self._attn(query, key, value, apply_future_mask, padding_mask)
103 | x = self._merge_heads(x)
104 |
105 | x = self.out_proj(x)
106 |
107 | return x
108 |
109 |
110 | class FeedForward(nn.Module):
111 | @staticmethod
112 | def gelu(x):
113 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
114 |
115 | def __init__(self, in_features, middle_features, dropout):
116 | super(FeedForward, self).__init__()
117 |
118 | self.layer_1 = nn.Linear(in_features, middle_features)
119 | self.layer_2 = nn.Linear(middle_features, in_features)
120 | self.dropout = nn.Dropout(dropout)
121 |
122 | self._init_weights()
123 |
124 | def _init_weights(self):
125 | nn.init.normal_(self.layer_1.weight, std=0.02)
126 | nn.init.normal_(self.layer_2.weight, std=0.02)
127 |
128 | def forward(self, x):
129 | x = FeedForward.gelu(self.layer_1(x))
130 | x = self.dropout(x)
131 | x = self.layer_2(x)
132 |
133 | return x
134 |
135 |
136 | class TransformerBlock(nn.Module):
137 | def __init__(self, n_features, n_heads, dropout, attn_dropout, ff_dropout):
138 | super(TransformerBlock, self).__init__()
139 |
140 | self.attn = MultiheadAttention(n_features, n_heads, attn_dropout)
141 | self.attn_norm = nn.LayerNorm(n_features)
142 | self.ff = FeedForward(n_features, 4 * n_features, ff_dropout)
143 | self.ff_norm = nn.LayerNorm(n_features)
144 | self.dropout = nn.Dropout(dropout)
145 |
146 | def forward(self, x, padding_mask, *contexts):
147 | '''contexts = [(context1, padding_mask1), ...]'''
148 |
149 | inputs = (x, padding_mask) + contexts
150 |
151 | full_attn = 0
152 | n_attn = len(inputs) // 2
153 | for i in range(0, len(inputs), 2):
154 | c, m = inputs[i], inputs[i+1].byte()
155 | a = self.attn(x, c, c, m)
156 | full_attn += (a / n_attn)
157 |
158 | full_attn = self.dropout(full_attn)
159 | x = self.attn_norm(x + full_attn)
160 |
161 | f = self.ff(x)
162 | f = self.dropout(f)
163 | x = self.ff_norm(x + f)
164 |
165 | return (x, padding_mask) + contexts
166 |
167 |
168 | class TransformerModule(nn.Module):
169 | def __init__(self, n_layers, n_embeddings, n_pos_embeddings, embeddings_size,
170 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, ff_dropout,
171 | n_segments=None):
172 | super(TransformerModule, self).__init__()
173 |
174 | self.embeddings = nn.Embedding(n_embeddings, embeddings_size, padding_idx=padding_idx)
175 | self.pos_embeddings = nn.Embedding(n_pos_embeddings + 1, embeddings_size, padding_idx=0)
176 | self.embed_dropout = nn.Dropout(embed_dropout)
177 | self.layers = nn.ModuleList([TransformerBlock(embeddings_size, n_heads, dropout, attn_dropout, ff_dropout) for _ in range(n_layers)])
178 | self.n_segments = n_segments
179 |
180 | self._init_weights()
181 |
182 | def _init_weights(self):
183 | nn.init.normal_(self.embeddings.weight, std=0.02)
184 | nn.init.normal_(self.pos_embeddings.weight, std=0.02)
185 |
186 | def forward(self, x, enc_contexts=[]):
187 | padding_mask = x.eq(self.embeddings.padding_idx)
188 |
189 | positions = torch.cumsum(~padding_mask, dim=-1, dtype=torch.long)
190 | positions.masked_fill_(padding_mask, self.pos_embeddings.padding_idx)
191 |
192 | x = self.embeddings(x) * math.sqrt(self.embeddings.embedding_dim) + self.pos_embeddings(positions)
193 | x = self.embed_dropout(x)
194 |
195 | enc_contexts = sum(enc_contexts, ())
196 |
197 | if self.n_segments is not None:
198 | padding_mask = padding_mask.float() # fucking checkpoint_sequential
199 | padding_mask.requires_grad_() # fucking checkpoint_sequential
200 | out = checkpoint_sequential(self.layers, self.n_segments, x, padding_mask, *enc_contexts)
201 | x = out[0]
202 | else:
203 | for layer in self.layers:
204 | out = layer(x, padding_mask, *enc_contexts)
205 | x = out[0]
206 |
207 | return x, padding_mask
208 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | import boto3
19 | import requests
20 | from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31 | Path.home() / '.pytorch_pretrained_bert'))
32 | except (AttributeError, ImportError):
33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 | parsed = urlparse(url_or_filename)
98 |
99 | if parsed.scheme in ('http', 'https', 's3'):
100 | # URL, so get it from the cache (downloading if necessary)
101 | return get_from_cache(url_or_filename, cache_dir)
102 | elif os.path.exists(url_or_filename):
103 | # File, and it exists.
104 | return url_or_filename
105 | elif parsed.scheme == '':
106 | # File, but it doesn't exist.
107 | raise EnvironmentError("file {} not found".format(url_or_filename))
108 | else:
109 | # Something unknown
110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111 |
112 |
113 | def split_s3_path(url):
114 | """Split a full s3 path into the bucket name and path."""
115 | parsed = urlparse(url)
116 | if not parsed.netloc or not parsed.path:
117 | raise ValueError("bad s3 path {}".format(url))
118 | bucket_name = parsed.netloc
119 | s3_path = parsed.path
120 | # Remove '/' at beginning of path.
121 | if s3_path.startswith("/"):
122 | s3_path = s3_path[1:]
123 | return bucket_name, s3_path
124 |
125 |
126 | def s3_request(func):
127 | """
128 | Wrapper function for s3 requests in order to create more helpful error
129 | messages.
130 | """
131 |
132 | @wraps(func)
133 | def wrapper(url, *args, **kwargs):
134 | try:
135 | return func(url, *args, **kwargs)
136 | except ClientError as exc:
137 | if int(exc.response["Error"]["Code"]) == 404:
138 | raise EnvironmentError("file {} not found".format(url))
139 | else:
140 | raise
141 |
142 | return wrapper
143 |
144 |
145 | @s3_request
146 | def s3_etag(url):
147 | """Check ETag on S3 object."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_object = s3_resource.Object(bucket_name, s3_path)
151 | return s3_object.e_tag
152 |
153 |
154 | @s3_request
155 | def s3_get(url, temp_file):
156 | """Pull a file directly from S3."""
157 | s3_resource = boto3.resource("s3")
158 | bucket_name, s3_path = split_s3_path(url)
159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160 |
161 |
162 | def http_get(url, temp_file):
163 | req = requests.get(url, stream=True)
164 | content_length = req.headers.get('Content-Length')
165 | total = int(content_length) if content_length is not None else None
166 | progress = tqdm(unit="B", total=total)
167 | for chunk in req.iter_content(chunk_size=1024):
168 | if chunk: # filter out keep-alive new chunks
169 | progress.update(len(chunk))
170 | temp_file.write(chunk)
171 | progress.close()
172 |
173 |
174 | def get_from_cache(url, cache_dir=None):
175 | """
176 | Given a URL, look for the corresponding dataset in the local cache.
177 | If it's not there, download it. Then return the path to the cached file.
178 | """
179 | if cache_dir is None:
180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182 | cache_dir = str(cache_dir)
183 |
184 | if not os.path.exists(cache_dir):
185 | os.makedirs(cache_dir)
186 |
187 | # Get eTag to add to filename, if it exists.
188 | if url.startswith("s3://"):
189 | etag = s3_etag(url)
190 | else:
191 | response = requests.head(url, allow_redirects=True)
192 | if response.status_code != 200:
193 | raise IOError("HEAD request failed for url {} with status code {}"
194 | .format(url, response.status_code))
195 | etag = response.headers.get("ETag")
196 |
197 | filename = url_to_filename(url, etag)
198 |
199 | # get cache path to put the file
200 | cache_path = os.path.join(cache_dir, filename)
201 |
202 | if not os.path.exists(cache_path):
203 | # Download to temporary file, then copy to cache dir once finished.
204 | # Otherwise you get corrupt cache entries if the download gets interrupted.
205 | with tempfile.NamedTemporaryFile() as temp_file:
206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207 |
208 | # GET file object
209 | if url.startswith("s3://"):
210 | s3_get(url, temp_file)
211 | else:
212 | http_get(url, temp_file)
213 |
214 | # we are copying the file before closing it, so flush to avoid truncation
215 | temp_file.flush()
216 | # shutil.copyfileobj() starts at the current position, so go to the start
217 | temp_file.seek(0)
218 |
219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220 | with open(cache_path, 'wb') as cache_file:
221 | shutil.copyfileobj(temp_file, cache_file)
222 |
223 | logger.info("creating metadata file for %s", cache_path)
224 | meta = {'url': url, 'etag': etag}
225 | meta_path = cache_path + '.json'
226 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
227 | json.dump(meta, meta_file)
228 |
229 | logger.info("removing temp file %s", temp_file.name)
230 |
231 | return cache_path
232 |
233 |
234 | def read_set_from_file(filename):
235 | '''
236 | Extract a de-duped collection (set) of text from a file.
237 | Expected file format is one item per line.
238 | '''
239 | collection = set()
240 | with open(filename, 'r', encoding='utf-8') as file_:
241 | for line in file_:
242 | collection.add(line.rstrip())
243 | return collection
244 |
245 |
246 | def get_file_extension(path, dot=True, lower=True):
247 | ext = os.path.splitext(path)[1]
248 | ext = ext if dot else ext[1:]
249 | return ext.lower() if lower else ext
250 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | # import regex as re
23 | import re
24 | from io import open
25 |
26 | try:
27 | from functools import lru_cache
28 | except ImportError:
29 | # Just a dummy decorator to get the checks to run on python2
30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
31 | def lru_cache():
32 | return lambda func: func
33 |
34 | from .file_utils import cached_path
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
40 | }
41 | PRETRAINED_MERGES_ARCHIVE_MAP = {
42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
43 | }
44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
45 | 'gpt2': 1024,
46 | }
47 | VOCAB_NAME = 'vocab.json'
48 | MERGES_NAME = 'merges.txt'
49 |
50 | @lru_cache()
51 | def bytes_to_unicode():
52 | """
53 | Returns list of utf-8 byte and a corresponding list of unicode strings.
54 | The reversible bpe codes work on unicode strings.
55 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
56 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
57 | This is a signficant percentage of your normal, say, 32K bpe vocab.
58 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
59 | And avoids mapping to whitespace/control characters the bpe code barfs on.
60 | """
61 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
62 | cs = bs[:]
63 | n = 0
64 | for b in range(2**8):
65 | if b not in bs:
66 | bs.append(b)
67 | cs.append(2**8+n)
68 | n += 1
69 | cs = [chr(n) for n in cs]
70 | return dict(zip(bs, cs))
71 |
72 | def get_pairs(word):
73 | """Return set of symbol pairs in a word.
74 |
75 | Word is represented as tuple of symbols (symbols being variable-length strings).
76 | """
77 | pairs = set()
78 | prev_char = word[0]
79 | for char in word[1:]:
80 | pairs.add((prev_char, char))
81 | prev_char = char
82 | return pairs
83 |
84 | class GPT2Tokenizer(object):
85 | """
86 | GPT-2 BPE tokenizer. Peculiarities:
87 | - Byte-level BPE
88 | """
89 | @classmethod
90 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
91 | """
92 | Instantiate a PreTrainedBertModel from a pre-trained model file.
93 | Download and cache the pre-trained model file if needed.
94 | """
95 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
96 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
97 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
98 | else:
99 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
100 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
101 | # redirect to the cache, if necessary
102 | try:
103 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
104 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
105 | except EnvironmentError:
106 | logger.error(
107 | "Model name '{}' was not found in model name list ({}). "
108 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
109 | "at this path or url.".format(
110 | pretrained_model_name_or_path,
111 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
112 | pretrained_model_name_or_path,
113 | vocab_file, merges_file))
114 | return None
115 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
116 | logger.info("loading vocabulary file {}".format(vocab_file))
117 | logger.info("loading merges file {}".format(merges_file))
118 | else:
119 | logger.info("loading vocabulary file {} from cache at {}".format(
120 | vocab_file, resolved_vocab_file))
121 | logger.info("loading merges file {} from cache at {}".format(
122 | merges_file, resolved_merges_file))
123 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
124 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
125 | # than the number of positional embeddings
126 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
127 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
128 | # Instantiate tokenizer.
129 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
130 | return tokenizer
131 |
132 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
133 | self.max_len = max_len if max_len is not None else int(1e12)
134 | self.encoder = json.load(open(vocab_file))
135 | self.decoder = {v:k for k,v in self.encoder.items()}
136 | self.errors = errors # how to handle errors in decoding
137 | self.byte_encoder = bytes_to_unicode()
138 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
139 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
140 | bpe_merges = [tuple(merge.split()) for merge in bpe_data]
141 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
142 | self.cache = {}
143 |
144 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
145 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
146 |
147 | def __len__(self):
148 | return len(self.encoder)
149 |
150 | def bpe(self, token):
151 | if token in self.cache:
152 | return self.cache[token]
153 | word = tuple(token)
154 | pairs = get_pairs(word)
155 |
156 | if not pairs:
157 | return token
158 |
159 | while True:
160 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
161 | if bigram not in self.bpe_ranks:
162 | break
163 | first, second = bigram
164 | new_word = []
165 | i = 0
166 | while i < len(word):
167 | try:
168 | j = word.index(first, i)
169 | new_word.extend(word[i:j])
170 | i = j
171 | except:
172 | new_word.extend(word[i:])
173 | break
174 |
175 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
176 | new_word.append(first+second)
177 | i += 2
178 | else:
179 | new_word.append(word[i])
180 | i += 1
181 | new_word = tuple(new_word)
182 | word = new_word
183 | if len(word) == 1:
184 | break
185 | else:
186 | pairs = get_pairs(word)
187 | word = ' '.join(word)
188 | self.cache[token] = word
189 | return word
190 |
191 | def encode(self, text):
192 | bpe_tokens = []
193 | for token in re.findall(self.pat, text):
194 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
195 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
196 | if len(bpe_tokens) > self.max_len:
197 | logger.warning(
198 | "Token indices sequence length is longer than the specified maximum "
199 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
200 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
201 | )
202 | return bpe_tokens
203 |
204 | def decode(self, tokens):
205 | text = ''.join([self.decoder[token] for token in tokens])
206 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
207 | return text
208 |
--------------------------------------------------------------------------------
/code/generation/model/trainer.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import random
21 | from torch.utils.data import DataLoader
22 | from tqdm import tqdm
23 | from .utils import pad_sequence
24 | from .optim import Adam, NoamOpt
25 | from .loss import LabelSmoothingLoss
26 | import json
27 | import logging
28 | import math
29 | logger = logging.getLogger('s2s')
30 | logger.setLevel(logging.INFO)
31 | fh = logging.FileHandler('s2s.log', encoding='utf-8')
32 | fh.setLevel(logging.INFO)
33 | ch = logging.StreamHandler()
34 | ch.setLevel(logging.INFO)
35 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
36 | fh.setFormatter(formatter)
37 | ch.setFormatter(formatter)
38 | logger.addHandler(fh)
39 | logger.addHandler(ch)
40 | # 记录一条日志
41 | logger.info('python logging test')
42 |
43 |
44 | class Trainer:
45 | def __init__(self, model, train_dataset, test_dataset=None, batch_size=8,
46 | batch_split=1, lm_weight=0.5, risk_weight=0, lr=6.25e-5, lr_warmup=2000,
47 | n_jobs=0, clip_grad=None, label_smoothing=0, device=torch.device('cuda'),
48 | ignore_idxs=[], distributed=False):
49 | self.model = model.to(device)
50 | self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.model.padding_idx).to(device)
51 | self.criterion = LabelSmoothingLoss(n_labels=self.model.n_embeddings, smoothing=label_smoothing,
52 | ignore_index=self.model.padding_idx).to(device)
53 | base_optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=0.01)
54 | self.optimizer = NoamOpt(self.model.embeddings_size, 0.1, lr_warmup, base_optimizer)
55 |
56 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
57 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if distributed else None
58 | self.train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size//batch_split,
59 | shuffle=(not distributed), num_workers=n_jobs, collate_fn=self.collate_func_cn)
60 | if test_dataset is not None:
61 | self.test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size//batch_split,
62 | shuffle=False, num_workers=n_jobs, collate_fn=self.collate_func_cn)
63 |
64 | self.batch_split = batch_split
65 | self.lm_weight = lm_weight
66 | self.risk_weight = risk_weight
67 | self.clip_grad = clip_grad
68 | self.device = device
69 | self.ignore_idxs = ignore_idxs
70 |
71 | def state_dict(self):
72 | return {'model': self.model.state_dict(),
73 | 'optimizer': self.optimizer.state_dict()}
74 |
75 | def load_state_dict(self, state_dict):
76 | self.model.load_state_dict(state_dict['model'], strict=True)
77 | self.optimizer.load_state_dict(state_dict['optimizer'])
78 |
79 | def collate_func_cn(self, data):
80 | x, y = zip(*data)
81 | contexts = []
82 |
83 | if max(map(len, x)) > 0:
84 | x = [torch.tensor(d, dtype=torch.long) for d in x]
85 | x = pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx)
86 | contexts.append(x)
87 | y = [torch.tensor(d, dtype=torch.long) for d in y]
88 | y = pad_sequence(y, batch_first=True, padding_value=self.model.padding_idx)
89 | return contexts, y
90 |
91 | def _eval_train(self, epoch, risk_func=None):
92 | self.model.train()
93 |
94 | tqdm_data = tqdm(self.train_dataloader, desc='Train (epoch #{})'.format(epoch))
95 | loss = 0
96 | lm_loss = 0
97 | for i, (contexts, targets) in enumerate(tqdm_data):
98 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device)
99 |
100 | enc_contexts = []
101 |
102 | # lm loss
103 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device)
104 | for context in contexts:
105 | enc_context = self.model.encode(context.clone())
106 | enc_contexts.append(enc_context)
107 |
108 | if self.lm_weight > 0:
109 | context_outputs = self.model.generate(enc_context[0])
110 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1)
111 | context.masked_fill_(ignore_mask, self.model.padding_idx)
112 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous()
113 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts))
114 |
115 | # s2s loss
116 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous()
117 | outputs = self.model.decode(prevs, enc_contexts)
118 | outputs = F.log_softmax(outputs, dim=-1)
119 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1))
120 |
121 | # optimization
122 | full_loss = (batch_lm_loss * self.lm_weight + batch_loss) / self.batch_split
123 | full_loss.backward()
124 |
125 | if (i + 1) % self.batch_split == 0:
126 | if self.clip_grad is not None:
127 | for group in self.optimizer.param_groups:
128 | nn.utils.clip_grad_norm_(group['params'], self.clip_grad)
129 |
130 | self.optimizer.step()
131 | self.optimizer.zero_grad()
132 |
133 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1)
134 | loss = (i * loss + batch_loss.item()) / (i + 1)
135 | tqdm_data.set_postfix({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss), 'loss_step': batch_loss.item(),
136 | 'lr': self.optimizer.rate(), 'step': self.optimizer._step})
137 |
138 | log_dict = {'epoch': epoch, 'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss), 'lr': self.optimizer.rate(),
139 | 'step': self.optimizer._step}
140 | log_dict_json = json.dumps(log_dict, ensure_ascii=False)
141 | logger.info(log_dict_json)
142 |
143 | def _eval_test(self, metric_funcs={}):
144 | self.model.eval()
145 |
146 | tqdm_data = tqdm(self.test_dataloader, desc='Test')
147 | loss = 0
148 | lm_loss = 0
149 | metrics = {name: 0 for name in metric_funcs.keys()}
150 | for i, (contexts, targets) in enumerate(tqdm_data):
151 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device)
152 | enc_contexts = []
153 |
154 | # lm loss
155 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device)
156 | for context in contexts:
157 | enc_context = self.model.encode(context.clone())
158 | enc_contexts.append(enc_context)
159 |
160 | if self.lm_weight > 0:
161 | context_outputs = self.model.generate(enc_context[0])
162 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1)
163 | context.masked_fill_(ignore_mask, self.model.padding_idx)
164 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous()
165 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts))
166 |
167 | # s2s loss
168 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous()
169 | outputs = self.model.decode(prevs, enc_contexts)
170 | outputs = F.log_softmax(outputs, dim=-1)
171 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1))
172 |
173 | predictions = self.model.beam_search(enc_contexts)
174 | target_lens = targets.ne(self.model.padding_idx).sum(dim=-1)
175 | targets = [t[1:l-1].tolist() for t, l in zip(targets, target_lens)]
176 |
177 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1)
178 | loss = (i * loss + batch_loss.item()) / (i + 1)
179 | for name, func in metric_funcs.items():
180 | score = func(predictions, targets)
181 | metrics[name] = (metrics[name] * i + score) / (i + 1)
182 |
183 | tqdm_data.set_postfix(dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics))
184 | log_dict = dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics)
185 | log_dict_json = json.dumps(log_dict, ensure_ascii=False)
186 | logger.info(log_dict_json)
187 |
188 | def test(self, metric_funcs={}):
189 | if hasattr(self, 'test_dataloader'):
190 | self._eval_test(metric_funcs)
191 |
192 | def train(self, start_epoch, epochs, after_epoch_funcs=[], risk_func=None):
193 | for epoch in range(start_epoch, epochs):
194 | self._eval_train(epoch, risk_func)
195 | for func in after_epoch_funcs:
196 | func(epoch)
197 |
--------------------------------------------------------------------------------
/code/generation/model/postprocessing.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import language_check
18 | from mosestokenizer import *
19 | from collections import deque
20 | from nltk import ngrams
21 | from nltk.corpus import wordnet
22 | import nltk
23 | import difflib
24 | import random
25 | from .retrieval import RetrievalBot
26 | import re
27 | import numpy as np
28 | from itertools import combinations
29 |
30 |
31 | SPELL_EXCEPTIONS = ['lol']
32 | STANDARD_ANSWERS = ['do you wanna talk about something else? ',
33 | 'tell me something about yourself.',
34 | 'it is interesting. how is it outside?',
35 | 'do you like walking outside?',
36 | 'cats... do you like cats!',
37 | 'how do you spend your free time?',
38 | 'how do you usually spend your weekend?',
39 | 'i think you are interesting person. tell me something about yourself.']
40 |
41 |
42 | def syntax_fix(text):
43 | def _i_replace(text):
44 | text = text.split()
45 | for i in range(len(text)):
46 | if text[i] == 'i':
47 | text[i] = 'I'
48 | if text[i] == 'i\'m':
49 | text[i] = 'I\'m'
50 |
51 | text = ' '.join(text)
52 |
53 | return text
54 |
55 | tool = language_check.LanguageTool('en-US')
56 |
57 | matches = tool.check(text)
58 | matches = [m for m in matches if text[m.fromx:m.tox].lower() not in SPELL_EXCEPTIONS]
59 |
60 | return _i_replace(language_check.correct(text, matches))
61 |
62 |
63 | def detokenize(text):
64 | text = text.split(' ')
65 | text[0] = text[0].title()
66 |
67 | with MosesDetokenizer('en') as detokenize:
68 | text = detokenize(text)
69 |
70 | text = syntax_fix(text)
71 |
72 | return text
73 |
74 |
75 | class ReplyChecker:
76 | def __init__(self, max_len=10, theshold=0.8, correct_generative=True, split_into_sentences=True):
77 | self._replies = deque([], maxlen=max_len)
78 | self._theshold = theshold
79 | self._retrieval = RetrievalBot()
80 | self._info = None
81 | self._max_len = max_len
82 |
83 | self._correct_generative = correct_generative
84 | self._split_into_sentences = split_into_sentences
85 |
86 | self._reset_prob()
87 |
88 | def _reset_prob(self):
89 | self._def_prob = np.ones(len(STANDARD_ANSWERS)) / len(STANDARD_ANSWERS)
90 |
91 | def _ratio(self, seq1, seq2):
92 | # todo: only works good for same sequences
93 | return difflib.SequenceMatcher(None, seq1, seq2).ratio()
94 |
95 | def _sentence_max_coincidence_drop(self, reply):
96 | history = sum([re.split(r' *[\?\.\!][\'"\)\]]* *', r) for r in self._replies], [])
97 |
98 | split_reply = re.split(r' *[\?\.\!][\'"\)\]]* *', reply)
99 | punc = list(re.finditer(r' *[\?\.\!][\'"\)\]]* *', reply))
100 |
101 | # ratio = 0
102 | drop = []
103 |
104 | for i, r in enumerate(split_reply):
105 | for h in history:
106 | if h and r:
107 | ratio = self._ratio(r, h)
108 | if ratio > self._theshold:
109 | drop.append(i)
110 |
111 | drop = sorted(set(drop), reverse=True)
112 | for d in drop:
113 | split_reply.pop(d)
114 | punc.pop(d)
115 |
116 | original_text = ''
117 |
118 | for s, m in zip(split_reply, punc):
119 | original_text += s + m.group()
120 | if len(split_reply) > len(punc):
121 | original_text += split_reply[-1]
122 |
123 | return original_text.strip()
124 |
125 | def _max_coincidence(self, reply):
126 | if not self._replies:
127 | return None, reply
128 |
129 | if self._split_into_sentences:
130 | reply = self._sentence_max_coincidence_drop(reply)
131 | if not reply:
132 | return 1.0, reply
133 |
134 | mc = max(self._replies, key=lambda x: self._ratio(x, reply))
135 |
136 | ratio = self._ratio(mc, reply)
137 |
138 | return ratio, reply
139 |
140 | def _replase_reply(self, reply, request, info):
141 | dialog = 2 * ['None'] + [request]
142 | res = self._retrieval.generate_question(dialog, info)
143 | if res is None:
144 | if self._info is None:
145 | self._info = self._retrieval.get_reply_info(info)
146 |
147 | if not self._info:
148 | idx = np.random.choice(range(len(STANDARD_ANSWERS)), p=self._def_prob)
149 | self._def_prob[idx] = 0
150 |
151 | if np.sum(self._def_prob) == 0:
152 | self._reset_prob()
153 | else:
154 | self._def_prob /= np.sum(self._def_prob)
155 |
156 | return STANDARD_ANSWERS[idx]
157 |
158 | res = random.choice(list(self._info.keys()))
159 | del self._info[res]
160 |
161 | return res
162 |
163 | @staticmethod
164 | def _correct_repeated_sentences(text):
165 | split_text = re.split(r' *[\?\.\!][\'"\)\]]* *', text)
166 | matches = list(re.finditer(r' *[\?\.\!][\'"\)\]]* *', text))
167 |
168 | drop = []
169 | for i, j in combinations(range(len(split_text)), 2):
170 | if split_text[j] and split_text[j] in split_text[i]:
171 | drop.append(j)
172 | drop = set(drop)
173 | drop = sorted(drop, reverse=True)
174 |
175 | for d in drop:
176 | split_text.pop(d)
177 | matches.pop(d)
178 |
179 | original_text = ''
180 |
181 | for s, m in zip(split_text, matches):
182 | original_text += s + m.group()
183 | if len(split_text) > len(matches):
184 | original_text += split_text[-1]
185 | return original_text
186 |
187 | def check_reply(self, reply, request, info):
188 | log = [reply]
189 | log_names = ['IN: ', 'RL: ', 'RS: ']
190 |
191 | try:
192 | if self._correct_generative:
193 | reply = ReplyChecker._correct_repeated_sentences(reply)
194 |
195 | ratio, reply = self._max_coincidence(reply)
196 | log.append(reply)
197 | if ratio is not None:
198 | # ratio = self._ratio(mc, reply)
199 |
200 | if ratio > self._theshold:
201 | reply = self._replase_reply(reply, request, info)
202 | log.append(reply)
203 |
204 | except Exception as e:
205 | print('ERROR: ', e)
206 | reply = log[0]
207 |
208 | # print('[' + ' | '.join([n + str(v) for n, v in zip(log_names, log) ]) + ']')
209 | self._replies.append(reply)
210 |
211 | return reply
212 |
213 | def clean(self):
214 | self._info = None
215 | self._replies = deque([], maxlen=self._max_len)
216 | self._reset_prob()
217 |
218 |
219 | def get_syn(seq):
220 | seq = seq.replace('i ', 'I ')
221 | seq = nltk.pos_tag(nltk.word_tokenize(seq))
222 |
223 | synonyms = {}
224 |
225 | for w, s_p in seq:
226 | if len(w) < 3:
227 | continue
228 | if s_p not in ['VBP', 'NN', 'NNS']:
229 | continue
230 |
231 | pos = wordnet.VERB if s_p == 'VBP' else wordnet.NOUN
232 |
233 | s = wordnet.synsets(w, pos=pos)
234 | for word in s:
235 | for l in word.lemma_names():
236 | if l != w:
237 | synonyms[l.replace('_', ' ')] = w
238 | break
239 |
240 | if not synonyms:
241 | return None
242 |
243 | key = random.choice(list(synonyms.keys()))
244 | return synonyms[key], key
245 |
246 |
247 | def equal_phrases(phrases):
248 | matches = {' am ': '\'m ',
249 | ' are ': '\'re ',
250 | ' have ': '\'ve ',
251 | ' has ': '\'s ',
252 | 'do not': 'don\'t',
253 | 'does not': 'doesn\'t'
254 | }
255 |
256 | replasments = []
257 |
258 | for ph in phrases:
259 | a = ph
260 | for o, r in matches.items():
261 | if o in a:
262 | a = a.replace(o, r)
263 | break
264 | if r in a:
265 | a = a.replace(r, o)
266 | break
267 |
268 | if a == ph:
269 | # todo: find synonims
270 | syn = get_syn(a)
271 | if syn is None:
272 | a = a.split(' ')
273 | a[-2], a[-1] = a[-1], a[-2]
274 | a = ' '.join(a)
275 | else:
276 | a = a.replace(syn[0], syn[1])
277 |
278 | replasments.append(a)
279 |
280 | return replasments
281 |
282 |
283 | def ngram_replaser(info, reply, n=3):
284 | if info is None:
285 | return reply
286 |
287 | org_reply = reply
288 |
289 | info = re.split(r' *[\?\.\!][\'"\)\]]* *', info.strip().lower())
290 | reply = re.split(r' *[\?\.\!][\'"\)\]]* *', reply.strip().lower())
291 |
292 | info = sum([list(ngrams(i.split(), n=n)) for i in info if i], [])
293 | reply = sum([list(ngrams(r.split(), n=n)) for r in reply if r], [])
294 |
295 | phrases = []
296 |
297 | for i in info:
298 | for r in reply:
299 | if i == r:
300 | phrases.append(' '.join(r))
301 |
302 | replasments = equal_phrases(phrases)
303 |
304 | for o, r in zip(phrases, replasments):
305 | org_reply = org_reply.replace(o, r)
306 |
307 | return org_reply
308 |
--------------------------------------------------------------------------------
/code/generation/train_kd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import os
4 | from model.utils import load_openai_weights_chinese, set_seed, f1_score
5 | from model.transformer_model import TransformerModel
6 | from model.trainer_kd import Trainer
7 | from model.text import myVocab
8 | from model.dataset import S2sDataset_dialog
9 | from config import get_model_config_dialog, get_trainer_config_dialog
10 | from torch.nn.parallel import DistributedDataParallel
11 | import argparse
12 |
13 |
14 | def main():
15 | model_config = get_model_config_dialog()
16 | trainer_config = get_trainer_config_dialog()
17 |
18 | set_seed(trainer_config.seed)
19 | device = torch.device(trainer_config.device)
20 | # zrs
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--local_rank", type=int, default=0)
23 | args = parser.parse_args()
24 | distributed = (args.local_rank != -1)
25 | if distributed:
26 | print(args.local_rank)
27 | torch.cuda.set_device(args.local_rank)
28 | device = torch.device("cuda", args.local_rank)
29 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
30 |
31 | vocab = myVocab(model_config.vocab_path)
32 |
33 | transformer = TransformerModel(n_layers=model_config.n_layers,
34 | n_embeddings=len(vocab),
35 | n_pos_embeddings=model_config.n_pos_embeddings,
36 | embeddings_size=model_config.embeddings_size,
37 | padding_idx=vocab.pad_id,
38 | n_heads=model_config.n_heads,
39 | dropout=model_config.dropout,
40 | embed_dropout=model_config.embed_dropout,
41 | attn_dropout=model_config.attn_dropout,
42 | ff_dropout=model_config.ff_dropout,
43 | bos_id=vocab.bos_id,
44 | eos_id=vocab.eos_id,
45 | max_seq_len=model_config.max_seq_len,
46 | beam_size=model_config.beam_size,
47 | length_penalty=model_config.length_penalty,
48 | n_segments=model_config.n_segments,
49 | annealing_topk=model_config.annealing_topk,
50 | temperature=model_config.temperature,
51 | annealing=model_config.annealing,
52 | diversity_coef=model_config.diversity_coef,
53 | diversity_groups=model_config.diversity_groups)
54 |
55 | transformer_T = TransformerModel(n_layers=model_config.n_layers,
56 | n_embeddings=len(vocab),
57 | n_pos_embeddings=model_config.n_pos_embeddings,
58 | embeddings_size=model_config.embeddings_size,
59 | padding_idx=vocab.pad_id,
60 | n_heads=model_config.n_heads,
61 | dropout=model_config.dropout,
62 | embed_dropout=model_config.embed_dropout,
63 | attn_dropout=model_config.attn_dropout,
64 | ff_dropout=model_config.ff_dropout,
65 | bos_id=vocab.bos_id,
66 | eos_id=vocab.eos_id,
67 | max_seq_len=model_config.max_seq_len,
68 | beam_size=model_config.beam_size,
69 | length_penalty=model_config.length_penalty,
70 | n_segments=model_config.n_segments,
71 | annealing_topk=model_config.annealing_topk,
72 | temperature=model_config.temperature,
73 | annealing=model_config.annealing,
74 | diversity_coef=model_config.diversity_coef,
75 | diversity_groups=model_config.diversity_groups)
76 |
77 | state_dict = torch.load(trainer_config.teacher_checkpoint_path, map_location=device)
78 | print(state_dict['model'].keys())
79 | temp = dict(state_dict['model'])
80 | keys = list(temp.keys())
81 | for key in keys:
82 | # new_key = '.'.join([i for i in key.split('.') if i != 'module'])
83 | new_key = key.replace('.module', '')
84 | temp[new_key] = temp.pop(key)
85 | transformer_T.load_state_dict(temp, strict=True)
86 | print('load teacher model ok')
87 |
88 | if not trainer_config.load_last:
89 | openai_model = torch.load(trainer_config.openai_parameters_dir, map_location=device)
90 | openai_model.pop('decoder.pre_softmax.weight')
91 | b = list(openai_model.keys())
92 | for i in b:
93 | temp = i.split('.')
94 | keep = True
95 | for j in range(model_config.n_layers, 12):
96 | if str(j) in temp:
97 | keep = False
98 | break
99 | if keep:
100 | openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
101 | else:
102 | print(i)
103 | openai_model.pop(i)
104 | #openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
105 | transformer.transformer_module.load_state_dict(openai_model, strict=True)
106 | # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir)
107 | print('OpenAI weights chinese loaded from {}'.format(trainer_config.openai_parameters_dir))
108 |
109 | train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1)
110 | test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1)
111 |
112 | model_trainer = Trainer(transformer,
113 | transformer_T,
114 | train_dataset,
115 | test_dataset,
116 | batch_size=trainer_config.batch_size,
117 | batch_split=trainer_config.batch_split,
118 | lr=trainer_config.lr,
119 | lr_warmup=trainer_config.lr_warmup,
120 | lm_weight=trainer_config.lm_weight,
121 | risk_weight=trainer_config.risk_weight,
122 | n_jobs=trainer_config.n_jobs,
123 | clip_grad=trainer_config.clip_grad,
124 | # label_smoothing=trainer_config.label_smoothing,
125 | device=device,
126 | ignore_idxs=vocab.special_tokens_ids,
127 | distributed=distributed)
128 | if distributed:
129 | model_trainer.model.transformer_module = DistributedDataParallel(model_trainer.model.transformer_module,
130 | device_ids=[args.local_rank], output_device=args.local_rank)
131 | model_trainer.model_T.transformer_module = DistributedDataParallel(model_trainer.model_T.transformer_module,
132 | device_ids=[args.local_rank], output_device=args.local_rank)
133 |
134 | start_epoch = 0
135 | init_epoch = 35
136 |
137 | if trainer_config.load_last:
138 | state_dict = torch.load(trainer_config.last_checkpoint_path + str(init_epoch - 1), map_location=device)
139 | model_trainer.load_state_dict(state_dict)
140 | # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1
141 | start_epoch = init_epoch
142 | print('Weights loaded from {}'.format(trainer_config.last_checkpoint_path + str(init_epoch - 1)))
143 |
144 |
145 | # helpers -----------------------------------------------------
146 | def save_func(epoch):
147 | dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1])
148 | if not os.path.exists(dirs):
149 | os.makedirs(dirs)
150 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path)
151 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path + str(epoch))
152 | if os.path.exists(trainer_config.last_checkpoint_path + str(epoch - 150)):
153 | os.remove(trainer_config.last_checkpoint_path + str(epoch - 150))
154 |
155 | def sample_text_func(epoch):
156 | n_samples = 5
157 | samples_idxs = random.sample(range(len(test_dataset)), n_samples)
158 | samples = [test_dataset[idx] for idx in samples_idxs]
159 | for source, target in samples:
160 | contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [source] if len(c) > 0]
161 | prediction = model_trainer.model.predict(contexts)[0]
162 | source_str = vocab.ids2string(source)
163 | target_str = vocab.ids2string(target[1:-1])
164 | prediction_str = vocab.ids2string(prediction)
165 | print('\n')
166 | print('Source:{}'.format(source_str))
167 | print('Target:\n\t{}'.format(target_str))
168 | print('Prediction:\n\t{}'.format(prediction_str))
169 |
170 | def test_func(epoch):
171 | if (epoch+1) % trainer_config.test_period == 0:
172 | metric_funcs = {'f1_score': f1_score}
173 | model_trainer.test(metric_funcs)
174 |
175 | def f1_risk(predictions, targets):
176 | scores = f1_score(predictions, targets, average=False)
177 | return [1-s for s in scores]
178 |
179 | # helpers -----------------------------------------------------
180 |
181 | # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1])
182 | try:
183 | if args.local_rank in [-1, 0]:
184 | model_trainer.train(start_epoch, trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func],
185 | risk_func=f1_risk)
186 | else:
187 | model_trainer.train(start_epoch, trainer_config.n_epochs)
188 | except (KeyboardInterrupt, Exception, RuntimeError) as e:
189 | torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path)
190 | raise e
191 |
192 |
193 | if __name__ == '__main__':
194 | main()
195 |
196 |
--------------------------------------------------------------------------------
/code/generation/model/trainer_kd.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import random
21 | from torch.utils.data import DataLoader
22 | from tqdm import tqdm
23 | from .utils import pad_sequence
24 | from .optim import Adam, NoamOpt
25 | from .loss import LabelSmoothingLoss
26 | import json
27 | import logging
28 | import math
29 | logger = logging.getLogger('s2s-dialog-fake-kd')
30 | logger.setLevel(logging.INFO)
31 | fh = logging.FileHandler('s2s-dialog-fake-kd.log', encoding='utf-8')
32 | fh.setLevel(logging.INFO)
33 | ch = logging.StreamHandler()
34 | ch.setLevel(logging.INFO)
35 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
36 | fh.setFormatter(formatter)
37 | ch.setFormatter(formatter)
38 | logger.addHandler(fh)
39 | logger.addHandler(ch)
40 | # 记录一条日志
41 | logger.info('python logging test')
42 |
43 |
44 | class Trainer:
45 | def __init__(self, model, model_T, train_dataset, test_dataset=None, batch_size=8,
46 | batch_split=1, lm_weight=0.5, risk_weight=0, lr=6.25e-5, lr_warmup=2000,
47 | n_jobs=0, clip_grad=None, label_smoothing=0, device=torch.device('cuda'),
48 | ignore_idxs=[], distributed=False):
49 | self.model = model.to(device)
50 | self.model_T = model_T.to(device)
51 | self.model_T.requires_grad = False
52 | # self.model_T.eval()
53 | self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.model.padding_idx).to(device)
54 | self.criterion = LabelSmoothingLoss(n_labels=self.model.n_embeddings, smoothing=label_smoothing,
55 | ignore_index=self.model.padding_idx).to(device)
56 | self.kl_loss = torch.nn.KLDivLoss(reduction='batchmean').to(device)
57 | #base_optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=0.01)
58 | base_optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0.01)
59 | self.optimizer = NoamOpt(self.model.embeddings_size, 0.05, lr_warmup, base_optimizer)
60 |
61 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
62 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if distributed else None
63 | self.train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size//batch_split,
64 | shuffle=(not distributed), num_workers=n_jobs, collate_fn=self.collate_func_cn)
65 | if test_dataset is not None:
66 | self.test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size//batch_split,
67 | shuffle=False, num_workers=n_jobs, collate_fn=self.collate_func_cn)
68 |
69 | self.batch_split = batch_split
70 | self.lm_weight = lm_weight
71 | self.risk_weight = risk_weight
72 | self.clip_grad = clip_grad
73 | self.device = device
74 | self.ignore_idxs = ignore_idxs
75 |
76 | def state_dict(self):
77 | return {'model': self.model.state_dict(),
78 | 'optimizer': self.optimizer.state_dict()}
79 |
80 | def load_state_dict(self, state_dict):
81 | self.model.load_state_dict(state_dict['model'], strict=True)
82 | self.optimizer.load_state_dict(state_dict['optimizer'])
83 |
84 | def collate_func_cn(self, data):
85 | x, y = zip(*data)
86 | contexts = []
87 |
88 | if max(map(len, x)) > 0:
89 | x = [torch.tensor(d, dtype=torch.long) for d in x]
90 | x = pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx)
91 | contexts.append(x)
92 | y = [torch.tensor(d, dtype=torch.long) for d in y]
93 | y = pad_sequence(y, batch_first=True, padding_value=self.model.padding_idx)
94 | return contexts, y
95 |
96 | def _eval_train(self, epoch, risk_func=None):
97 | self.model.train()
98 |
99 | tqdm_data = tqdm(self.train_dataloader, desc='Train (epoch #{})'.format(epoch))
100 | loss = 0
101 | lm_loss = 0
102 | kl_loss = 0
103 | for i, (contexts, targets) in enumerate(tqdm_data):
104 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device)
105 |
106 | enc_contexts = []
107 | enc_contexts_t = []
108 |
109 | # lm loss
110 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device)
111 | for context in contexts:
112 | enc_context = self.model.encode(context.clone())
113 | enc_context_t = self.model_T.encode(context.clone())
114 | enc_contexts.append(enc_context)
115 | enc_contexts_t.append(enc_context_t)
116 | # print('enc_context: ', enc_context)
117 | # print('enc_context_t: ', enc_context_t)
118 |
119 | if self.lm_weight > 0:
120 | context_outputs = self.model.generate(enc_context[0])
121 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1)
122 | context.masked_fill_(ignore_mask, self.model.padding_idx)
123 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous()
124 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts))
125 |
126 | # s2s loss
127 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous()
128 | outputs = self.model.decode(prevs, enc_contexts)
129 | outputs = F.log_softmax(outputs, dim=-1)
130 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1))
131 | outputs_t = self.model_T.decode(prevs, enc_contexts_t)
132 | outputs_t = F.softmax(outputs_t, dim=-1)
133 | batch_kl_loss = self.kl_loss(outputs.view(-1, outputs.shape[-1]), outputs_t.view(-1, outputs_t.shape[-1]))
134 | '''
135 | # teacher model
136 | prevs_t = targets[:, :-1].contiguous()
137 | enc_contexts_t = []
138 | for context in contexts:
139 | temp = context.clone()
140 | # print('temp: ', temp)
141 | enc_context_t = self.model_T.encode(temp)
142 | enc_contexts_t.append(enc_context_t)
143 | print('enc_context_t: ', enc_context_t)
144 | outputs_t = self.model_T.decode(prevs_t, enc_contexts_t)
145 | # print('outputs_t: ', outputs_t)
146 | outputs_t = F.softmax(outputs_t, dim=-1)
147 | batch_kl_loss = self.kl_loss(outputs.view(-1, outputs.shape[-1]), outputs_t.view(-1, outputs_t.shape[-1]).data)
148 | print(outputs.shape, outputs_t.shape, batch_kl_loss)
149 | '''
150 |
151 | # optimization
152 | full_loss = (batch_lm_loss * self.lm_weight + batch_loss + batch_kl_loss) / self.batch_split
153 | full_loss.backward()
154 |
155 | if (i + 1) % self.batch_split == 0:
156 | if self.clip_grad is not None:
157 | for group in self.optimizer.param_groups:
158 | nn.utils.clip_grad_norm_(group['params'], self.clip_grad)
159 |
160 | self.optimizer.step()
161 | self.optimizer.zero_grad()
162 |
163 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1)
164 | loss = (i * loss + batch_loss.item()) / (i + 1)
165 | kl_loss = (i * kl_loss + batch_kl_loss.item()) / (i + 1)
166 | tqdm_data.set_postfix({'lm_loss': lm_loss, 'loss': loss, 'kl_loss': kl_loss, 'ppl': math.exp(loss), 'loss_step': batch_loss.item(),
167 | 'lr': self.optimizer.rate(), 'step': self.optimizer._step})
168 |
169 | log_dict = {'epoch': epoch, 'lm_loss': lm_loss, 'loss': loss, 'kl_loss': kl_loss, 'ppl': math.exp(loss), 'lr': self.optimizer.rate(),
170 | 'step': self.optimizer._step}
171 | log_dict_json = json.dumps(log_dict, ensure_ascii=False)
172 | logger.info(log_dict_json)
173 |
174 | def _eval_test(self, metric_funcs={}):
175 | self.model.eval()
176 |
177 | tqdm_data = tqdm(self.test_dataloader, desc='Test')
178 | loss = 0
179 | lm_loss = 0
180 | metrics = {name: 0 for name in metric_funcs.keys()}
181 | for i, (contexts, targets) in enumerate(tqdm_data):
182 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device)
183 | enc_contexts = []
184 |
185 | # lm loss
186 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device)
187 | for context in contexts:
188 | enc_context = self.model.encode(context.clone())
189 | enc_contexts.append(enc_context)
190 |
191 | if self.lm_weight > 0:
192 | context_outputs = self.model.generate(enc_context[0])
193 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1)
194 | context.masked_fill_(ignore_mask, self.model.padding_idx)
195 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous()
196 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts))
197 |
198 | # s2s loss
199 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous()
200 | outputs = self.model.decode(prevs, enc_contexts)
201 | outputs = F.log_softmax(outputs, dim=-1)
202 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1))
203 |
204 | predictions = self.model.beam_search(enc_contexts)
205 | target_lens = targets.ne(self.model.padding_idx).sum(dim=-1)
206 | targets = [t[1:l-1].tolist() for t, l in zip(targets, target_lens)]
207 |
208 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1)
209 | loss = (i * loss + batch_loss.item()) / (i + 1)
210 | for name, func in metric_funcs.items():
211 | score = func(predictions, targets)
212 | metrics[name] = (metrics[name] * i + score) / (i + 1)
213 |
214 | tqdm_data.set_postfix(dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics))
215 | log_dict = dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics)
216 | log_dict_json = json.dumps(log_dict, ensure_ascii=False)
217 | logger.info(log_dict_json)
218 |
219 | def test(self, metric_funcs={}):
220 | if hasattr(self, 'test_dataloader'):
221 | self._eval_test(metric_funcs)
222 |
223 | def train(self, start_epoch, epochs, after_epoch_funcs=[], risk_func=None):
224 | for epoch in range(start_epoch, epochs):
225 | self._eval_train(epoch, risk_func)
226 | for func in after_epoch_funcs:
227 | func(epoch)
228 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import re
23 | import sys
24 | from io import open
25 |
26 | from tqdm import tqdm
27 |
28 | from .file_utils import cached_path
29 | from .tokenization import BasicTokenizer
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
35 | }
36 | PRETRAINED_MERGES_ARCHIVE_MAP = {
37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'openai-gpt': 512,
41 | }
42 | VOCAB_NAME = 'vocab.json'
43 | MERGES_NAME = 'merges.txt'
44 |
45 | def get_pairs(word):
46 | """
47 | Return set of symbol pairs in a word.
48 | word is represented as tuple of symbols (symbols being variable-length strings)
49 | """
50 | pairs = set()
51 | prev_char = word[0]
52 | for char in word[1:]:
53 | pairs.add((prev_char, char))
54 | prev_char = char
55 | return pairs
56 |
57 | def text_standardize(text):
58 | """
59 | fixes some issues the spacy tokenizer had on books corpus
60 | also does some whitespace standardization
61 | """
62 | text = text.replace('—', '-')
63 | text = text.replace('–', '-')
64 | text = text.replace('―', '-')
65 | text = text.replace('…', '...')
66 | text = text.replace('´', "'")
67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
68 | text = re.sub(r'\s*\n\s*', ' \n ', text)
69 | text = re.sub(r'[^\S\n]+', ' ', text)
70 | return text.strip()
71 |
72 | class OpenAIGPTTokenizer(object):
73 | """
74 | BPE tokenizer. Peculiarities:
75 | - lower case all inputs
76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
77 | - argument special_tokens and function set_special_tokens:
78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary.
79 | """
80 | @classmethod
81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
82 | """
83 | Instantiate a PreTrainedBertModel from a pre-trained model file.
84 | Download and cache the pre-trained model file if needed.
85 | """
86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
89 | else:
90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
92 | # redirect to the cache, if necessary
93 | try:
94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
96 | except EnvironmentError:
97 | logger.error(
98 | "Model name '{}' was not found in model name list ({}). "
99 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
100 | "at this path or url.".format(
101 | pretrained_model_name_or_path,
102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
103 | pretrained_model_name_or_path,
104 | vocab_file, merges_file))
105 | return None
106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
107 | logger.info("loading vocabulary file {}".format(vocab_file))
108 | logger.info("loading merges file {}".format(merges_file))
109 | else:
110 | logger.info("loading vocabulary file {} from cache at {}".format(
111 | vocab_file, resolved_vocab_file))
112 | logger.info("loading merges file {} from cache at {}".format(
113 | merges_file, resolved_merges_file))
114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
116 | # than the number of positional embeddings
117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
119 | # Instantiate tokenizer.
120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
121 | return tokenizer
122 |
123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
124 | try:
125 | import ftfy
126 | import spacy
127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
128 | self.fix_text = ftfy.fix_text
129 | except ImportError:
130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
131 | self.nlp = BasicTokenizer(do_lower_case=True,
132 | never_split=special_tokens if special_tokens is not None else [])
133 | self.fix_text = None
134 |
135 | self.max_len = max_len if max_len is not None else int(1e12)
136 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
137 | self.decoder = {v:k for k,v in self.encoder.items()}
138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
139 | merges = [tuple(merge.split()) for merge in merges]
140 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
141 | self.cache = {}
142 | self.set_special_tokens(special_tokens)
143 |
144 | def __len__(self):
145 | return len(self.encoder) + len(self.special_tokens)
146 |
147 | def set_special_tokens(self, special_tokens):
148 | """ Add a list of additional tokens to the encoder.
149 | The additional tokens are indexed starting from the last index of the
150 | current vocabulary in the order of the `special_tokens` list.
151 | """
152 | if not special_tokens:
153 | self.special_tokens = {}
154 | self.special_tokens_decoder = {}
155 | return
156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
158 | if self.fix_text is None:
159 | # Using BERT's BasicTokenizer: we can update the tokenizer
160 | self.nlp.never_split = special_tokens
161 | logger.info("Special tokens {}".format(self.special_tokens))
162 |
163 | def bpe(self, token):
164 | word = tuple(token[:-1]) + (token[-1] + '',)
165 | if token in self.cache:
166 | return self.cache[token]
167 | pairs = get_pairs(word)
168 |
169 | if not pairs:
170 | return token+''
171 |
172 | while True:
173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
174 | if bigram not in self.bpe_ranks:
175 | break
176 | first, second = bigram
177 | new_word = []
178 | i = 0
179 | while i < len(word):
180 | try:
181 | j = word.index(first, i)
182 | new_word.extend(word[i:j])
183 | i = j
184 | except:
185 | new_word.extend(word[i:])
186 | break
187 |
188 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
189 | new_word.append(first+second)
190 | i += 2
191 | else:
192 | new_word.append(word[i])
193 | i += 1
194 | new_word = tuple(new_word)
195 | word = new_word
196 | if len(word) == 1:
197 | break
198 | else:
199 | pairs = get_pairs(word)
200 | word = ' '.join(word)
201 | if word == '\n ':
202 | word = '\n'
203 | self.cache[token] = word
204 | return word
205 |
206 | def tokenize(self, text):
207 | """ Tokenize a string. """
208 | split_tokens = []
209 | if self.fix_text is None:
210 | # Using BERT's BasicTokenizer
211 | text = self.nlp.tokenize(text)
212 | for token in text:
213 | split_tokens.extend([t for t in self.bpe(token).split(' ')])
214 | else:
215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
216 | text = self.nlp(text_standardize(self.fix_text(text)))
217 | for token in text:
218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
219 | return split_tokens
220 |
221 | def convert_tokens_to_ids(self, tokens):
222 | """ Converts a sequence of tokens into ids using the vocab. """
223 | ids = []
224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
225 | if tokens in self.special_tokens:
226 | return self.special_tokens[tokens]
227 | else:
228 | return self.encoder.get(tokens, 0)
229 | for token in tokens:
230 | if token in self.special_tokens:
231 | ids.append(self.special_tokens[token])
232 | else:
233 | ids.append(self.encoder.get(token, 0))
234 | if len(ids) > self.max_len:
235 | logger.warning(
236 | "Token indices sequence length is longer than the specified maximum "
237 | " sequence length for this OpenAI GPT model ({} > {}). Running this"
238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
239 | )
240 | return ids
241 |
242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
243 | """Converts a sequence of ids in BPE tokens using the vocab."""
244 | tokens = []
245 | for i in ids:
246 | if i in self.special_tokens_decoder:
247 | if not skip_special_tokens:
248 | tokens.append(self.special_tokens_decoder[i])
249 | else:
250 | tokens.append(self.decoder[i])
251 | return tokens
252 |
253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
254 | """Converts a sequence of ids in a string."""
255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
256 | out_string = ''.join(tokens).replace('', ' ').strip()
257 | if clean_up_tokenization_spaces:
258 | out_string = out_string.replace('', '')
259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't"
261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
262 | ).replace(" 've", "'ve")
263 | return out_string
264 |
--------------------------------------------------------------------------------
/code/generation/model/transformer_model.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import random
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 | from .transformer_module import TransformerModule
23 |
24 |
25 | class TransformerModel(nn.Module):
26 | def __init__(self, n_layers, n_embeddings, n_pos_embeddings, embeddings_size,
27 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, ff_dropout,
28 | bos_id, eos_id, max_seq_len=256, beam_size=5, sample=False,
29 | length_penalty=0.8, annealing_topk=None, temperature=1.0, annealing=0,
30 | diversity_coef=0, diversity_groups=1, n_segments=None):
31 |
32 | super(TransformerModel, self).__init__()
33 |
34 | self.padding_idx = padding_idx
35 | self.n_embeddings = n_embeddings
36 | self.n_pos_embeddings = n_pos_embeddings
37 | self.embeddings_size = embeddings_size
38 |
39 | self.bos_id = bos_id
40 | self.eos_id = eos_id
41 |
42 | self.max_seq_len = max_seq_len
43 | self.beam_size = beam_size
44 | self.sample = sample
45 | self.length_penalty_coef = length_penalty
46 | self.annealing = annealing
47 | self.annealing_topk = annealing_topk
48 | self.temperature = temperature
49 | self.diversity_coef = diversity_coef
50 | self.diversity_groups = diversity_groups
51 |
52 | self.transformer_module = TransformerModule(n_layers, n_embeddings, n_pos_embeddings, embeddings_size,
53 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout,
54 | ff_dropout, n_segments)
55 | self.pre_softmax = nn.Linear(embeddings_size, n_embeddings, bias=False)
56 | self.pre_softmax.weight = self.transformer_module.embeddings.weight
57 |
58 | def forward(self, x, contexts=[]):
59 | enc_contexts = [self.encode(c) for c in contexts]
60 | return self.decode(x, enc_contexts)
61 |
62 | def encode(self, x):
63 | return self.transformer_module(x)
64 |
65 | def generate(self, enc_x):
66 | return self.pre_softmax(enc_x)
67 |
68 | def decode(self, x, enc_contexts=[]):
69 | x, _ = self.transformer_module(x, enc_contexts)
70 | return self.generate(x)
71 |
72 | def predict(self, contexts=[]):
73 | enc_contexts = [self.encode(c) for c in contexts]
74 | prediction = self.beam_search(enc_contexts)
75 | return prediction
76 |
77 | def predict_beams(self, contexts=[]):
78 | enc_contexts = [self.encode(c) for c in contexts]
79 | prediction = self.beam_search(enc_contexts, return_beams=True)
80 | return prediction
81 |
82 | def _length_penalty(self, sequence_lengths):
83 | """https://arxiv.org/abs/1609.08144"""
84 | return (5 + sequence_lengths) ** self.length_penalty_coef / (5 + 1) ** self.length_penalty_coef
85 |
86 | def beam_search(self, enc_contexts=[], return_beams=False):
87 | with torch.no_grad():
88 | if len(enc_contexts) == 0:
89 | return []
90 |
91 | batch_size = enc_contexts[0][0].shape[0]
92 | device = next(self.parameters()).device
93 |
94 | prevs = torch.full((batch_size * self.beam_size, 1), fill_value=self.bos_id, dtype=torch.long, device=device)
95 |
96 | beam_scores = torch.zeros(batch_size, self.beam_size, device=device)
97 | beam_lens = torch.ones(batch_size, self.beam_size, dtype=torch.long, device=device)
98 | is_end = torch.zeros(batch_size, self.beam_size, dtype=torch.uint8, device=device)
99 |
100 | beam_enc_contexts = []
101 | for c, p in enc_contexts:
102 | c = c.unsqueeze(1).repeat(1, self.beam_size, 1, 1)
103 | c = c.view(-1, c.shape[2], c.shape[3])
104 | p = p.unsqueeze(1).repeat(1, self.beam_size, 1)
105 | p = p.view(-1, p.shape[2])
106 | beam_enc_contexts.append((c, p))
107 |
108 | current_sample_prob = 1
109 | group_size = self.beam_size // self.diversity_groups
110 | diversity_penalty = torch.zeros((batch_size, self.n_embeddings), device=device)
111 |
112 | # zrs:
113 | repeat = [{} for i in range(batch_size * self.beam_size)]
114 | # **********
115 | for i in range(self.max_seq_len):
116 | outputs, _ = self.transformer_module(prevs, beam_enc_contexts)
117 |
118 | logits = self.generate(outputs[:, -1, :])
119 | log_probs = F.log_softmax(logits, dim=-1)
120 | '''
121 | # zrs: remove n repeat. prevs: (batch_size*beam_size, 1)
122 | for idx in range(batch_size * self.beam_size):
123 | for key in repeat[idx]:
124 | for value in repeat[idx][key]:
125 | log_probs[idx][value] = -1000
126 | # **********
127 | '''
128 | # zrs
129 | prevs_list_temp = prevs.tolist()
130 | for idx in range(batch_size * self.beam_size):
131 | b_list = prevs_list_temp[idx]
132 | if len(b_list) > 1 and b_list[-1] != self.padding_idx and b_list[-1] != self.eos_id:
133 | key = (int(b_list[-2]), int(b_list[-1]))
134 | if key in repeat[idx]:
135 | for value in repeat[idx][key]:
136 | log_probs[idx][value] = -1000
137 | # **********
138 |
139 | log_probs = log_probs.view(batch_size, self.beam_size, -1)
140 |
141 | beam_scores = beam_scores.unsqueeze(-1) + log_probs * (1 - is_end.float().unsqueeze(-1))
142 |
143 | # zrs, log_probs: batch * beam * dim
144 | ba, be, dim = beam_scores.shape
145 | for ba_idx in range(ba):
146 | for be_idx in range(be):
147 | if int(torch.max(beam_scores[ba_idx][be_idx]) == torch.min(beam_scores[ba_idx][be_idx])):
148 | temp = float(beam_scores[ba_idx][be_idx][0])
149 | beam_scores[ba_idx][be_idx] = -float('inf')
150 | beam_scores[ba_idx][be_idx][0] = temp
151 | # **********
152 |
153 |
154 | penalty = self._length_penalty(beam_lens.float() + 1 - is_end.float())
155 | penalty = penalty.unsqueeze(-1).repeat(1, 1, self.n_embeddings)
156 | beam_scores = beam_scores / penalty
157 |
158 | if i == 0:
159 | penalty = penalty[:, 0, :]
160 | beam_scores = beam_scores[:, 0, :]
161 |
162 | beam_scores, idxs = beam_scores.topk(self.beam_size, dim=-1)
163 | beam_idxs = torch.zeros((batch_size, self.beam_size), dtype=torch.long, device=device)
164 | else:
165 | penalty = penalty.view(batch_size, self.diversity_groups, group_size, -1)
166 | beam_scores = beam_scores.view(batch_size, self.diversity_groups, group_size, -1)
167 |
168 | all_scores, all_idxs = [], []
169 | for g in range(self.diversity_groups):
170 | g_beam_scores = beam_scores[:, g, :, :]
171 | g_penalty = penalty[:, g, :, :]
172 | g_beam_scores -= self.diversity_coef * diversity_penalty.unsqueeze(1) / g_penalty
173 | g_beam_scores = g_beam_scores.view(batch_size, -1)
174 |
175 | if random.random() < current_sample_prob:
176 | beam_probas = F.softmax(g_beam_scores/self.temperature, dim=-1)
177 | if self.annealing_topk is not None:
178 | beam_probas, sample_idxs = beam_probas.topk(self.annealing_topk, dim=-1)
179 | g_idxs = torch.multinomial(beam_probas, group_size)
180 | g_idxs = torch.gather(sample_idxs, 1, g_idxs)
181 | else:
182 | g_idxs = torch.multinomial(beam_probas, group_size)
183 | else:
184 | _, g_idxs = g_beam_scores.topk(group_size, dim=-1)
185 |
186 | g_scores = torch.gather(beam_scores[:, g, :, :].view(batch_size, -1), 1, g_idxs)
187 | g_idxs += g * group_size * self.n_embeddings
188 |
189 | all_scores.append(g_scores)
190 | all_idxs.append(g_idxs)
191 |
192 | diversity_penalty.scatter_add_(1, torch.fmod(g_idxs, self.n_embeddings), torch.ones((batch_size, group_size), device=device))
193 |
194 | diversity_penalty.fill_(0)
195 | penalty = penalty.view(batch_size, -1)
196 | beam_scores = torch.cat(all_scores, dim=-1)
197 | idxs = torch.cat(all_idxs, dim=-1)
198 |
199 | beam_idxs = (idxs.float() / self.n_embeddings).long()
200 | # print('beam_scores: ', beam_scores)
201 |
202 | penalty = torch.gather(penalty, 1, idxs)
203 | sym_idxs = torch.fmod(idxs, log_probs.shape[-1])
204 | is_end = torch.gather(is_end, 1, beam_idxs)
205 | beam_lens = torch.gather(beam_lens, 1, beam_idxs)
206 |
207 | sym_idxs[is_end] = self.padding_idx
208 | beam_lens[~is_end] += 1
209 | is_end[sym_idxs == self.eos_id] = 1
210 |
211 | sym_idxs = sym_idxs.view(batch_size * self.beam_size, 1)
212 | prevs = prevs.view(batch_size, self.beam_size, -1)
213 | prevs = torch.gather(prevs, 1, beam_idxs.unsqueeze(-1).repeat(1, 1, prevs.shape[-1]))
214 | prevs = prevs.view(batch_size * self.beam_size, -1)
215 | prevs = torch.cat([prevs, sym_idxs], dim=1)
216 |
217 | # zrs:
218 | prevs_list = prevs.tolist()
219 | for b in range(batch_size * self.beam_size):
220 | b_list = prevs_list[b]
221 | if len(b_list) > 2 and b_list[-1] != self.padding_idx and b_list[-1] != self.eos_id:
222 | key = (int(b_list[-3]), int(b_list[-2]))
223 | if key in repeat[b]:
224 | repeat[b][key].append(int(b_list[-1]))
225 | else:
226 | repeat[b][key] = [int(b_list[-1])]
227 | # ********
228 |
229 | if all(is_end.view(-1)):
230 | break
231 | # print(beam_scores.shape)
232 | beam_scores *= penalty
233 | current_sample_prob *= self.annealing
234 |
235 | predicts = []
236 | result = prevs.view(batch_size, self.beam_size, -1)
237 |
238 | # if return_beams:
239 | # return result, beam_lens
240 | if return_beams:
241 | bests = torch.argsort(beam_scores, dim=-1, descending=True)
242 | for i in range(batch_size):
243 | temp = []
244 | for j in range(self.beam_size):
245 | best_len = beam_lens[i, bests[i][j]]
246 | best_seq = result[i, bests[i][j], 1:best_len - 1]
247 | temp.append(best_seq.tolist())
248 | predicts.append(temp)
249 | return predicts
250 |
251 | if self.sample:
252 | probs = F.softmax(beam_scores, dim=-1)
253 | bests = torch.multinomial(probs, 1).view(-1)
254 | else:
255 | bests = beam_scores.argmax(dim=-1)
256 |
257 | for i in range(batch_size):
258 | best_len = beam_lens[i, bests[i]]
259 | best_seq = result[i, bests[i], 1:best_len-1]
260 | predicts.append(best_seq.tolist())
261 |
262 | return predicts
263 |
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 | import abc
24 | import sys
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | if sys.version_info >= (3, 4):
30 | ABC = abc.ABC
31 | else:
32 | ABC = abc.ABCMeta('ABC', (), {})
33 |
34 |
35 | class _LRSchedule(ABC):
36 | """ Parent of all LRSchedules here. """
37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
38 | def __init__(self, warmup=0.002, t_total=-1, **kw):
39 | """
40 | :param warmup: what fraction of t_total steps will be used for linear warmup
41 | :param t_total: how many training steps (updates) are planned
42 | :param kw:
43 | """
44 | super(_LRSchedule, self).__init__(**kw)
45 | if t_total < 0:
46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
47 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
49 | warmup = max(warmup, 0.)
50 | self.warmup, self.t_total = float(warmup), float(t_total)
51 | self.warned_for_t_total_at_progress = -1
52 |
53 | def get_lr(self, step, nowarn=False):
54 | """
55 | :param step: which of t_total steps we're on
56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
57 | :return: learning rate multiplier for current update
58 | """
59 | if self.t_total < 0:
60 | return 1.
61 | progress = float(step) / self.t_total
62 | ret = self.get_lr_(progress)
63 | # warning for exceeding t_total (only active with warmup_linear
64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
65 | logger.warning(
66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
67 | .format(ret, self.__class__.__name__))
68 | self.warned_for_t_total_at_progress = progress
69 | # end warning
70 | return ret
71 |
72 | @abc.abstractmethod
73 | def get_lr_(self, progress):
74 | """
75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
76 | :return: learning rate multiplier for current update
77 | """
78 | return 1.
79 |
80 |
81 | class ConstantLR(_LRSchedule):
82 | def get_lr_(self, progress):
83 | return 1.
84 |
85 |
86 | class WarmupCosineSchedule(_LRSchedule):
87 | """
88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
91 | """
92 | warn_t_total = True
93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
94 | """
95 | :param warmup: see LRSchedule
96 | :param t_total: see LRSchedule
97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
98 | :param kw:
99 | """
100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
101 | self.cycles = cycles
102 |
103 | def get_lr_(self, progress):
104 | if progress < self.warmup:
105 | return progress / self.warmup
106 | else:
107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
109 |
110 |
111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
112 | """
113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
115 | learning rate (with hard restarts).
116 | """
117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
119 | assert(cycles >= 1.)
120 |
121 | def get_lr_(self, progress):
122 | if progress < self.warmup:
123 | return progress / self.warmup
124 | else:
125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
127 | return ret
128 |
129 |
130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
131 | """
132 | All training progress is divided in `cycles` (default=1.) parts of equal length.
133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve.
135 | """
136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
137 | assert(warmup * cycles < 1.)
138 | warmup = warmup * cycles if warmup >= 0 else warmup
139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
140 |
141 | def get_lr_(self, progress):
142 | progress = progress * self.cycles % 1.
143 | if progress < self.warmup:
144 | return progress / self.warmup
145 | else:
146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
147 | ret = 0.5 * (1. + math.cos(math.pi * progress))
148 | return ret
149 |
150 |
151 | class WarmupConstantSchedule(_LRSchedule):
152 | """
153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
154 | Keeps learning rate equal to 1. after warmup.
155 | """
156 | def get_lr_(self, progress):
157 | if progress < self.warmup:
158 | return progress / self.warmup
159 | return 1.
160 |
161 |
162 | class WarmupLinearSchedule(_LRSchedule):
163 | """
164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
166 | """
167 | warn_t_total = True
168 | def get_lr_(self, progress):
169 | if progress < self.warmup:
170 | return progress / self.warmup
171 | return max((progress - 1.) / (self.warmup - 1.), 0.)
172 |
173 |
174 | SCHEDULES = {
175 | None: ConstantLR,
176 | "none": ConstantLR,
177 | "warmup_cosine": WarmupCosineSchedule,
178 | "warmup_constant": WarmupConstantSchedule,
179 | "warmup_linear": WarmupLinearSchedule
180 | }
181 |
182 |
183 | class BertAdam(Optimizer):
184 | """Implements BERT version of Adam algorithm with weight decay fix.
185 | Params:
186 | lr: learning rate
187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
188 | t_total: total number of training steps for the learning
189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
190 | schedule: schedule to use for the warmup (see above).
191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
192 | If `None` or `'none'`, learning rate is always kept constant.
193 | Default : `'warmup_linear'`
194 | b1: Adams b1. Default: 0.9
195 | b2: Adams b2. Default: 0.999
196 | e: Adams epsilon. Default: 1e-6
197 | weight_decay: Weight decay. Default: 0.01
198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
199 | """
200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
202 | if lr is not required and lr < 0.0:
203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
205 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
206 | if not 0.0 <= b1 < 1.0:
207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
208 | if not 0.0 <= b2 < 1.0:
209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
210 | if not e >= 0.0:
211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
212 | # initialize schedule object
213 | if not isinstance(schedule, _LRSchedule):
214 | schedule_type = SCHEDULES[schedule]
215 | schedule = schedule_type(warmup=warmup, t_total=t_total)
216 | else:
217 | if warmup != -1 or t_total != -1:
218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
219 | "Please specify custom warmup and t_total in _LRSchedule object.")
220 | defaults = dict(lr=lr, schedule=schedule,
221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
222 | max_grad_norm=max_grad_norm)
223 | super(BertAdam, self).__init__(params, defaults)
224 |
225 | def get_lr(self):
226 | lr = []
227 | for group in self.param_groups:
228 | for p in group['params']:
229 | state = self.state[p]
230 | if len(state) == 0:
231 | return [0]
232 | lr_scheduled = group['lr']
233 | lr_scheduled *= group['schedule'].get_lr(state['step'])
234 | lr.append(lr_scheduled)
235 | return lr
236 |
237 | def step(self, closure=None):
238 | """Performs a single optimization step.
239 | Arguments:
240 | closure (callable, optional): A closure that reevaluates the model
241 | and returns the loss.
242 | """
243 | loss = None
244 | if closure is not None:
245 | loss = closure()
246 |
247 | for group in self.param_groups:
248 | for p in group['params']:
249 | if p.grad is None:
250 | continue
251 | grad = p.grad.data
252 | if grad.is_sparse:
253 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
254 |
255 | state = self.state[p]
256 |
257 | # State initialization
258 | if len(state) == 0:
259 | state['step'] = 0
260 | # Exponential moving average of gradient values
261 | state['next_m'] = torch.zeros_like(p.data)
262 | # Exponential moving average of squared gradient values
263 | state['next_v'] = torch.zeros_like(p.data)
264 |
265 | next_m, next_v = state['next_m'], state['next_v']
266 | beta1, beta2 = group['b1'], group['b2']
267 |
268 | # Add grad clipping
269 | if group['max_grad_norm'] > 0:
270 | clip_grad_norm_(p, group['max_grad_norm'])
271 |
272 | # Decay the first and second moment running average coefficient
273 | # In-place operations to update the averages at the same time
274 | next_m.mul_(beta1).add_(1 - beta1, grad)
275 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
276 | update = next_m / (next_v.sqrt() + group['e'])
277 |
278 | # Just adding the square of the weights to the loss function is *not*
279 | # the correct way of using L2 regularization/weight decay with Adam,
280 | # since that will interact with the m and v parameters in strange ways.
281 | #
282 | # Instead we want to decay the weights in a manner that doesn't interact
283 | # with the m/v parameters. This is equivalent to adding the square
284 | # of the weights to the loss with plain (non-momentum) SGD.
285 | if group['weight_decay'] > 0.0:
286 | update += group['weight_decay'] * p.data
287 |
288 | lr_scheduled = group['lr']
289 | lr_scheduled *= group['schedule'].get_lr(state['step'])
290 |
291 | update_with_lr = lr_scheduled * update
292 | p.data.add_(-update_with_lr)
293 |
294 | state['step'] += 1
295 |
296 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
297 | # No bias correction
298 | # bias_correction1 = 1 - beta1 ** state['step']
299 | # bias_correction2 = 1 - beta2 ** state['step']
300 |
301 | return loss
302 |
--------------------------------------------------------------------------------
/code/generation/config.py:
--------------------------------------------------------------------------------
1 | from attrdict import AttrDict
2 | from model.utils import openai_transformer_config
3 |
4 |
5 | # transformer config
6 | def get_model_config_dialog():
7 | default_config = openai_transformer_config()
8 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt',
9 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/last_checkpoint',
10 | 'n_layers': 12,
11 | 'n_pos_embeddings': 512,
12 | 'embeddings_size': default_config.embeddings_size,
13 | 'n_heads': default_config.n_heads,
14 | 'dropout': default_config.dropout,
15 | 'embed_dropout': default_config.embed_dropout,
16 | 'attn_dropout': default_config.attn_dropout,
17 | 'ff_dropout': default_config.ff_dropout,
18 | 'max_seq_len': 32,
19 | 'beam_size': 1,
20 | 'diversity_coef': 0,
21 | 'diversity_groups': 1,
22 | 'temperature': 1.0,
23 | 'annealing_topk': None,
24 | 'annealing': 0,
25 | 'length_penalty': 1.0,
26 | 'n_segments': None})
27 |
28 | return config
29 |
30 |
31 | def get_trainer_config_dialog():
32 | config = AttrDict({'n_epochs': 100,
33 | 'batch_size': 256,
34 | 'batch_split': 32,
35 | 'lr': 6.25e-5,
36 | 'lr_warmup': 1000,
37 | 'lm_weight': 0.5,
38 | 'risk_weight': 0,
39 | 'n_jobs': 4,
40 | 'label_smoothing': 0.1,
41 | 'clip_grad': None,
42 | 'test_period': 1,
43 | 'seed': 0,
44 | 'device': 'cuda',
45 | 'load_last': True,
46 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
47 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/last_checkpoint',
48 | 'teacher_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded/last_checkpoint15',
49 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/interrupt_checkpoint',
50 | 'train_datasets': [#'/root/generation_with_augmentation/dataset/dialog/crowded_500k.txt',
51 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k.txt',
52 | #'/root/generation_with_augmentation/dataset/dialog/crowded_100k.txt',
53 | #'/root/generation_with_augmentation/dataset/dialog/crowded_unpaired_500k.txt',
54 | #'/root/generation_with_augmentation/dataset/dialog/fake_unique.txt',
55 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_eda_v2.txt',
56 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_cvae_post.txt',
57 | #'/root/generation_with_augmentation/dataset/dialog/crowded_dirty_300k.txt',
58 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_bt.txt',
59 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_post.txt',
60 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_resp.txt',
61 | '/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_both.txt',
62 | #'/root/generation_with_augmentation/dataset/dialog/fake_wo_matching_generation.txt',
63 | ],
64 | 'test_datasets': ['/root/generation_with_augmentation/dataset/dialog/valid_9k.txt']})
65 | return config
66 |
67 |
68 | def get_test_config_dialog():
69 | config = AttrDict({'seed': 0,
70 | 'device': 'cuda',
71 | 'load_last': True,
72 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
73 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_kd/crowded_fake/last_checkpoint11'})
74 | return config
75 |
76 |
77 | # transformer config overlap
78 | def get_model_config_dialog_overlap():
79 | default_config = openai_transformer_config()
80 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab_overlap.txt',
81 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint',
82 | 'n_layers': 12,
83 | 'n_pos_embeddings': 512,
84 | 'embeddings_size': default_config.embeddings_size,
85 | 'n_heads': default_config.n_heads,
86 | 'dropout': default_config.dropout,
87 | 'embed_dropout': default_config.embed_dropout,
88 | 'attn_dropout': default_config.attn_dropout,
89 | 'ff_dropout': default_config.ff_dropout,
90 | 'max_seq_len': 32,
91 | 'beam_size': 5,
92 | 'diversity_coef': 0,
93 | 'diversity_groups': 1,
94 | 'temperature': 1.0,
95 | 'annealing_topk': None,
96 | 'annealing': 0,
97 | 'length_penalty': 1.5,
98 | 'n_segments': None})
99 |
100 | return config
101 |
102 |
103 | def get_trainer_config_dialog_overlap():
104 | config = AttrDict({'n_epochs': 30,
105 | 'batch_size': 256,
106 | 'batch_split': 16,
107 | 'lr': 6.25e-5,
108 | 'lr_warmup': 1000,
109 | 'lm_weight': 0.5,
110 | 'risk_weight': 0,
111 | 'n_jobs': 4,
112 | 'label_smoothing': 0.1,
113 | 'clip_grad': None,
114 | 'test_period': 1,
115 | 'seed': 0,
116 | 'device': 'cuda',
117 | 'load_last': False,
118 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
119 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint',
120 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/interrupt_checkpoint',
121 | 'train_datasets': ['/root/generation_with_augmentation/dataset/dialog/crowded_500k_overlap.txt',
122 | #'/root/generation_with_augmentation/dataset/dialog/crowded_unpaired_500k.txt',
123 | #'/root/generation_with_augmentation/dataset/dialog/fake_unique.txt',
124 | ],
125 | 'test_datasets': ['/root/generation_with_augmentation/dataset/dialog/valid_9k_overlap.txt']})
126 | return config
127 |
128 | def get_test_config_dialog_overlap():
129 | config = AttrDict({'seed': 0,
130 | 'device': 'cuda',
131 | 'load_last': True,
132 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
133 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint15'})
134 | return config
135 |
136 | # transformer config
137 | def get_model_config_poem():
138 | default_config = openai_transformer_config()
139 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt',
140 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/last_checkpoint',
141 | 'n_layers': 12,
142 | 'n_pos_embeddings': 512,
143 | 'embeddings_size': default_config.embeddings_size,
144 | 'n_heads': default_config.n_heads,
145 | 'dropout': default_config.dropout,
146 | 'embed_dropout': default_config.embed_dropout,
147 | 'attn_dropout': default_config.attn_dropout,
148 | 'ff_dropout': default_config.ff_dropout,
149 | 'max_seq_len': 128,
150 | 'beam_size': 15,
151 | 'diversity_coef': 0.5,
152 | 'diversity_groups': 5,
153 | 'temperature': 0.8,
154 | 'annealing_topk': 20,
155 | 'annealing': 1.0,
156 | 'length_penalty': 0.6,
157 | 'n_segments': None})
158 |
159 | return config
160 |
161 |
162 | def get_trainer_config_poem():
163 | config = AttrDict({'n_epochs': 30,
164 | 'batch_size': 256,
165 | 'batch_split': 8,
166 | 'lr': 6.25e-5,
167 | 'lr_warmup': 1000,
168 | 'lm_weight': 0,
169 | 'risk_weight': 0,
170 | 'n_jobs': 4,
171 | 'label_smoothing': 0,
172 | 'clip_grad': None,
173 | 'test_period': 1,
174 | 'seed': 0,
175 | 'device': 'cuda',
176 | 'load_last': False,
177 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
178 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/last_checkpoint',
179 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/interrupt_checkpoint',
180 | 'train_datasets': ['/root/generation_with_augmentation/dataset/poem/train_wu.txt'],
181 | 'test_datasets': ['/root/generation_with_augmentation/dataset/poem/valid_wu.txt']})
182 | return config
183 |
184 |
185 | def get_test_config_poem():
186 | config = AttrDict({'seed': 0,
187 | 'device': 'cuda',
188 | 'load_last': True,
189 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
190 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem/last_checkpoint20'})
191 |
192 | return config
193 |
194 | # transformer config
195 | def get_model_config_meme():
196 | default_config = openai_transformer_config()
197 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt',
198 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint',
199 | 'n_layers': 6,
200 | 'n_pos_embeddings': 512,
201 | 'embeddings_size': default_config.embeddings_size,
202 | 'n_heads': default_config.n_heads,
203 | 'dropout': default_config.dropout,
204 | 'embed_dropout': default_config.embed_dropout,
205 | 'attn_dropout': default_config.attn_dropout,
206 | 'ff_dropout': default_config.ff_dropout,
207 | 'max_seq_len': 30,
208 | 'beam_size': 5,
209 | 'diversity_coef': 0,
210 | 'diversity_groups': 1,
211 | 'temperature': 1.0,
212 | 'annealing_topk': None,
213 | 'annealing': 0,
214 | 'length_penalty': 1.0,
215 | 'n_segments': None})
216 |
217 | return config
218 |
219 |
220 | def get_trainer_config_meme():
221 | config = AttrDict({'n_epochs': 50,
222 | 'batch_size': 256,
223 | 'batch_split': 8,
224 | 'lr': 6.25e-5,
225 | 'lr_warmup': 1000,
226 | 'lm_weight': 0,
227 | 'risk_weight': 0,
228 | 'n_jobs': 4,
229 | 'label_smoothing': 0,
230 | 'clip_grad': None,
231 | 'test_period': 1,
232 | 'seed': 0,
233 | 'device': 'cuda',
234 | 'load_last': False,
235 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
236 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint',
237 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/interrupt_checkpoint',
238 | 'train_datasets': ['/root/generation_with_augmentation/dataset/meme/train_all.txt'],
239 | 'test_datasets': ['/root/generation_with_augmentation/dataset/meme/valid_all.txt']})
240 | return config
241 |
242 |
243 | def get_test_config_meme():
244 | config = AttrDict({'seed': 0,
245 | 'device': 'cuda',
246 | 'load_last': True,
247 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt',
248 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint19'})
249 | return config
--------------------------------------------------------------------------------
/code/retrieval/pytorch_pretrained_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .file_utils import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
37 | }
38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
39 | 'bert-base-uncased': 512,
40 | 'bert-large-uncased': 512,
41 | 'bert-base-cased': 512,
42 | 'bert-large-cased': 512,
43 | 'bert-base-multilingual-uncased': 512,
44 | 'bert-base-multilingual-cased': 512,
45 | 'bert-base-chinese': 512,
46 | }
47 | VOCAB_NAME = 'vocab.txt'
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r", encoding="utf-8") as reader:
55 | while True:
56 | token = reader.readline()
57 | if not token:
58 | break
59 | token = token.strip()
60 | vocab[token] = index
61 | index += 1
62 | return vocab
63 |
64 |
65 | def whitespace_tokenize(text):
66 | """Runs basic whitespace cleaning and splitting on a piece of text."""
67 | text = text.strip()
68 | if not text:
69 | return []
70 | tokens = text.split()
71 | return tokens
72 |
73 |
74 | class BertTokenizer(object):
75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
76 |
77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
79 | """Constructs a BertTokenizer.
80 |
81 | Args:
82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
83 | do_lower_case: Whether to lower case the input
84 | Only has an effect when do_wordpiece_only=False
85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
86 | max_len: An artificial maximum length to truncate tokenized sequences to;
87 | Effective maximum length is always the minimum of this
88 | value (if specified) and the underlying BERT model's
89 | sequence length.
90 | never_split: List of tokens which will never be split during tokenization.
91 | Only has an effect when do_wordpiece_only=False
92 | """
93 | if not os.path.isfile(vocab_file):
94 | raise ValueError(
95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
97 | self.vocab = load_vocab(vocab_file)
98 | self.ids_to_tokens = collections.OrderedDict(
99 | [(ids, tok) for tok, ids in self.vocab.items()])
100 | self.do_basic_tokenize = do_basic_tokenize
101 | if do_basic_tokenize:
102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
103 | never_split=never_split)
104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
105 | self.max_len = max_len if max_len is not None else int(1e12)
106 |
107 | def tokenize(self, text):
108 | if self.do_basic_tokenize:
109 | split_tokens = []
110 | for token in self.basic_tokenizer.tokenize(text):
111 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
112 | split_tokens.append(sub_token)
113 | else:
114 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
115 | return split_tokens
116 |
117 | def convert_tokens_to_ids(self, tokens):
118 | """Converts a sequence of tokens into ids using the vocab."""
119 | ids = []
120 | for token in tokens:
121 | ids.append(self.vocab[token])
122 | if len(ids) > self.max_len:
123 | logger.warning(
124 | "Token indices sequence length is longer than the specified maximum "
125 | " sequence length for this BERT model ({} > {}). Running this"
126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
127 | )
128 | return ids
129 |
130 | def convert_ids_to_tokens(self, ids):
131 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
132 | tokens = []
133 | for i in ids:
134 | tokens.append(self.ids_to_tokens[i])
135 | return tokens
136 |
137 | @classmethod
138 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
139 | """
140 | Instantiate a PreTrainedBertModel from a pre-trained model file.
141 | Download and cache the pre-trained model file if needed.
142 | """
143 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
144 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
145 | else:
146 | vocab_file = pretrained_model_name_or_path
147 | if os.path.isdir(vocab_file):
148 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
149 | # redirect to the cache, if necessary
150 | try:
151 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
152 | except EnvironmentError:
153 | logger.error(
154 | "Model name '{}' was not found in model name list ({}). "
155 | "We assumed '{}' was a path or url but couldn't find any file "
156 | "associated to this path or url.".format(
157 | pretrained_model_name_or_path,
158 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
159 | vocab_file))
160 | return None
161 | if resolved_vocab_file == vocab_file:
162 | logger.info("loading vocabulary file {}".format(vocab_file))
163 | else:
164 | logger.info("loading vocabulary file {} from cache at {}".format(
165 | vocab_file, resolved_vocab_file))
166 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
167 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
168 | # than the number of positional embeddings
169 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
170 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
171 | # Instantiate tokenizer.
172 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
173 | return tokenizer
174 |
175 |
176 | class BasicTokenizer(object):
177 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
178 |
179 | def __init__(self,
180 | do_lower_case=True,
181 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
182 | """Constructs a BasicTokenizer.
183 |
184 | Args:
185 | do_lower_case: Whether to lower case the input.
186 | """
187 | self.do_lower_case = do_lower_case
188 | self.never_split = never_split
189 |
190 | def tokenize(self, text):
191 | """Tokenizes a piece of text."""
192 | text = self._clean_text(text)
193 | # This was added on November 1st, 2018 for the multilingual and Chinese
194 | # models. This is also applied to the English models now, but it doesn't
195 | # matter since the English models were not trained on any Chinese data
196 | # and generally don't have any Chinese data in them (there are Chinese
197 | # characters in the vocabulary because Wikipedia does have some Chinese
198 | # words in the English Wikipedia.).
199 | text = self._tokenize_chinese_chars(text)
200 | orig_tokens = whitespace_tokenize(text)
201 | split_tokens = []
202 | for token in orig_tokens:
203 | if self.do_lower_case and token not in self.never_split:
204 | token = token.lower()
205 | token = self._run_strip_accents(token)
206 | split_tokens.extend(self._run_split_on_punc(token))
207 |
208 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
209 | return output_tokens
210 |
211 | def _run_strip_accents(self, text):
212 | """Strips accents from a piece of text."""
213 | text = unicodedata.normalize("NFD", text)
214 | output = []
215 | for char in text:
216 | cat = unicodedata.category(char)
217 | if cat == "Mn":
218 | continue
219 | output.append(char)
220 | return "".join(output)
221 |
222 | def _run_split_on_punc(self, text):
223 | """Splits punctuation on a piece of text."""
224 | if text in self.never_split:
225 | return [text]
226 | chars = list(text)
227 | i = 0
228 | start_new_word = True
229 | output = []
230 | while i < len(chars):
231 | char = chars[i]
232 | if _is_punctuation(char):
233 | output.append([char])
234 | start_new_word = True
235 | else:
236 | if start_new_word:
237 | output.append([])
238 | start_new_word = False
239 | output[-1].append(char)
240 | i += 1
241 |
242 | return ["".join(x) for x in output]
243 |
244 | def _tokenize_chinese_chars(self, text):
245 | """Adds whitespace around any CJK character."""
246 | output = []
247 | for char in text:
248 | cp = ord(char)
249 | if self._is_chinese_char(cp):
250 | output.append(" ")
251 | output.append(char)
252 | output.append(" ")
253 | else:
254 | output.append(char)
255 | return "".join(output)
256 |
257 | def _is_chinese_char(self, cp):
258 | """Checks whether CP is the codepoint of a CJK character."""
259 | # This defines a "chinese character" as anything in the CJK Unicode block:
260 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
261 | #
262 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
263 | # despite its name. The modern Korean Hangul alphabet is a different block,
264 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
265 | # space-separated words, so they are not treated specially and handled
266 | # like the all of the other languages.
267 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
268 | (cp >= 0x3400 and cp <= 0x4DBF) or #
269 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
270 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
271 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
272 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
273 | (cp >= 0xF900 and cp <= 0xFAFF) or #
274 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
275 | return True
276 |
277 | return False
278 |
279 | def _clean_text(self, text):
280 | """Performs invalid character removal and whitespace cleanup on text."""
281 | output = []
282 | for char in text:
283 | cp = ord(char)
284 | if cp == 0 or cp == 0xfffd or _is_control(char):
285 | continue
286 | if _is_whitespace(char):
287 | output.append(" ")
288 | else:
289 | output.append(char)
290 | return "".join(output)
291 |
292 |
293 | class WordpieceTokenizer(object):
294 | """Runs WordPiece tokenization."""
295 |
296 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
297 | self.vocab = vocab
298 | self.unk_token = unk_token
299 | self.max_input_chars_per_word = max_input_chars_per_word
300 |
301 | def tokenize(self, text):
302 | """Tokenizes a piece of text into its word pieces.
303 |
304 | This uses a greedy longest-match-first algorithm to perform tokenization
305 | using the given vocabulary.
306 |
307 | For example:
308 | input = "unaffable"
309 | output = ["un", "##aff", "##able"]
310 |
311 | Args:
312 | text: A single token or whitespace separated tokens. This should have
313 | already been passed through `BasicTokenizer`.
314 |
315 | Returns:
316 | A list of wordpiece tokens.
317 | """
318 |
319 | output_tokens = []
320 | for token in whitespace_tokenize(text):
321 | chars = list(token)
322 | if len(chars) > self.max_input_chars_per_word:
323 | output_tokens.append(self.unk_token)
324 | continue
325 |
326 | is_bad = False
327 | start = 0
328 | sub_tokens = []
329 | while start < len(chars):
330 | end = len(chars)
331 | cur_substr = None
332 | while start < end:
333 | substr = "".join(chars[start:end])
334 | if start > 0:
335 | substr = "##" + substr
336 | if substr in self.vocab:
337 | cur_substr = substr
338 | break
339 | end -= 1
340 | if cur_substr is None:
341 | is_bad = True
342 | break
343 | sub_tokens.append(cur_substr)
344 | start = end
345 |
346 | if is_bad:
347 | output_tokens.append(self.unk_token)
348 | else:
349 | output_tokens.extend(sub_tokens)
350 | return output_tokens
351 |
352 |
353 | def _is_whitespace(char):
354 | """Checks whether `chars` is a whitespace character."""
355 | # \t, \n, and \r are technically contorl characters but we treat them
356 | # as whitespace since they are generally considered as such.
357 | if char == " " or char == "\t" or char == "\n" or char == "\r":
358 | return True
359 | cat = unicodedata.category(char)
360 | if cat == "Zs":
361 | return True
362 | return False
363 |
364 |
365 | def _is_control(char):
366 | """Checks whether `chars` is a control character."""
367 | # These are technically control characters but we count them as whitespace
368 | # characters.
369 | if char == "\t" or char == "\n" or char == "\r":
370 | return False
371 | cat = unicodedata.category(char)
372 | if cat.startswith("C"):
373 | return True
374 | return False
375 |
376 |
377 | def _is_punctuation(char):
378 | """Checks whether `chars` is a punctuation character."""
379 | cp = ord(char)
380 | # We treat all non-letter/number ASCII as punctuation.
381 | # Characters such as "^", "$", and "`" are not in the Unicode
382 | # Punctuation class but we treat them as punctuation anyways, for
383 | # consistency.
384 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
385 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
386 | return True
387 | cat = unicodedata.category(char)
388 | if cat.startswith("P"):
389 | return True
390 | return False
391 |
--------------------------------------------------------------------------------