├── .gitignore ├── README.md ├── data └── scripts │ ├── __init__.py │ ├── dataloader.py │ └── dataset.py ├── models ├── MTModel_Hybird.py ├── MaskedCELoss.py └── __init__.py ├── requirement.txt └── scripts ├── __init__.py ├── run_example.sh ├── train.py ├── translate.py ├── translate_example.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | /models/__pycache__/ 3 | /models/__pycache__/* 4 | /scripts/__pycache__/* 5 | /scripts/logs 6 | /scripts/parameters/* 7 | /scripts/demonstration 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert_transformer_for_machine_translation 2 | BERT2Transformer(Decoder) 3 | Deep learning framework using Mxnet (gluon, gluonnlp) 4 | 5 | # Prepare 6 | 1.The arguments require a training file and an evaluation file, with each line of each file being a parallel corpus pair, separated by "\t" by default. 7 | ___________ 8 | 2.There is no need for operations such as splitting words and building dictionaries. The encoder of the model is a pre-trained BERT, and it is good to use its corresponding Tokenizer. The decoder also does not need to build its own dictionary, and uses the dictionary of BERT of a certain language. The dictionaries of the encoder and decoder are set in "--src_bert_dataset" and "--tgt_bert_dataset", and the specific BERT versions that can be used are listed below. 9 | ### The supported bert datasets are: 10 | 'book_corpus_wiki_en_cased', 11 | 'book_corpus_wiki_en_uncased', 12 | 'wiki_cn_cased', 13 | 'openwebtext_book_corpus_wiki_en_uncased', 14 | 'wiki_multilingual_uncased', 15 | 'wiki_multilingual_cased', 16 | 17 | ___________ 18 | 3.You simply need to check whether the source and target languages to be used have their corresponding BERT versions. If it is OK, you just need to prepare the training parallel corpus and set it up in the arguments. 19 | 20 | ## Environment 21 | Use requirement.txt to install the required packages, the first time you use it, you will automatically download the corresponding BERT pre-training parameters. 22 | 23 | # Train 24 | Set your arguments in train.py, then 'python train.py' 25 | 26 | # Translate 27 | If you need to use the trained model for translation, please use translate.py and set the trained model parameter address in the arguments. -------------------------------------------------------------------------------- /data/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | -------------------------------------------------------------------------------- /data/scripts/dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import re 5 | import sys 6 | sys.path.append("../") 7 | 8 | import gluonnlp as nlp 9 | from gluonnlp import Vocab 10 | from gluonnlp.data import BERTSentenceTransform, BERTTokenizer 11 | from mxnet import nd 12 | from mxnet.gluon import data 13 | 14 | 15 | class DatasetAssiantTransformer(): 16 | def __init__(self, src_vocab=None, tgt_vocab=None, max_src_len=None, max_tgt_len=None): 17 | self.src_vocab = src_vocab 18 | self.tgt_vocab = tgt_vocab 19 | self.max_src_len = max_src_len 20 | self.max_tgt_len = max_tgt_len 21 | self.bert_src_tokenzier = BERTTokenizer(src_vocab) 22 | self.bert_tgt_tokenzier = BERTTokenizer(tgt_vocab) 23 | self.bos_token = "[CLS]" 24 | self.eos_token = "[SEP]" 25 | 26 | def MTSentenceProcess(self, *src_and_tgt): 27 | src, tgt = src_and_tgt 28 | assert isinstance(src, str), 'the input type must be str' 29 | assert isinstance(tgt, str), 'the input type must be str' 30 | 31 | src = self.bert_src_tokenzier(src) 32 | tgt = self.bert_tgt_tokenzier(tgt) 33 | 34 | if self.max_src_len and len(src) > self.max_src_len-2: 35 | src = src[0:self.max_src_len-2] 36 | if self.max_tgt_len and len(tgt) > self.max_tgt_len-1: 37 | tgt = tgt[0:self.max_tgt_len-1] 38 | 39 | src = [self.src_vocab.cls_token] + \ 40 | src + [self.src_vocab.sep_token] 41 | tgt = [self.bos_token] + tgt 42 | 43 | src_valid_len = len(src) 44 | tgt_valid_len = len(tgt) 45 | 46 | src = self.src_vocab[src] 47 | tgt = self.tgt_vocab[tgt] 48 | label = tgt[1:] + [self.tgt_vocab(self.eos_token)] 49 | 50 | return src, tgt, label, src_valid_len, tgt_valid_len 51 | 52 | 53 | class MTDataLoader(object): 54 | def __init__(self, dataset, batch_size, assiant, shuffle=False, num_workers=3, lazy=True): 55 | trans_func = assiant.MTSentenceProcess 56 | self.dataset = dataset.transform(trans_func, lazy=lazy) 57 | self.batch_size = batch_size 58 | self.src_pad_val = assiant.src_vocab[assiant.src_vocab.padding_token] 59 | self.tgt_pad_val = assiant.tgt_vocab[assiant.tgt_vocab.padding_token] 60 | self.shuffle = shuffle 61 | self.num_workers = num_workers 62 | self.dataloader = self._build_dataloader() 63 | 64 | def _build_dataloader(self): 65 | batchify_fn = nlp.data.batchify.Tuple( 66 | nlp.data.batchify.Pad(pad_val=self.src_pad_val), 67 | nlp.data.batchify.Pad(pad_val=self.tgt_pad_val), 68 | nlp.data.batchify.Pad(pad_val=self.tgt_pad_val), 69 | nlp.data.batchify.Stack(dtype="float32"), 70 | nlp.data.batchify.Stack(dtype="float32"),) 71 | dataloader = data.DataLoader(dataset=self.dataset, batch_size=self.batch_size, 72 | shuffle=self.shuffle, batchify_fn=batchify_fn, 73 | num_workers=self.num_workers) 74 | return dataloader 75 | 76 | @property 77 | def dataiter(self): 78 | return self.dataloader 79 | 80 | @property 81 | def data_lengths(self): 82 | return len(self.dataset) 83 | -------------------------------------------------------------------------------- /data/scripts/dataset.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | 3 | 4 | class MTDataset(gluon.data.Dataset): 5 | def __init__(self, data_path, **kwargs): 6 | super(MTDataset, self).__init__(**kwargs) 7 | self.src_sentences, self.tgt_sentences = self._get_data(data_path) 8 | 9 | def _get_data(self, data_path, sep="\t"): 10 | src_sentences = [] 11 | tgt_sentences = [] 12 | with open(data_path, 'r', encoding='utf-8') as fr_trans: 13 | lines = [line.strip() 14 | for line in fr_trans.readlines()] 15 | for line in lines: 16 | src, tgt = line.split(sep) 17 | # tgt, src = line.split(sep) 18 | src_sentences.append(src) 19 | tgt_sentences.append(tgt) 20 | if len(src_sentences) != len(tgt_sentences): 21 | assert "lens of SRC and TGT is not the same!" 22 | return src_sentences, tgt_sentences 23 | 24 | def __getitem__(self, item): 25 | return self.src_sentences[item], self.tgt_sentences[item] 26 | 27 | def __len__(self): 28 | return len(self.src_sentences) 29 | -------------------------------------------------------------------------------- /models/MTModel_Hybird.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | sys.path.append("../../") 4 | 5 | import gluonnlp 6 | import mxnet as mx 7 | import numpy as np 8 | from mxnet import nd 9 | from mxnet.gluon import nn 10 | 11 | 12 | 13 | class Transformer(nn.Block): 14 | def __init__(self, src_vocab, tgt_vocab, embedding_dim, model_dim, head_num, layer_num, ffn_dim, dropout, att_dropout, ffn_dropout, ctx=mx.cpu(), **kwargs): 15 | super(Transformer, self).__init__(**kwargs) 16 | self._ctx = ctx 17 | self._embedding_dim = embedding_dim 18 | self._model_dim = model_dim 19 | self.src_pad_idx = src_vocab(src_vocab.padding_token) 20 | self.tgt_pad_idx = tgt_vocab(tgt_vocab.padding_token) 21 | self.tgt_embedding = nn.Embedding( 22 | len(tgt_vocab.idx_to_token), embedding_dim) 23 | with self.name_scope(): 24 | self.decoder = Decoder(embedding_dim, model_dim, head_num, 25 | layer_num, ffn_dim, dropout, att_dropout, ffn_dropout) 26 | self.linear = nn.Dense( 27 | len(tgt_vocab.idx_to_token), flatten=False, params=self.tgt_embedding.collect_params()) 28 | 29 | def forward(self, src_bert_output, src_idx, tgt_idx): 30 | self_tril_mask = self._get_self_tril_mask(tgt_idx) 31 | self_key_mask = self._get_key_mask( 32 | tgt_idx, tgt_idx, pad_idx=self.tgt_pad_idx) 33 | self_att_mask = nd.greater((self_key_mask + self_tril_mask), 1) 34 | 35 | context_att_mask = self._get_key_mask( 36 | src_idx, tgt_idx, pad_idx=self.src_pad_idx) 37 | non_pad_mask = self._get_non_pad_mask( 38 | tgt_idx, pad_idx=self.tgt_pad_idx) 39 | 40 | position = nd.array(self._position_encoding_init( 41 | tgt_idx.shape[1], self._model_dim), ctx=self._ctx) 42 | position = nd.expand_dims(position, axis=0) 43 | position = nd.broadcast_axes(position, axis=0, size=tgt_idx.shape[0]) 44 | position = position * non_pad_mask 45 | tgt_emb = self.tgt_embedding(tgt_idx) 46 | outputs = self.decoder( 47 | src_bert_output, tgt_emb, position, self_att_mask, context_att_mask, non_pad_mask) 48 | outputs = self.linear(outputs) 49 | return outputs 50 | 51 | def _get_non_pad_mask(self, seq, pad_idx=None): 52 | if pad_idx: 53 | non_pad_mask = nd.not_equal(seq, pad_idx) 54 | else: 55 | non_pad_mask = nd.not_equal(seq, 0) 56 | non_pad_mask = nd.expand_dims(non_pad_mask, axis=2) 57 | return non_pad_mask 58 | 59 | def _get_key_mask(self, enc_idx, dec_idx, pad_idx=None): 60 | seq_len = dec_idx.shape[1] 61 | if pad_idx: 62 | pad_mask = nd.not_equal(enc_idx, pad_idx) 63 | else: 64 | pad_mask = nd.not_equal(enc_idx, 0) 65 | pad_mask = nd.expand_dims(pad_mask, axis=1) 66 | pad_mask = nd.broadcast_axes(pad_mask, axis=1, size=seq_len) 67 | return pad_mask 68 | 69 | def _get_self_tril_mask(self, dec_idx): 70 | batch_size, seq_len = dec_idx.shape 71 | mask_matrix = np.ones(shape=(seq_len, seq_len)) 72 | mask = np.tril(mask_matrix, k=0) 73 | mask = nd.expand_dims(nd.array(mask, ctx=self._ctx), axis=0) 74 | mask = nd.broadcast_axes(mask, axis=0, size=batch_size) 75 | return mask 76 | 77 | def _position_encoding_init(self, max_length, dim): 78 | """Init the sinusoid position encoding table """ 79 | position_enc = np.arange(max_length).reshape((-1, 1)) \ 80 | / (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1)))) 81 | # Apply the cosine to even columns and sin to odds. 82 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i 83 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 84 | return position_enc 85 | 86 | 87 | class Decoder(nn.HybridBlock): 88 | def __init__(self, embedding_dim, model_dim, head_num, layer_num, ffn_dim, dropout, att_dropout, ffn_dropout, **kwargs): 89 | super(Decoder, self).__init__(**kwargs) 90 | with self.name_scope(): 91 | self.enc_model_dense = nn.Dense( 92 | model_dim, flatten=False, use_bias=False) 93 | self.dec_model_dense = nn.Dense( 94 | model_dim, flatten=False, use_bias=False) 95 | self.decoder_layers = [] 96 | for i in range(layer_num): 97 | sub_layer = DecoderLayer(model_dim, head_num, ffn_dim, dropout, 98 | att_dropout, ffn_dropout) 99 | self.register_child(sub_layer) 100 | self.decoder_layers.append(sub_layer) 101 | 102 | def hybrid_forward(self, F, bert_output, lm_output, position, self_att_mask, context_att_mask, non_pad_mask): 103 | bert_output = self.enc_model_dense(bert_output) 104 | lm_output = self.dec_model_dense(lm_output) 105 | 106 | dec_output = lm_output + position 107 | 108 | for sub_layer in self.decoder_layers: 109 | dec_output = sub_layer( 110 | bert_output, dec_output, self_att_mask, context_att_mask, non_pad_mask) 111 | return dec_output 112 | 113 | 114 | class DecoderLayer(nn.HybridBlock): 115 | def __init__(self, model_dim, head_num, ffn_dim, dropout, att_dropout, ffn_dropout, ** kwargs): 116 | super(DecoderLayer, self).__init__(**kwargs) 117 | with self.name_scope(): 118 | self.self_masked_attention = MultiHeadAttention( 119 | model_dim, head_num, dropout, att_dropout) 120 | self.context_attention = MultiHeadAttention( 121 | model_dim, head_num, dropout, att_dropout) 122 | self.feed_forward = FeedForward( 123 | model_dim, ffn_dim, ffn_dropout) 124 | 125 | def hybrid_forward(self, F, enc_emb, dec_emb, self_att_mask, context_att_mask, non_pad_mask): 126 | dec_output = self.self_masked_attention( 127 | dec_emb, 128 | dec_emb, 129 | dec_emb, 130 | self_att_mask, 131 | ) 132 | dec_output = F.broadcast_mul(dec_output, non_pad_mask) 133 | dec_output = self.context_attention( 134 | dec_output, 135 | enc_emb, 136 | enc_emb, 137 | context_att_mask, 138 | ) 139 | dec_output = F.broadcast_mul(dec_output, non_pad_mask) 140 | dec_output = self.feed_forward(dec_output) 141 | dec_output = F.broadcast_mul(dec_output, non_pad_mask) 142 | return dec_output 143 | 144 | 145 | class MultiHeadAttention(nn.HybridBlock): 146 | def __init__(self, model_dim, head_num, dropout, att_dropout, **kwargs): 147 | super(MultiHeadAttention, self).__init__(**kwargs) 148 | self._model_dim = model_dim 149 | self._head_num = head_num 150 | if self._model_dim % self._head_num != 0: 151 | raise ValueError('In MultiHeadAttetion, the model_dim should be divided exactly' 152 | ' by the number of head_num. Received model_dim={}, head_num={}' 153 | .format(model_dim, head_num)) 154 | with self.name_scope(): 155 | self.queries_dense = nn.Dense( 156 | model_dim, use_bias=False, flatten=False, prefix="query_") 157 | self.keys_dense = nn.Dense( 158 | model_dim, use_bias=False, flatten=False, prefix="keys_") 159 | self.values_dense = nn.Dense( 160 | model_dim, use_bias=False, flatten=False, prefix="values_") 161 | self.att_dropout = nn.Dropout(att_dropout) 162 | self.dropout = nn.Dropout(dropout) 163 | self.LayerNorm = nn.LayerNorm() 164 | 165 | def hybrid_forward(self, F, queries, keys, values, mask=None): 166 | Q = self.queries_dense(queries) 167 | K = self.keys_dense(keys) 168 | V = self.values_dense(values) 169 | c_dim = int(self._model_dim / self._head_num) 170 | 171 | Q = F.reshape(F.transpose(F.reshape(Q, shape=(0, 0, self._head_num, -1)), 172 | axes=(0, 2, 1, 3)), shape=(-1, 0, 0), reverse=True) 173 | K = F.reshape(F.transpose(F.reshape(K, shape=(0, 0, self._head_num, -1)), 174 | axes=(0, 2, 1, 3)), shape=(-1, 0, 0), reverse=True) 175 | V = F.reshape(F.transpose(F.reshape(V, shape=(0, 0, self._head_num, -1)), 176 | axes=(0, 2, 1, 3)), shape=(-1, 0, 0), reverse=True) 177 | 178 | scale = c_dim ** -0.5 179 | # att_score 180 | att_scores = F.batch_dot(Q, K, transpose_b=True) 181 | 182 | # scale 183 | att_scores = att_scores * scale 184 | 185 | # mask 186 | if mask is not None: 187 | mask = F.reshape(F.broadcast_axes(F.expand_dims(mask, axis=1), 188 | axis=1, size=self._head_num), shape=(-1, 0, 0), reverse=True) 189 | padding = F.ones_like(mask) * -np.inf 190 | att_scores = F.where(mask, att_scores, padding) 191 | att_weights = F.softmax(att_scores, axis=-1) 192 | 193 | outputs = F.batch_dot(att_weights, V) 194 | outputs = F.reshape(F.transpose(F.reshape(outputs, shape=(-1, self._head_num, 195 | 0, 0), reverse=True), axes=(0, 2, 1, 3)), shape=(0, 0, -1)) 196 | outputs = self.dropout(outputs) 197 | # residual 198 | outputs = F.broadcast_add(outputs, queries) 199 | outputs = self.LayerNorm(outputs) 200 | return outputs 201 | 202 | 203 | class FeedForward(nn.HybridBlock): 204 | def __init__(self, model_dim, ffn_dim, ffn_dropout, use_bias=True, activation="relu", **kwargs): 205 | super(FeedForward, self).__init__(**kwargs) 206 | with self.name_scope(): 207 | self.ffn_dense = nn.Dense( 208 | ffn_dim, activation=activation, use_bias=use_bias, flatten=False) 209 | self.model_dense = nn.Dense( 210 | model_dim, use_bias=use_bias, bias_initializer="zeros", flatten=False) 211 | self.dropout = nn.Dropout(ffn_dropout) 212 | self.layer_norm = nn.LayerNorm() 213 | 214 | def hybrid_forward(self, F, x, *args): 215 | output = self.ffn_dense(x) 216 | output = self.model_dense(output) 217 | output = self.dropout(output) 218 | output = self.layer_norm(F.broadcast_add(x, output)) 219 | return output 220 | -------------------------------------------------------------------------------- /models/MaskedCELoss.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon.loss import SoftmaxCrossEntropyLoss 2 | from mxnet import ndarray 3 | 4 | 5 | class MaskedCELoss(SoftmaxCrossEntropyLoss): 6 | def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, 7 | batch_axis=0, **kwargs): 8 | super(MaskedCELoss, self).__init__( 9 | weight, batch_axis, **kwargs) 10 | self._axis = axis 11 | self._batch_axis = batch_axis 12 | self._sparse_label = sparse_label 13 | self._from_logits = from_logits 14 | 15 | def hybrid_forward(self, F, pred, label, mask=None): 16 | if not self._from_logits: 17 | pred = F.log_softmax(pred, self._axis) 18 | if self._sparse_label: 19 | loss = -F.pick(pred, label, axis=self._axis, keepdims=True) 20 | else: 21 | label = _reshape_like(F, label, pred) 22 | loss = -F.sum(pred * label, axis=self._axis) 23 | if mask is not None: 24 | loss = loss * mask 25 | return F.sum(loss) / F.sum(mask) 26 | 27 | 28 | def _reshape_like(F, x, y): 29 | """Reshapes x to the same shape as y.""" 30 | return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y) 31 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | gluon==1.1.0 2 | gluonnlp==0.8.2 3 | multiprocess==0.70.11.1 4 | mxboard==0.1.0 5 | mxnet-cu101==1.5.1.post0 6 | nlp==0.4.0 7 | nltk==3.4.5 8 | numpy==1.19.5 9 | scikit-learn==0.23.2 10 | six==1.15.0 11 | sklearn==0.0 12 | subprocess32==3.5.4 13 | tensorboard==2.4.1 14 | tensorboard-plugin-wit==1.8.0 15 | tensorboardX==2.1 16 | tqdm==4.58.0 17 | 18 | 19 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | -------------------------------------------------------------------------------- /scripts/run_example.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --train_data_path=/is/bes/to/use/absolute/paths/wmt2020_ende.txt \ 3 | --eval_data_path=./is/bes/to/use/absolute/paths/eval_wmt2017_ende.txt \ 4 | --src_bert_dataset=wiki_cn_cased \ 5 | --tgt_bert_dataset=book_corpus_wiki_en_uncased \ 6 | --train_lr=0.0002 \ 7 | --finetune_lr=2e-5 \ 8 | --batch_size=32 \ 9 | --epochs=5 \ 10 | --log_root=../logs \ 11 | --log_step=100 \ 12 | --eval_step=10000 \ 13 | --check_step=300 \ 14 | --max_src_len=50 \ 15 | --max_tgt_len=50 \ -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import sys 5 | sys.path.append("..") 6 | 7 | import gluonnlp 8 | import mxnet as mx 9 | import numpy as np 10 | from data.scripts.dataloader import DatasetAssiantTransformer, MTDataLoader 11 | from data.scripts.dataset import MTDataset 12 | from gluonnlp.data import train_valid_split 13 | from gluonnlp.model import BeamSearchScorer 14 | from models.MaskedCELoss import MaskedCELoss 15 | from models.MTModel_Hybird import Transformer as MTModel_Hybird 16 | from mxnet import autograd, gluon, init, nd 17 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 18 | from numpy import random 19 | from tensorboardX import SummaryWriter 20 | from tqdm import tqdm 21 | 22 | from utils import config_logger 23 | 24 | 25 | 26 | np.random.seed(100) 27 | random.seed(100) 28 | mx.random.seed(10000) 29 | BOS = "[CLS]" 30 | EOS = "[SEP]" 31 | writer = SummaryWriter("../runs") 32 | 33 | 34 | def train_and_valid(src_bert, 35 | mt_model, 36 | src_vocab, 37 | tgt_vocab, 38 | train_dataiter, 39 | dev_dataiter, 40 | trainer, 41 | finetune_trainer, 42 | epochs, 43 | loss_func, 44 | ctx, 45 | lr, 46 | batch_size, 47 | params_save_path_root, 48 | eval_step, 49 | log_step, 50 | check_step, 51 | label_smooth, 52 | logger, 53 | num_train_examples, 54 | warmup_ratio): 55 | batches = len(train_dataiter) 56 | 57 | num_train_steps = int(num_train_examples / batch_size * epochs) 58 | num_warmup_steps = int(num_train_steps * warmup_ratio) 59 | global_step = 0 60 | dev_bleu_score = 0 61 | 62 | for epoch in range(epochs): 63 | for src, tgt, label, src_valid_len, tgt_valid_len in train_dataiter: 64 | # learning rate strategy 65 | if global_step < num_warmup_steps: 66 | new_lr = lr * global_step / num_warmup_steps 67 | else: 68 | non_warmup_steps = global_step - num_warmup_steps 69 | offset = non_warmup_steps / \ 70 | (num_train_steps - num_warmup_steps) 71 | new_lr = lr - offset * lr 72 | trainer.set_learning_rate(new_lr) 73 | 74 | src = src.as_in_context(ctx) 75 | tgt = tgt.as_in_context(ctx) 76 | label = label.as_in_context(ctx) 77 | src_valid_len = src_valid_len.as_in_context(ctx) 78 | src_token_type = nd.zeros_like(src, ctx=ctx) 79 | 80 | tgt_mask = nd.not_equal(tgt, tgt_vocab(tgt_vocab.padding_token)) 81 | 82 | if label_smooth: 83 | eps = 0.1 84 | num_class = len(tgt_vocab.idx_to_token) 85 | one_hot = nd.one_hot(label, num_class) 86 | one_hot_label = one_hot * \ 87 | (1 - eps) + (1 - one_hot) * eps / num_class 88 | 89 | with autograd.record(): 90 | src_bert_outputs = src_bert( 91 | src, src_token_type, src_valid_len) 92 | mt_outputs = mt_model(src_bert_outputs, src, tgt) 93 | loss_mean = loss_func(mt_outputs, one_hot_label, tgt_mask) 94 | 95 | loss_mean.backward() 96 | loss_scalar = loss_mean.asscalar() 97 | 98 | trainer.step(1) 99 | finetune_trainer.step(1) 100 | 101 | if global_step and global_step % log_step == 0: 102 | predicts = nd.argmax(nd.softmax(mt_outputs, axis=-1), axis=-1) 103 | correct = nd.equal(label, predicts) 104 | accuracy = (nd.sum(correct * tgt_mask) / 105 | nd.sum(tgt_mask)).asscalar() 106 | logger.info("epoch:{}, batch:{}/{}, bleu:{}, acc:{}, loss:{}, (lr:{}s)" 107 | .format(epoch, global_step % batches, batches, dev_bleu_score, accuracy, loss_scalar, trainer.learning_rate)) 108 | 109 | if global_step and global_step % check_step == 0: 110 | predicts = nd.argmax(nd.softmax(mt_outputs, axis=-1), axis=-1) 111 | refer_sample = src.asnumpy().tolist() 112 | label_sample = label.asnumpy().tolist() 113 | pred_sample = predicts.asnumpy().tolist() 114 | logger.info("train sample:") 115 | logger.info("refer :{}".format( 116 | " ".join([src_vocab.idx_to_token[int(idx)] for idx in refer_sample[0]])).replace(src_vocab.padding_token, "")) 117 | logger.info("target :{}".format( 118 | " ".join([tgt_vocab.idx_to_token[int(idx)] for idx in label_sample[0]])).replace(EOS, "[EOS]").replace(tgt_vocab.padding_token, "")) 119 | logger.info("predict:{}".format( 120 | " ".join([tgt_vocab.idx_to_token[int(idx)] for idx in pred_sample[0]])).replace(EOS, "[EOS]")) 121 | 122 | if global_step and global_step % eval_step == 0: 123 | dev_bleu_score = eval(src_bert, mt_model, src_vocab, 124 | tgt_vocab, dev_dataiter, logger, ctx=ctx) 125 | if not os.path.exists(params_save_path_root): 126 | os.makedirs(params_save_path_root) 127 | model_params_file = params_save_path_root + \ 128 | "src_bert_step_{}.params".format(global_step) 129 | src_bert.save_parameters(model_params_file) 130 | logger.info("{} Save Completed.".format(model_params_file)) 131 | 132 | model_params_file = params_save_path_root + \ 133 | "mt_step_{}.params".format(global_step) 134 | mt_model.save_parameters(model_params_file) 135 | logger.info("{} Save Completed.".format(model_params_file)) 136 | writer.add_scalar("loss", loss_scalar, global_step) 137 | global_step += 1 138 | 139 | 140 | def eval(src_bert, mt_model, src_vocab, tgt_vocab, dev_dataiter, logger, ctx): 141 | references = [] 142 | hypothesis = [] 143 | score = 0 144 | chencherry = SmoothingFunction() 145 | for src, _, label, src_valid_len, label_valid_len in tqdm(dev_dataiter): 146 | src = src.as_in_context(ctx) 147 | src_valid_len = src_valid_len.as_in_context(ctx) 148 | batch_size = src.shape[0] 149 | 150 | src_token_type = nd.zeros_like(src) 151 | src_bert_outputs = src_bert(src, src_token_type, src_valid_len) 152 | 153 | tgt_sentences = [BOS] 154 | tgt = tgt_vocab[tgt_sentences] 155 | tgt = nd.array([tgt], ctx=ctx) 156 | tgt = nd.broadcast_axes(tgt, axis=0, size=batch_size) 157 | 158 | for n in range(0, args.max_tgt_len): 159 | mt_outputs = mt_model(src_bert_outputs, src, tgt) 160 | predicts = nd.argmax(nd.softmax(mt_outputs, axis=-1), axis=-1) 161 | final_predict = predicts[:, -1:] 162 | tgt = nd.concat(tgt, final_predict, dim=1) 163 | 164 | label = label.asnumpy().tolist() 165 | predict_valid_len = nd.sum(nd.not_equal(predicts, tgt_vocab( 166 | tgt_vocab.padding_token)), axis=-1).asnumpy().tolist() 167 | predicts = tgt[:, 1:].asnumpy().tolist() 168 | label_valid_len = label_valid_len.asnumpy().tolist() 169 | 170 | for refer, hypoth, l_v_len, p_v_len in zip(label, predicts, label_valid_len, predict_valid_len): 171 | l_v_len = int(l_v_len) 172 | p_v_len = int(p_v_len) 173 | refer = refer[:l_v_len] 174 | refer_str = [tgt_vocab.idx_to_token[int(idx)] for idx in refer] 175 | hypoth_str = [tgt_vocab.idx_to_token[int(idx)] for idx in hypoth] 176 | hypoth_str_valid = [] 177 | for token in hypoth_str: 178 | if token == EOS: 179 | hypoth_str_valid.append(token) 180 | break 181 | hypoth_str_valid.append(token) 182 | references.append(refer_str) 183 | hypothesis.append(hypoth_str_valid) 184 | 185 | for refer, hypoth in zip(references, hypothesis): 186 | score += sentence_bleu([refer], hypoth, 187 | smoothing_function=chencherry.method1) 188 | logger.info("dev sample:") 189 | logger.info("refer :{}".format(" ".join(references[0]).replace( 190 | EOS, "[EOS]").replace(tgt_vocab.padding_token, ""))) 191 | logger.info("hypoth:{}".format(" ".join(hypothesis[0]).replace( 192 | EOS, "[EOS]"))) 193 | return score / len(references) 194 | 195 | 196 | def main(args): 197 | # init some setting 198 | # config logging 199 | log_path = os.path.join(args.log_root, '{}.log'.format(args.model_name)) 200 | if not os.path.exists(args.log_root): 201 | os.makedirs(args.log_root) 202 | logger = config_logger(log_path) 203 | 204 | gpu_idx = args.gpu 205 | if not gpu_idx: 206 | ctx = mx.cpu() 207 | else: 208 | ctx = mx.gpu(gpu_idx - 1) 209 | logger.info("Using ctx: {}".format(ctx)) 210 | 211 | # Loading vocab and model 212 | src_bert, src_vocab = gluonnlp.model.get_model(args.bert_model, 213 | dataset_name=args.src_bert_dataset, 214 | pretrained=True, 215 | ctx=ctx, 216 | use_pooler=False, 217 | use_decoder=False, 218 | use_classifier=False) 219 | _, tgt_vocab = gluonnlp.model.get_model(args.bert_model, 220 | dataset_name=args. tgt_bert_dataset, 221 | pretrained=True, 222 | ctx=ctx, 223 | use_pooler=False, 224 | use_decoder=False, 225 | use_classifier=False) 226 | 227 | mt_model = MTModel_Hybird(src_vocab=src_vocab, 228 | tgt_vocab=tgt_vocab, 229 | embedding_dim=args.mt_emb_dim, 230 | model_dim=args.mt_model_dim, 231 | head_num=args.mt_head_num, 232 | layer_num=args.mt_layer_num, 233 | ffn_dim=args.mt_ffn_dim, 234 | dropout=args.mt_dropout, 235 | att_dropout=args.mt_att_dropout, 236 | ffn_dropout=args.mt_ffn_dropout, 237 | ctx=ctx) 238 | logger.info("Model Creating Completed.") 239 | 240 | # init or load params for model 241 | mt_model.initialize(init.Xavier(), ctx) 242 | 243 | if args.src_bert_load_path: 244 | src_bert.load_parameters(args.src_bert_load_path, ctx=ctx) 245 | if args.mt_model_load_path: 246 | mt_model.load_parameters(args.mt_model_load_path, ctx=ctx) 247 | logger.info("Parameters Initing and Loading Completed") 248 | 249 | src_bert.hybridize() 250 | mt_model.hybridize() 251 | 252 | # Loading dataloader 253 | assiant = DatasetAssiantTransformer( 254 | src_vocab=src_vocab, tgt_vocab=tgt_vocab, max_src_len=args.max_src_len, max_tgt_len=args.max_tgt_len) 255 | train_dataset = MTDataset(args.train_data_path) 256 | eval_dataset = MTDataset(args.eval_data_path) 257 | 258 | train_dataiter = MTDataLoader(train_dataset, batch_size=args.batch_size, 259 | assiant=assiant, shuffle=True).dataiter 260 | dev_dataiter = MTDataLoader(eval_dataset, batch_size=args.batch_size, 261 | assiant=assiant, shuffle=True).dataiter 262 | logger.info("Data Loading Completed") 263 | 264 | # build trainer 265 | finetune_trainer = gluon.Trainer(src_bert.collect_params(), 266 | args.optimizer, {"learning_rate": args.finetune_lr}) 267 | trainer = gluon.Trainer(mt_model.collect_params(), args.optimizer, 268 | {"learning_rate": args.train_lr}) 269 | 270 | # loss function 271 | if args.label_smooth: 272 | loss_func = MaskedCELoss(sparse_label=False) 273 | else: 274 | loss_func = MaskedCELoss() 275 | 276 | logger.info("## Trainning Start ##") 277 | train_and_valid( 278 | src_bert=src_bert, 279 | mt_model=mt_model, 280 | src_vocab=src_vocab, 281 | tgt_vocab=tgt_vocab, 282 | train_dataiter=train_dataiter, 283 | dev_dataiter=dev_dataiter, 284 | trainer=trainer, 285 | finetune_trainer=finetune_trainer, 286 | epochs=args.epochs, 287 | loss_func=loss_func, 288 | ctx=ctx, 289 | lr=args.train_lr, 290 | batch_size=args.batch_size, 291 | params_save_path_root=args.params_save_path_root, 292 | eval_step=args.eval_step, 293 | log_step=args.log_step, 294 | check_step=args.check_step, 295 | label_smooth=args.label_smooth, 296 | logger=logger, 297 | num_train_examples=len(train_dataset), 298 | warmup_ratio=args.warmup_ratio 299 | ) 300 | 301 | 302 | if __name__ == "__main__": 303 | parser = argparse.ArgumentParser( 304 | "Use Bert's Outputs as Transformer's Encoder Outputs, the Decoder is a Norm Transformer's Decoder.") 305 | parser.add_argument("--model_name", type=str, 306 | default="bert2transformerOnMachineTranslation") 307 | parser.add_argument("--train_data_path", type=str, 308 | default=None, required=True) 309 | parser.add_argument("--eval_data_path", type=str, 310 | default=None, required=True) 311 | parser.add_argument("--bert_model", type=str, 312 | default="bert_12_768_12") 313 | parser.add_argument("--src_bert_dataset", type=str, 314 | default=None, required=True) 315 | parser.add_argument("--tgt_bert_dataset", type=str, 316 | default=None, required=True) 317 | parser.add_argument("--src_bert_load_path", type=str, 318 | default=None) 319 | parser.add_argument("--mt_model_load_path", type=str, 320 | default=None) 321 | parser.add_argument("--gpu", type=int, 322 | default=1, help='which gpu to use for finetuning. CPU is used if set 0.') 323 | parser.add_argument("--optimizer", type=str, default="adam") 324 | parser.add_argument("--train_lr", type=float, default=0.0002) 325 | parser.add_argument("--finetune_lr", type=float, default=2e-5) 326 | parser.add_argument("--label_smooth", type=bool, default=True) 327 | parser.add_argument("--batch_size", type=int, 328 | default=None, required=True) 329 | parser.add_argument("--epochs", type=int, 330 | default=None, required=True) 331 | parser.add_argument("--log_root", type=str, default=None, required=True) 332 | parser.add_argument("--log_step", type=int, default=None, required=True) 333 | parser.add_argument("--eval_step", type=int, default=None, required=True) 334 | parser.add_argument("--check_step", type=int, default=None, required=True) 335 | parser.add_argument("--params_save_path_root", 336 | type=str, default="../checkpoints/") 337 | parser.add_argument('--warmup_ratio', type=float, default=0.1, 338 | help='ratio of warmup steps that linearly increase learning rate from ' 339 | '0 to target learning rate. default is 0.1') 340 | parser.add_argument("--max_src_len", type=int, 341 | default=128) 342 | parser.add_argument("--max_tgt_len", type=int, 343 | default=128) 344 | # translation model parameters setting 345 | parser.add_argument("--mt_model_dim", type=int, 346 | default=768) 347 | parser.add_argument("--mt_emb_dim", type=int, 348 | default=768) 349 | parser.add_argument("--mt_head_num", type=int, 350 | default=8) 351 | parser.add_argument("--mt_layer_num", type=int, 352 | default=6) 353 | parser.add_argument("--mt_ffn_dim", type=int, 354 | default=2048) 355 | parser.add_argument("--mt_dropout", type=float, 356 | default=0.1) 357 | parser.add_argument("--mt_ffn_dropout", type=float, 358 | default=0.1) 359 | parser.add_argument("--mt_att_dropout", type=float, 360 | default=0.1) 361 | 362 | args = parser.parse_args() 363 | 364 | main(args) 365 | -------------------------------------------------------------------------------- /scripts/translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import sys 4 | sys.path.append("..") 5 | 6 | import gluonnlp 7 | import mxnet as mx 8 | from gluonnlp.data import BERTTokenizer 9 | from models.MTModel_Hybird import Transformer as MTModel_Hybird 10 | from mxnet import nd 11 | 12 | BOS = "[CLS]" 13 | EOS = "[SEP]" 14 | 15 | 16 | def translate(args): 17 | gpu_idx = args.gpu 18 | if not gpu_idx: 19 | ctx = mx.cpu() 20 | else: 21 | ctx = mx.gpu(gpu_idx - 1) 22 | src_bert, src_vocab = gluonnlp.model.get_model(args.bert_model, 23 | dataset_name=args.src_bert_dataset, 24 | pretrained=True, 25 | ctx=ctx, 26 | use_pooler=False, 27 | use_decoder=False, 28 | use_classifier=False) 29 | _, tgt_vocab = gluonnlp.model.get_model(args.bert_model, 30 | dataset_name=args.tgt_bert_dataset, 31 | pretrained=True, 32 | ctx=ctx, 33 | use_pooler=False, 34 | use_decoder=False, 35 | use_classifier=False) 36 | 37 | mt_model = MTModel_Hybird(src_vocab=src_vocab, 38 | tgt_vocab=tgt_vocab, 39 | embedding_dim=args.mt_emb_dim, 40 | model_dim=args.mt_model_dim, 41 | head_num=args.mt_head_num, 42 | layer_num=args.mt_layer_num, 43 | ffn_dim=args.mt_ffn_dim, 44 | dropout=args.mt_dropout, 45 | att_dropout=args.mt_att_dropout, 46 | ffn_dropout=args.mt_ffn_dropout, 47 | ctx=ctx) 48 | 49 | src_bert.load_parameters(args.bert_model_params_path, ctx=ctx) 50 | mt_model.load_parameters(args.mt_model_params_path, ctx=ctx) 51 | 52 | src_bert_tokenzier = BERTTokenizer(src_vocab) 53 | tgt_bert_tokenzier = BERTTokenizer(tgt_vocab) 54 | 55 | while True: 56 | src = input("input:") 57 | 58 | src = src_bert_tokenzier(src) 59 | src = [src_vocab.cls_token] + \ 60 | src + [src_vocab.sep_token] 61 | 62 | src_valid_len = len(src) 63 | 64 | if args.max_src_len and len(src) > args.max_src_len: 65 | src = src[0:args.max_src_len] 66 | 67 | tgt = [BOS] 68 | 69 | src = src_vocab[src] 70 | tgt = tgt_vocab[tgt] 71 | 72 | tgt = nd.array([tgt], ctx=ctx) 73 | 74 | src = nd.array([src], ctx=ctx) 75 | src_valid_len = nd.array([src_valid_len], ctx=ctx) 76 | src_token_types = nd.zeros_like(src) 77 | 78 | beam_size = 6 79 | 80 | src_bert_outputs = src_bert(src, src_token_types, src_valid_len) 81 | mt_outputs = mt_model(src_bert_outputs, src, tgt) 82 | 83 | src_bert_outputs = nd.broadcast_axes( 84 | src_bert_outputs, axis=0, size=beam_size) 85 | src = nd.broadcast_axes(src, axis=0, size=beam_size) 86 | targets = None 87 | for n in range(0, args.max_tgt_len): 88 | tgt, targets = beam_search( 89 | mt_outputs[:, n, :], targets=targets, max_seq_len=args.max_tgt_len, ctx=ctx, beam_width=beam_size) 90 | mt_outputs = mt_model(src_bert_outputs, src, tgt) 91 | 92 | predict = tgt.asnumpy().tolist() 93 | predict_strs = [] 94 | for pred in predict: 95 | predict_token = [tgt_vocab.idx_to_token[int(idx)] for idx in pred] 96 | predict_str = "" 97 | sub_token = [] 98 | for token in predict_token: 99 | # if token in ["[CLS]", EOS, "[SEP]"]: 100 | # continue 101 | if len(sub_token) == 0: 102 | sub_token.append(token) 103 | elif token[:2] != "##" and len(sub_token) != 0: 104 | predict_str += "".join(sub_token) + " " 105 | sub_token = [] 106 | sub_token.append(token) 107 | else: 108 | if token[:2] == "##": 109 | token = token.replace("##", "") 110 | sub_token.append(token) 111 | if token == EOS: 112 | if len(sub_token) != 0: 113 | predict_str += "".join(sub_token) + " " 114 | break 115 | predict_strs.append(predict_str.replace( 116 | "[SEP]", "").replace("[CLS]", "").replace(EOS, "")) 117 | for predict_str in predict_strs: 118 | print(predict_str) 119 | 120 | 121 | def beam_search(outputs, ctx, targets, max_seq_len, beam_width): 122 | predicts = nd.topk(nd.softmax(outputs, axis=-1), 123 | axis=-1, k=beam_width, ret_typ='both') 124 | 125 | if not targets: 126 | targets = {} 127 | beam_result_idxs = [] 128 | beam_result_score = [] 129 | count = 0 130 | for score, idx in zip(predicts[0][0], predicts[1][0]): 131 | idx = [2] + [int(idx.asscalar())] 132 | beam_result_idxs.append(idx) 133 | beam_result_score.append(score) 134 | targets.update( 135 | {"beam_{}".format(count): {"idx": idx, "score": score}}) 136 | count += 1 137 | 138 | result = [] 139 | for idx in beam_result_idxs: 140 | idx = idx[:max_seq_len] + \ 141 | [0] * (max_seq_len - len(idx)) 142 | result.append(idx) 143 | return nd.array(result, ctx=ctx), targets 144 | 145 | else: 146 | beam_idxs = [] 147 | beam_score = [] 148 | for scores, idxs, target in zip(predicts[0], predicts[1], targets.values()): 149 | last_score = target["score"] 150 | last_idxs = target["idx"] 151 | max_score = 0 152 | max_score_idx = [] 153 | for score, idx in zip(scores, idxs): 154 | if last_score + score > max_score: 155 | max_score = last_score + score 156 | idx = int(idx.asscalar()) 157 | max_score_idx = last_idxs[:] + [idx] 158 | 159 | beam_idxs.append(max_score_idx) 160 | beam_score.append(max_score) 161 | 162 | beam_score, beam_idxs = (list(t) 163 | for t in zip(*sorted(zip(beam_score, beam_idxs), reverse=True))) 164 | 165 | targets = {} 166 | count = 0 167 | for idx, score in zip(beam_idxs, beam_score): 168 | targets.update( 169 | {"beam_{}".format(count): {"idx": idx, "score": score}}) 170 | count += 1 171 | 172 | result = [] 173 | for idx in beam_idxs: 174 | idx = idx[:max_seq_len] + \ 175 | [0] * (max_seq_len - len(idx)) 176 | result.append(idx) 177 | return nd.array(result, ctx=ctx), targets 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument("--model_name", type=str, 183 | default="Bert2transformer_translate") 184 | parser.add_argument("--bert_model", type=str, 185 | default="bert_12_768_12") 186 | parser.add_argument("--src_bert_dataset", type=str, 187 | default=None, required=True) 188 | parser.add_argument("--tgt_bert_dataset", type=str, 189 | default=None, required=True) 190 | parser.add_argument("--bert_model_params_path", type=str, 191 | default=None, required=True) 192 | parser.add_argument("--mt_model_params_path", type=str, 193 | default=None, required=True) 194 | parser.add_argument("--gpu", type=int, 195 | default=1, help='which gpu to use for finetuning. CPU is used if set 0.') 196 | parser.add_argument("--max_src_len", type=int, 197 | default=128) 198 | parser.add_argument("--max_tgt_len", type=int, 199 | default=128) 200 | 201 | # translation model parameters setting 202 | parser.add_argument("--mt_model_dim", type=int, 203 | default=768) 204 | parser.add_argument("--mt_emb_dim", type=int, 205 | default=768) 206 | parser.add_argument("--mt_head_num", type=int, 207 | default=8) 208 | parser.add_argument("--mt_layer_num", type=int, 209 | default=6) 210 | parser.add_argument("--mt_ffn_dim", type=int, 211 | default=2048) 212 | parser.add_argument("--mt_dropout", type=float, 213 | default=0.1) 214 | parser.add_argument("--mt_ffn_dropout", type=float, 215 | default=0.1) 216 | parser.add_argument("--mt_att_dropout", type=float, 217 | default=0.1) 218 | 219 | args = parser.parse_args() 220 | 221 | translate(args) 222 | -------------------------------------------------------------------------------- /scripts/translate_example.sh: -------------------------------------------------------------------------------- 1 | python translate.py \ 2 | --src_bert_dataset=wiki_cn_cased \ 3 | --tgt_bert_dataset=book_corpus_wiki_en_uncased \ 4 | --bert_model_params_path=../checkpoints/src_bert_step_100.params \ 5 | --mt_model_params_path=../checkpoints/mt_step_100.params \ 6 | --max_src_len=50 \ 7 | --max_tgt_len=50 \ 8 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def config_logger(log_path): 5 | # Configuring logger 6 | logger = logging.getLogger() 7 | logger.setLevel(logging.INFO) 8 | fhandler = logging.FileHandler(log_path, mode='w') 9 | shandler = logging.StreamHandler() 10 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 11 | fhandler.setFormatter(formatter) 12 | shandler.setFormatter(formatter) 13 | logger.addHandler(fhandler) 14 | logger.addHandler(shandler) 15 | 16 | return logger 17 | --------------------------------------------------------------------------------