├── README.md ├── bucket_data_helper.py ├── inference_helper.py ├── make_dataset.py ├── result_img └── final.PNG ├── transformer.py └── translation_train.py /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Transformer 2 | Attention Is All You Need 3 | 4 | 5 | ## Translate EN into DE (Train with WMT17) 6 | * newstest2014 BLEU: 28.51 7 | * newstest2015 BLEU: 30.22 8 | * newstest2016 BLEU: 33.88 9 | ![final.PNG](./result_img/final.PNG) 10 | 11 | 12 | ## Paper 13 | * Attention Is All You Need: https://arxiv.org/abs/1706.03762 14 | * Layer Normalization: https://arxiv.org/abs/1607.06450 15 | * Label Smoothing: https://arxiv.org/abs/1512.00567 16 | * Byte-Pair Encoding (BPE): https://arxiv.org/abs/1508.07909 17 | * Beam-Search length penalty: https://arxiv.org/abs/1609.08144 18 | 19 | ## Env 20 | * GTX1080TI 21 | * ubuntu 16.04 22 | * CUDA 8.0 23 | * cuDNN 5.1 24 | * tensorflow 1.4 25 | * numpy 26 | * nltk (bleu) 27 | * tqdm (iteration check bar) 28 | * python 3 29 | 30 | 31 | 32 | ## Dataset 33 | * Preprocessed WMT17 en-de: http://data.statmt.org/wmt17/translation-task/preprocessed/ 34 | * train_set: corpus.tc.[en, de]/corpus.tc.[en, de] 35 | * dev_set: dev.tar/newstest[2014, 2015, 2016].tc.[en, de] 36 | 37 | * learn and apply [Sentences were encoded using byte-pair encoding](https://github.com/SeonbeomKim/Python-Bype_Pair_Encoding) 38 | * -num_merges: 35000 39 | * -final_voca_threshold: 50 40 | * -train_voca_threshold: 1 41 | * make_file: bpe applied documents and voca 42 | 43 | ## Code 44 | * transformer.py 45 | * Transformer graph 46 | 47 | * inference_helper.py 48 | * greedy 49 | * beam (length penalty applied) 50 | * bleu (nltk) 51 | 52 | * bucket_data_helper.py 53 | * bucket으로 구성된 데이터를 쉽게 가져오도록 하는 class 54 | 55 | * make_dataset.py 56 | * generate bucketed bpe2idx dataset for train, valid, test from bpe applied dataset 57 | * need MakeFile of [Sentences were encoded using byte-pair encoding](https://github.com/SeonbeomKim/Python-Bype_Pair_Encoding) 58 | * command: 59 | * make bucket train_set wmt17 60 | ``` 61 | python make_dataset.py 62 | -mode train 63 | -source_input_path path/bpe_wmt17.en (source bpe applied document data) 64 | -source_out_path path/source_idx_wmt17_en.csv (source bpe idx data) 65 | -target_input_path path/bpe_wmt17.de (target bpe applied document data) 66 | -target_out_path path/target_idx_wmt17_de.csv (target bpe idx data) 67 | -bucket_out_path ./bpe_dataset/train_set_wmt17 (bucket trainset from source bpe idx data, target bpe idx data) 68 | -voca_path voca_path/voca_file_name (bpe voca from bpe_learn.py) 69 | ``` 70 | * make bucket valid_set newstest2014 71 | ``` 72 | python make_dataset.py 73 | -mode infer 74 | -source_input_path path/bpe_newstest2014.en (source bpe applied document data) 75 | -source_out_path path/source_idx_newstest2014_en.csv (source bpe idx data) 76 | -target_input_path path/dev.tar/newstest2014.tc.de (target original raw data) 77 | -bucket_out_path ./bpe_dataset/valid_set_newstest2014 (bucket validset from source bpe idx data, target original raw data) 78 | -voca_path voca_path/voca_file_name (bpe voca from bpe_learn.py) 79 | ``` 80 | * make bucket test_set newstest2015 81 | ``` 82 | python make_dataset.py 83 | -mode infer 84 | -source_input_path path/bpe_newstest2015.en (source bpe applied document data) 85 | -source_out_path path/source_idx_newstest2015_en.csv (source bpe idx data) 86 | -target_input_path path/dev.tar/newstest2015.tc.de (target original raw data) 87 | -bucket_out_path ./bpe_dataset/test_set_newstest2015 (bucket testset from source bpe idx data, target original raw data) 88 | -voca_path voca_path/voca_file_name (bpe voca from bpe_learn.py) 89 | ``` 90 | * make bucket test_set newstest2016 91 | ``` 92 | python make_dataset.py 93 | -mode infer 94 | -source_input_path path/bpe_newstest2016.en (source bpe applied document data) 95 | -source_out_path path/source_idx_newstest2016_en.csv (source bpe idx data) 96 | -target_input_path path/dev.tar/newstest2016.tc.de (target original raw data) 97 | -bucket_out_path ./bpe_dataset/test_set_newstest2016 (bucket testset from source bpe idx data, target original raw data) 98 | -voca_path voca_path/voca_file_name (bpe voca from bpe_learn.py) 99 | ``` 100 | * translation_train.py 101 | * en -> de translation train, validation, test 102 | * command 103 | ``` 104 | python translation_train.py 105 | -train_path_2017 ./bpe_dataset/train_set_wmt17 106 | -valid_path_2014 ./bpe_dataset/valid_set_newstest2014 107 | -test_path_2015 ./bpe_dataset/test_set_newstest2015 108 | -test_path_2016 ./bpe_dataset/test_set_newstest2016 109 | -voca_path voca_path/voca_file_name 110 | ``` 111 | 112 | ## Training 113 | 1. [WMT17 Dataset Download](http://data.statmt.org/wmt17/translation-task/preprocessed/) 114 | 2. [apply byte-pair_encoding](https://github.com/SeonbeomKim/Python-Bype_Pair_Encoding) 115 | 3. run make_dataset.py 116 | 4. run translation_train.py 117 | 118 | ## Reference 119 | * https://jalammar.github.io/illustrated-transformer/ 120 | * https://github.com/Kyubyong/transformer 121 | -------------------------------------------------------------------------------- /bucket_data_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class bucket_data: 4 | def __init__(self, data, batch_token = 16000): 5 | self.data = data 6 | self.batch_token = batch_token 7 | 8 | 9 | def get_dataset(self, bucket_shuffle=False, dataset_shuffle=False): 10 | # bucket_shuffle: 버켓별로 셔플. 11 | # dataset_shuffle: data_list 셔플 12 | 13 | data_list = [] 14 | for key in self.data: 15 | batch_size = self.batch_token // sum(key) 16 | 17 | if bucket_shuffle is True: 18 | source, target = self.data[key] 19 | indices = np.arange(len(source)) 20 | np.random.shuffle(indices) 21 | self.data[key] = [source[indices], target[indices]] 22 | 23 | for i in range( int(np.ceil(len(self.data[key][0])/batch_size)) ): 24 | bucket_data = self.data[key] 25 | batch_source = bucket_data[0][i*batch_size : (i+1)*batch_size] 26 | batch_target = bucket_data[1][i*batch_size : (i+1)*batch_size] 27 | data_list.append([batch_source, batch_target]) 28 | 29 | if dataset_shuffle is True: 30 | np.random.shuffle(data_list) 31 | return data_list 32 | -------------------------------------------------------------------------------- /inference_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import nltk 4 | 5 | class greedy: 6 | def __init__(self, sess, model, go_idx): 7 | self.sess = sess 8 | self.model = model 9 | self.go_idx = go_idx 10 | 11 | def decode(self, encoder_input, target_length): 12 | sess = self.sess 13 | model = self.model 14 | encoder_input = np.array(encoder_input, dtype=np.int32) 15 | 16 | input_token = np.zeros([encoder_input.shape[0], target_length+1], np.int32) # go || target_length 17 | input_token[:, 0] = self.go_idx 18 | 19 | encoder_embedding = sess.run(model.encoder_embedding, 20 | { 21 | model.encoder_input:encoder_input, 22 | model.keep_prob:1 23 | } 24 | ) # [N, self.encoder_input_length, self.embedding_size] 25 | 26 | for index in range(target_length): 27 | current_pred = sess.run(model.decoder_pred, 28 | { 29 | model.encoder_embedding:encoder_embedding, 30 | model.decoder_input:input_token[:, :index+1], 31 | model.keep_prob:1 32 | } 33 | ) # [N, target_length+1] 34 | input_token[:, index+1] = current_pred[:, index] 35 | 36 | # [N, target_length] 37 | return input_token[:, 1:] 38 | 39 | 40 | 41 | 42 | class beam: 43 | def __init__(self, sess, model, go_idx, eos_idx, beam_width, length_penalty=0.6): 44 | self.sess = sess 45 | self.model = model 46 | self.go_idx = go_idx 47 | self.eos_idx = eos_idx 48 | self.beam_width = beam_width 49 | self.length_penalty = length_penalty 50 | self.build_beam_graph() 51 | 52 | 53 | def build_beam_graph(self): 54 | model = self.model 55 | 56 | self.time_step = tf.placeholder(tf.int32, name='time_step_placeholder') 57 | 58 | self.tile_encoder_embedding = tf.contrib.seq2seq.tile_batch(model.encoder_embedding, self.beam_width) 59 | tile_current_embedding = model.decoder_embedding[:, self.time_step, :] # [N*beam_width, voca_size] 60 | 61 | top_k_prob, top_k_indices = tf.nn.top_k( 62 | tf.nn.softmax(tile_current_embedding, dim=-1), # [N*beam_width, self.voca_size] 63 | self.beam_width 64 | ) # [N*beam_width, beam_width], [N*beam_width, beam_width] 65 | 66 | # lp(length_penalty) https://arxiv.org/pdf/1609.08144.pdf 67 | Y_length = tf.to_float(self.time_step) + 1 68 | lp = ((5. + Y_length)**self.length_penalty) / ((5. + 1.)**self.length_penalty) 69 | self.top_k_prob = tf.log(tf.reshape(top_k_prob, [-1, 1])) / lp # [N*beam_width*beam_width, 1] 70 | self.top_k_indices = tf.reshape(top_k_indices, [-1, 1]) # [N*beam_width*beam_width, 1] 71 | 72 | 73 | def decode(self, encoder_input, target_length): 74 | sess = self.sess 75 | model = self.model 76 | beam_width = self.beam_width 77 | 78 | encoder_input = np.array(encoder_input, dtype=np.int32) 79 | 80 | N = encoder_input.shape[0] 81 | for_indexing = np.arange(N).reshape(-1, 1) * beam_width * beam_width # [N, 1] 82 | 83 | # for eos check, one-initialize 84 | is_previous_eos = np.ones([N*beam_width*beam_width, 1], dtype=np.float32) 85 | 86 | input_token = np.zeros([N*beam_width, target_length+1], np.int32) # go || target_length 87 | input_token[:, 0] = self.go_idx 88 | 89 | encoder_embedding = sess.run(self.tile_encoder_embedding, 90 | { 91 | model.encoder_input:encoder_input, 92 | model.keep_prob:1, 93 | } 94 | ) # [N*beam_width, self.encoder_input_length, self.embedding_size] 95 | 96 | for index in range(target_length): 97 | prob, indices = sess.run([self.top_k_prob, self.top_k_indices], 98 | { 99 | model.encoder_embedding:encoder_embedding, 100 | model.decoder_input:input_token[:, :index+1], 101 | model.keep_prob:1, 102 | self.time_step:index, 103 | } 104 | ) # each [N*beam_width*beam_width, 1] 105 | 106 | if index == 0: 107 | prob = prob.reshape([-1, beam_width, beam_width]) # [N, beam_width, beam_width] 108 | prob = prob.transpose([0, 2, 1]) # [N, beam_width, beam_width] 109 | prob = prob.reshape([-1, 1]) # [N*beam_width*beam_width, 1] 110 | indices = indices.reshape([-1, beam_width, beam_width]) # # [N, beam_width, beam_width] 111 | indices = indices.transpose([0, 2, 1]) # [N, beam_width, beam_width] 112 | indices = indices.reshape([-1, 1]) # [N*beam_width*beam_width, 1] 113 | input_token[:, 1] = indices[np.arange(0, N*beam_width*beam_width, beam_width)].reshape(-1) # [N*beam_width] 114 | # save 115 | prob_list = prob # [N*beam_width*beam_width, 1] 116 | indices_list = indices # [N*beam_width*beam_width, 1] 117 | 118 | else: 119 | # 이전 output 중에 한번이라도 eos가 있으면 prob 반영 안함. eos가 없으면 1, 있으면 0 120 | is_previous_eos *= (indices_list[:, -1:] != self.eos_idx) # [N*beam_width*beam_width, 1] 121 | masked_prob = prob * is_previous_eos # [N*beam_width*beam_width, 1] 122 | prob_list += masked_prob # [N*beam_width*beam_width, 1] 123 | indices_list = np.concatenate((indices_list, indices), axis=1) # [N*beam_width*beam_width, index+1] 124 | 125 | batch_split_prob_list = prob_list.reshape([-1, beam_width*beam_width]) # [N, beam_width*beam_width] 126 | top_k_indices = np.argsort(-batch_split_prob_list)[:, :beam_width] # -붙여야 내림차순 정렬. [N, beam_width] 127 | top_k_indices += for_indexing # [N, beam_width] 128 | top_k_indices = top_k_indices.reshape(-1) # [N*beam_width] 129 | 130 | is_previous_eos = is_previous_eos[top_k_indices] # [N*beam_width, 1] 131 | top_k_prob = prob_list[top_k_indices] # [N*beam_width, 1] 132 | indices_list = indices_list[top_k_indices] # [N*beam_width, index+1] 133 | input_token[:, 1:index+2] = indices_list 134 | 135 | if index < target_length-1: 136 | # save 137 | is_previous_eos = np.tile(is_previous_eos, beam_width) # [N*beam_width, beam_width] 138 | is_previous_eos = is_previous_eos.reshape(N*beam_width*beam_width, 1) # [N*beam_width*beam_width, 1] 139 | indices_list = np.tile(indices_list, beam_width) # [N*beam_width, beam_width*(index+1)] 140 | indices_list = indices_list.reshape(N*beam_width*beam_width, -1) # [N*beam_width*beam_width, (index+1)] 141 | prob_list = np.tile(top_k_prob, beam_width) # [N*beam_width, beam_width] 142 | prob_list = prob_list.reshape(N*beam_width*beam_width, 1) # [N*beam_width*beam_width, 1] 143 | 144 | indices_list = indices_list.reshape(N, beam_width, target_length) 145 | 146 | # [N, target_length] 147 | return indices_list[:, 0, :] # batch마다 가장 probability가 높은 결과 리턴. 148 | 149 | 150 | 151 | 152 | class utils: 153 | def __init__(self): 154 | pass 155 | 156 | def bleu(self, target, pred): 157 | smoothing = nltk.translate.bleu_score.SmoothingFunction() 158 | score = nltk.translate.bleu_score.corpus_bleu(target, pred, smoothing_function=smoothing.method0) 159 | #score = nltk.translate.bleu_score.corpus_bleu(target, pred, smoothing_function=smoothing.method4) 160 | return score 161 | 162 | 163 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import csv 4 | import os 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | '-mode', 10 | help="train or infer", 11 | choices=['train', 'infer'], 12 | required=True, 13 | ) 14 | parser.add_argument( 15 | '-source_input_path', 16 | help="source document path", 17 | required=True, 18 | ) 19 | parser.add_argument( 20 | '-source_out_path', 21 | help="preprocessed source output path", 22 | required=True, 23 | ) 24 | parser.add_argument( 25 | '-target_input_path', 26 | help="target document path", 27 | required=True, 28 | ) 29 | parser.add_argument( 30 | '-target_out_path', 31 | help="preprocessed target output path", 32 | required=False 33 | ) 34 | parser.add_argument( 35 | '-bucket_out_path', 36 | help="bucket output path", 37 | required=True, 38 | ) 39 | parser.add_argument( 40 | '-voca_path', 41 | help="Vocabulary_path", 42 | required=True 43 | ) 44 | 45 | args = parser.parse_args() 46 | 47 | mode = args.mode 48 | source_input_path = args.source_input_path 49 | source_out_path = args.source_out_path 50 | target_input_path = args.target_input_path 51 | target_out_path = args.target_out_path 52 | bucket_out_path = args.bucket_out_path 53 | voca_path = args.voca_path 54 | 55 | 56 | 57 | def read_voca(path): 58 | sorted_voca = [] 59 | with open(path, 'r', encoding='utf-8') as f: 60 | for bpe_voca in f: 61 | bpe_voca = bpe_voca.strip() 62 | if bpe_voca: 63 | bpe_voca = bpe_voca.split() 64 | sorted_voca.append(bpe_voca) 65 | return sorted_voca 66 | 67 | 68 | 69 | def make_bpe2idx(voca): 70 | bpe2idx = {'

':0, '':1, '':2, '':3} 71 | idx = 4 72 | 73 | for word, _ in voca: 74 | bpe2idx[word] = idx 75 | idx += 1 76 | 77 | return bpe2idx 78 | 79 | 80 | 81 | def bpe2idx_out_csv(data_path, out_path, bpe2idx, info='source'): #info: 'source' or 'target' 82 | print('documents to idx csv', data_path, '->', out_path) 83 | 84 | o = open(out_path, 'w', newline='', encoding='utf-8') 85 | wr = csv.writer(o) 86 | 87 | with open(data_path, 'r', encoding='utf-8') as f: 88 | documents = f.readlines() 89 | 90 | for i in tqdm(range(len(documents)), ncols=50): 91 | sentence = documents[i] 92 | 93 | # bpe2idx 94 | if info == 'target': 95 | row_idx = [bpe2idx['']] 96 | else: 97 | row_idx = [] 98 | 99 | for word in sentence.strip().split(): 100 | if word in bpe2idx: 101 | row_idx.append(bpe2idx[word]) 102 | else: 103 | row_idx.append(bpe2idx['']) ## 1 104 | row_idx.append(bpe2idx['']) ## eos:3 105 | 106 | wr.writerow(row_idx) 107 | 108 | o.close() 109 | print('saved', out_path, '\n') 110 | 111 | 112 | 113 | def _make_bucket_dataset(source_path, target_path, out_path, bucket, pad_idx, file_mode='w', is_trainset=True): 114 | print('make bucket dataset') 115 | print('source:', source_path, 'target:', target_path) 116 | 117 | if not os.path.exists(out_path): 118 | os.makedirs(out_path) 119 | 120 | # 저장시킬 object 생성 121 | source_open_list = [] 122 | target_open_list = [] 123 | for bucket_size in bucket: 124 | o_s = open(os.path.join(out_path, 'source_'+str(bucket_size)+'.csv'), file_mode, newline='') 125 | o_s_csv = csv.writer(o_s) 126 | source_open_list.append((o_s, o_s_csv)) 127 | 128 | if is_trainset: 129 | o_t = open(os.path.join(out_path, 'target_'+str(bucket_size)+'.csv'), file_mode, newline='') 130 | o_t_csv = csv.writer(o_t) 131 | target_open_list.append((o_t, o_t_csv)) 132 | else: 133 | o_t = open(os.path.join(out_path, 'target_'+str(bucket_size)+'.txt'), file_mode, encoding='utf-8') 134 | target_open_list.append(o_t) 135 | 136 | 137 | 138 | with open(source_path, 'r') as s: 139 | source = s.readlines() 140 | 141 | if is_trainset: 142 | with open(target_path, 'r') as t: 143 | target = t.readlines() 144 | else: 145 | with open(target_path, 'r', encoding='utf-8') as t: 146 | target = t.readlines() 147 | 148 | 149 | for i in tqdm(range(len(source)), ncols=50): 150 | source_sentence = np.array(source[i].strip().split(','), dtype=np.int32) 151 | if is_trainset: 152 | target_sentence = np.array(target[i].strip().split(','), dtype=np.int32) 153 | else: 154 | target_sentence = target[i] 155 | 156 | 157 | for bucket_index, bucket_size in enumerate(bucket): 158 | source_size, target_size = bucket_size 159 | # 버켓에 없는것은 데이터는 제외. 160 | 161 | if is_trainset: 162 | if len(source_sentence) <= source_size and len(target_sentence) <= target_size: # (1,2) <= (10, 40) 163 | source_sentence = np.pad( 164 | source_sentence, 165 | (0, source_size-len(source_sentence)), 166 | 'constant', 167 | constant_values = pad_idx# bpe2idx['

'] # pad value 168 | ) 169 | target_sentence = np.pad( 170 | target_sentence, 171 | (0, target_size+1-len(target_sentence)), # [0:-1]: decoder_input, [1:]: decoder_target 이므로 +1 해줌. 172 | 'constant', 173 | constant_values = pad_idx # bpe2idx['

'] # pad value 174 | ) 175 | source_open_list[bucket_index][1].writerow(source_sentence) 176 | target_open_list[bucket_index][1].writerow(target_sentence) 177 | break 178 | 179 | else: 180 | if len(source_sentence) <= source_size: 181 | source_sentence = np.pad( 182 | source_sentence, 183 | (0, source_size-len(source_sentence)), 184 | 'constant', 185 | constant_values = pad_idx# bpe2idx['

'] # pad value 186 | ) 187 | source_open_list[bucket_index][1].writerow(source_sentence) 188 | target_open_list[bucket_index].write(target_sentence) 189 | break 190 | 191 | 192 | # close object 193 | for i in range(len(bucket)): 194 | source_open_list[i][0].close() 195 | 196 | if is_trainset: 197 | target_open_list[i][0].close() 198 | else: 199 | target_open_list[i].close() 200 | print('saved', out_path) 201 | 202 | 203 | 204 | def make_bucket_dataset(data_path, idx_out_path, bucket_out_path, bucket, bpe2idx, file_mode='w', is_trainset=True): 205 | print('start make_bucket_dataset', 'is_trainset:', is_trainset) 206 | 207 | bpe2idx_out_csv( 208 | data_path=data_path['source'], 209 | out_path=idx_out_path['source'], 210 | bpe2idx=bpe2idx, 211 | info='source' 212 | ) 213 | 214 | if is_trainset: 215 | bpe2idx_out_csv( 216 | data_path=data_path['target'], 217 | out_path=idx_out_path['target'], 218 | bpe2idx=bpe2idx, 219 | info='target' 220 | ) 221 | 222 | # padding and bucketing 223 | _make_bucket_dataset( 224 | source_path=idx_out_path['source'], 225 | target_path=idx_out_path['target'], 226 | out_path=bucket_out_path, 227 | bucket=bucket, 228 | pad_idx=bpe2idx['

'], 229 | file_mode=file_mode, 230 | is_trainset=is_trainset 231 | ) 232 | 233 | else: 234 | # padding and bucketing 235 | _make_bucket_dataset( 236 | source_path=idx_out_path['source'], 237 | target_path=data_path['target'], 238 | out_path=bucket_out_path, 239 | bucket=bucket, 240 | pad_idx=bpe2idx['

'], 241 | file_mode=file_mode, 242 | is_trainset=is_trainset 243 | ) 244 | 245 | print('\n\n') 246 | 247 | 248 | 249 | voca = read_voca(voca_path) 250 | bpe2idx = make_bpe2idx(voca) 251 | 252 | 253 | 254 | 255 | if mode == 'train': 256 | data_path = {'source':source_input_path, 'target':target_input_path} 257 | idx_out_path = {'source':source_out_path, 'target':target_out_path} 258 | 259 | #bucket (source, target) 260 | train_bucket = [(i*5, i*5 + j*10) for i in range(1, 31) for j in range(4)]# [(5, 5), (5, 15), .., (5, 35), ... , (150, 150), .., (150, 180)] 261 | print('train_bucket\n', train_bucket,'\n') 262 | 263 | make_bucket_dataset( 264 | data_path, 265 | idx_out_path, 266 | bucket_out_path, 267 | train_bucket, 268 | bpe2idx 269 | ) 270 | 271 | elif mode == 'infer': 272 | data_path = {'source':source_input_path, 'target':target_input_path} 273 | idx_out_path = {'source':source_out_path} 274 | 275 | #bucket (source, target) 276 | infer_bucket = [(i*5, i*5+50) for i in range(1, 31)] # [(5, 55), (10, 60), ..., (150, 200)] 277 | print('infer_bucket\n', infer_bucket,'\n') 278 | 279 | make_bucket_dataset( 280 | data_path, 281 | idx_out_path, 282 | bucket_out_path, 283 | infer_bucket, 284 | bpe2idx, 285 | is_trainset=False 286 | ) 287 | 288 | 289 | 290 | ''' 291 | # make trainset 292 | data_path = {'source':'./bpe_dataset/bpe_wmt17.en', 'target':'./bpe_dataset/bpe_wmt17.de'} 293 | idx_out_path = {'source':'./bpe_dataset/source_idx_wmt17_en.csv', 'target':'./bpe_dataset/target_idx_wmt17_de.csv'} 294 | bucket_out_path = './bpe_dataset/train_set_wmt17/' 295 | make_bucket_dataset(data_path, idx_out_path, bucket_out_path, train_bucket, bpe2idx) 296 | 297 | # make validset 298 | data_path = {'source':'./bpe_dataset/bpe_newstest2014.en', 'target':'./dataset/dev.tar/newstest2014.tc.de'} 299 | idx_out_path = {'source':'./bpe_dataset/source_idx_newstest2014_en.csv'} 300 | bucket_out_path = './bpe_dataset/valid_set_newstest2014/' 301 | make_bucket_dataset(data_path, idx_out_path, bucket_out_path, infer_bucket, bpe2idx, is_trainset=False) 302 | 303 | # make testset 304 | data_path = {'source':'./bpe_dataset/bpe_newstest2015.en', 'target':'./dataset/dev.tar/newstest2015.tc.de'} 305 | idx_out_path = {'source':'./bpe_dataset/source_idx_newstest2015_en.csv'} 306 | bucket_out_path = './bpe_dataset/test_set_newstest2015/' 307 | make_bucket_dataset(data_path, idx_out_path, bucket_out_path, infer_bucket, bpe2idx, is_trainset=False) 308 | 309 | # make testset 310 | data_path = {'source':'./bpe_dataset/bpe_newstest2016.en', 'target':'./dataset/dev.tar/newstest2016.tc.de'} 311 | idx_out_path = {'source':'./bpe_dataset/source_idx_newstest2016_en.csv'} 312 | bucket_out_path = './bpe_dataset/test_set_newstest2016/' 313 | make_bucket_dataset(data_path, idx_out_path, bucket_out_path, infer_bucket, bpe2idx, is_trainset=False) 314 | ''' -------------------------------------------------------------------------------- /result_img/final.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonbeomKim/TensorFlow-Transformer/20834412313a3290b1805ca47b0f077f8fb59e96/result_img/final.PNG -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | #https://arxiv.org/abs/1706.03762 Attention Is All You Need(Transformer) 2 | #https://arxiv.org/abs/1607.06450 Layer Normalization 3 | #https://arxiv.org/abs/1512.00567 Label Smoothing 4 | 5 | import tensorflow as tf #version 1.4 6 | import numpy as np 7 | import os 8 | #tf.set_random_seed(787) 9 | 10 | class Transformer: 11 | def __init__(self, sess, voca_size, embedding_size, is_embedding_scale, PE_sequence_length, 12 | encoder_decoder_stack, multihead_num, eos_idx, pad_idx, label_smoothing): 13 | 14 | self.sess = sess 15 | self.voca_size = voca_size 16 | self.embedding_size = embedding_size 17 | self.is_embedding_scale = is_embedding_scale # True or False 18 | self.PE_sequence_length = PE_sequence_length 19 | self.encoder_decoder_stack = encoder_decoder_stack 20 | self.multihead_num = multihead_num 21 | self.eos_idx = eos_idx # <'eos'> symbol index 22 | self.pad_idx = pad_idx # <'pad'> symbol index 23 | self.label_smoothing = label_smoothing # if 1.0, then one-hot encooding 24 | self.PE = tf.convert_to_tensor(self.positional_encoding(), dtype=tf.float32) #[self.PE_sequence_length, self.embedding_siz] #slice해서 쓰자. 25 | 26 | 27 | with tf.name_scope("placeholder"): 28 | self.lr = tf.placeholder(tf.float32) 29 | self.encoder_input = tf.placeholder(tf.int32, [None, None], name='encoder_input') 30 | self.encoder_input_length = tf.shape(self.encoder_input)[1] 31 | 32 | self.decoder_input = tf.placeholder(tf.int32, [None, None], name='decoder_input') #'go a b c eos pad' 33 | self.decoder_input_length = tf.shape(self.decoder_input)[1] 34 | 35 | self.target = tf.placeholder(tf.int32, [None, None], name='target') # 'a b c eos pad pad' 36 | 37 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 38 | # dropout (each sublayers before add and norm) and (sums of the embeddings and the PE) and (attention) 39 | 40 | 41 | with tf.name_scope("embedding_table"): 42 | with tf.device('/cpu:0'): 43 | zero = tf.zeros([1, self.embedding_size], dtype=tf.float32) # for padding 44 | #embedding_table = tf.Variable(tf.random_uniform([self.voca_size-1, self.embedding_size], -1, 1)) 45 | embedding_table = tf.get_variable( # https://github.com/tensorflow/models/blob/master/official/transformer/model/embedding_layer.py 46 | 'embedding_table', 47 | [self.voca_size-1, self.embedding_size], 48 | initializer=tf.random_normal_initializer(0., self.embedding_size ** -0.5)) 49 | front, end = tf.split(embedding_table, [self.pad_idx, self.voca_size-1-self.pad_idx]) 50 | self.embedding_table = tf.concat((front, zero, end), axis=0) # [self.voca_size, self.embedding_size] 51 | 52 | 53 | with tf.name_scope('encoder'): 54 | encoder_input_embedding, encoder_input_mask = self.embedding_and_PE(self.encoder_input, self.encoder_input_length) 55 | self.encoder_embedding = self.encoder(encoder_input_embedding, encoder_input_mask) 56 | 57 | 58 | with tf.name_scope('decoder'): 59 | decoder_input_embedding, decoder_input_mask = self.embedding_and_PE(self.decoder_input, self.decoder_input_length) # decoder_input은 go 붙어있어야함. 60 | self.decoder_embedding, self.decoder_pred = self.decoder(decoder_input_embedding, self.encoder_embedding, decoder_input_mask) 61 | 62 | 63 | with tf.name_scope('train_cost'): 64 | # target mask ( masking pad of target ) 65 | self.target_pad_mask = tf.cast( #sequence_mask처럼 생성됨 66 | tf.not_equal(self.target, self.pad_idx), 67 | dtype=tf.float32 68 | ) # [N, target_length] (include eos) 69 | 70 | # make smoothing target one hot vector 71 | self.target_one_hot = tf.one_hot( 72 | self.target, 73 | depth=self.voca_size, 74 | on_value = (1.0-self.label_smoothing) + (self.label_smoothing / self.voca_size), # tf.float32 75 | off_value = (self.label_smoothing / self.voca_size), # tf.float32 76 | dtype= tf.float32 77 | ) # [N, self.target_length, self.voca_size] 78 | 79 | # calc train_cost 80 | self.train_cost = tf.nn.softmax_cross_entropy_with_logits( 81 | labels = self.target_one_hot, 82 | logits = self.decoder_embedding 83 | ) # [N, self.target_length] 84 | self.train_cost *= self.target_pad_mask # except pad 85 | self.train_cost = tf.reduce_sum(self.train_cost) / tf.reduce_sum(self.target_pad_mask) 86 | 87 | 88 | with tf.name_scope('optimizer'): 89 | optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.9, beta2=0.98, epsilon=1e-9) 90 | self.minimize = optimizer.minimize(self.train_cost) 91 | 92 | 93 | with tf.name_scope("saver"): 94 | self.saver = tf.train.Saver(max_to_keep=10000) 95 | 96 | sess.run(tf.global_variables_initializer()) 97 | 98 | 99 | 100 | def embedding_and_PE(self, data, data_length): 101 | # data: [N, data_length] 102 | 103 | # embedding lookup and scale 104 | with tf.device('/cpu:0'): 105 | embedding = tf.nn.embedding_lookup( 106 | self.embedding_table, 107 | data 108 | ) # [N, data_length, self.embedding_size] 109 | if self.is_embedding_scale is True: 110 | embedding *= self.embedding_size**0.5 111 | 112 | embedding_mask = tf.expand_dims( 113 | tf.cast(tf.not_equal(data, self.pad_idx), dtype=tf.float32), # [N, data_length] 114 | axis=-1 115 | ) # [N, data_length, 1] 116 | 117 | # Add Position Encoding 118 | embedding += self.PE[:data_length, :] 119 | 120 | # pad masking (set 0 PE added pad position) 121 | embedding *= embedding_mask 122 | 123 | # Drop out 124 | embedding = tf.nn.dropout(embedding, keep_prob=self.keep_prob) 125 | return embedding, embedding_mask 126 | 127 | 128 | 129 | def encoder(self, encoder_input_embedding, encoder_input_mask): 130 | # encoder_input_embedding: [N, self.encoder_input_length, self.embedding_size] , pad mask applied 131 | # encoder_input_mask: [N, self.encoder_input_length, 1] 132 | 133 | # mask 134 | encoder_self_attention_mask = tf.tile( 135 | tf.matmul(encoder_input_mask, tf.transpose(encoder_input_mask, [0, 2, 1])), # [N, encoder_input_length, encoder_input_length] 136 | [self.multihead_num, 1, 1] 137 | ) # [self.multihead_num*N, encoder_input_length, encoder_input_length] 138 | 139 | for i in range(self.encoder_decoder_stack): #6 140 | # Multi-Head Attention 141 | Multihead_add_norm = self.multi_head_attention_add_norm( 142 | query=encoder_input_embedding, 143 | key_value=encoder_input_embedding, 144 | score_mask=encoder_self_attention_mask, 145 | output_mask=encoder_input_mask, 146 | activation=None, 147 | name='encoder'+str(i) 148 | ) # [N, self.encoder_input_length, self.embedding_size] 149 | 150 | # Feed Forward 151 | encoder_input_embedding = self.dense_add_norm( 152 | Multihead_add_norm, 153 | self.embedding_size, 154 | output_mask=encoder_input_mask, # set 0 bias added pad position 155 | activation=tf.nn.relu, 156 | name='encoder_dense'+str(i) 157 | ) # [N, self.encoder_input_length, self.embedding_size] 158 | 159 | return encoder_input_embedding # [N, self.encoder_input_length, self.embedding_size] 160 | 161 | 162 | 163 | def decoder(self, decoder_input_embedding, encoder_embedding, decoder_input_mask): 164 | # decoder_input_embedding: [N, self.decoder_input_length, self.embedding_size] , pad mask applied 165 | # encoder_embedding: [N, self.encoder_input_length, self.embedding_size] , pad mask applied 166 | # decoder_input_mask: [N, self.decoder_input_length, 1] 167 | 168 | # mask 169 | pad_of_encoder_embedding = tf.transpose( 170 | tf.reduce_sum(tf.abs(encoder_embedding), axis=-1, keep_dims=True), # [N, self.encoder_input_length, 1] 171 | [0, 2, 1] 172 | ) # [N, 1, encoder_input_length] 173 | decoder_ED_attention_mask = tf.tile( 174 | tf.cast(tf.not_equal(pad_of_encoder_embedding, self.pad_idx), dtype=tf.float32), # [N, 1, encoder_input_length] 175 | [self.multihead_num, 1, 1] 176 | ) # [self.multihead_num*N, 1, encoder_input_length] 177 | decoder_self_attention_mask = tf.sequence_mask( 178 | tf.range(start=1, limit=self.decoder_input_length+1), # [start, limit) 179 | maxlen=self.decoder_input_length,#.eval(session=sess), 180 | dtype=tf.float32 181 | ) # [decoder_input_length, decoder_input_length] 182 | 183 | 184 | for i in range(self.encoder_decoder_stack): 185 | # Masked Multi-Head Attention 186 | Masked_Multihead_add_norm = self.multi_head_attention_add_norm( 187 | query=decoder_input_embedding, 188 | key_value=decoder_input_embedding, 189 | score_mask=decoder_self_attention_mask, 190 | output_mask=decoder_input_mask, 191 | activation=None, 192 | name='self_attention_decoder'+str(i) 193 | ) 194 | 195 | # Multi-Head Attention(Encoder Decoder Attention) 196 | ED_Multihead_add_norm = self.multi_head_attention_add_norm( 197 | query=Masked_Multihead_add_norm, 198 | key_value=encoder_embedding, 199 | score_mask=decoder_ED_attention_mask, 200 | output_mask=decoder_input_mask, 201 | activation=None, 202 | name='ED_attention_decoder'+str(i) 203 | ) 204 | 205 | #Feed Forward 206 | decoder_input_embedding = self.dense_add_norm( 207 | ED_Multihead_add_norm, 208 | units=self.embedding_size, 209 | output_mask=decoder_input_mask, # set 0 bias added pad position 210 | activation=tf.nn.relu, 211 | name='decoder_dense'+str(i) 212 | ) # [N, self.decoder_input_length, self.embedding_size] 213 | 214 | 215 | # share weight, input embeddings, per-softmax layer 216 | decoder_embedding = tf.reshape( 217 | decoder_input_embedding, 218 | [-1, self.embedding_size] 219 | ) # [N*self.decoder_input_length, self.embedding_size] 220 | decoder_embedding = tf.matmul( 221 | decoder_embedding, 222 | tf.transpose(self.embedding_table) # [self.embedding_size, self.voca_size] 223 | ) # [N*self.decoder_input_length, self.voca_size] 224 | decoder_embedding = tf.reshape( 225 | decoder_embedding, 226 | [-1, self.decoder_input_length, self.voca_size] 227 | ) # [N, self.decoder_input_length, self.voca_size] 228 | 229 | decoder_pred = tf.argmax( 230 | decoder_embedding, 231 | axis=-1, 232 | output_type=tf.int32 233 | ) # [N, self.decoder_input_length] 234 | 235 | return decoder_embedding, decoder_pred 236 | 237 | 238 | 239 | def multi_head_attention_add_norm(self, query, key_value, score_mask=None, output_mask=None, activation=None, name=None): 240 | # Sharing Variables 241 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 242 | # for문으로 self.multihead_num번 돌릴 필요 없이 embedding_size 만큼 만들고 self.multihead_num등분해서 연산하면 됨. 243 | V = tf.layers.dense( # layers dense는 배치(N)별로 동일하게 연산됨. 244 | key_value, 245 | units=self.embedding_size, 246 | activation=activation, 247 | use_bias=False, 248 | name='V' 249 | ) # [N, key_value_sequence_length, self.embedding_size] 250 | K = tf.layers.dense( 251 | key_value, 252 | units=self.embedding_size, 253 | activation=activation, 254 | use_bias=False, 255 | name='K' 256 | ) # [N, key_value_sequence_length, self.embedding_size] 257 | Q = tf.layers.dense( 258 | query, 259 | units=self.embedding_size, 260 | activation=activation, 261 | use_bias=False, 262 | name='Q' 263 | ) # [N, query_sequence_length, self.embedding_size] 264 | 265 | # linear 결과를 self.multihead_num등분하고 연산에 지장을 주지 않도록 batch화 시킴. 266 | # https://github.com/Kyubyong/transformer 참고. 267 | # split: [N, key_value_sequence_length, self.embedding_size/self.multihead_num]이 self.multihead_num개 존재 268 | V = tf.concat(tf.split(V, self.multihead_num, axis=-1), axis=0) # [self.multihead_num*N, key_value_sequence_length, self.embedding_size/self.multihead_num] 269 | K = tf.concat(tf.split(K, self.multihead_num, axis=-1), axis=0) # [self.multihead_num*N, key_value_sequence_length, self.embedding_size/self.multihead_num] 270 | Q = tf.concat(tf.split(Q, self.multihead_num, axis=-1), axis=0) # [self.multihead_num*N, query_sequence_length, self.embedding_size/self.multihead_num] 271 | 272 | 273 | # Q * (K.T) and scaling , [self.multihead_num*N, query_sequence_length, key_value_sequence_length] 274 | score = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) / tf.sqrt(self.embedding_size/self.multihead_num) 275 | 276 | # masking 277 | if score_mask is not None: 278 | score *= score_mask # zero mask 279 | score += ((score_mask-1) * 1e+9) # -inf mask 280 | # decoder self_attention: 281 | # 1 0 0 282 | # 1 1 0 283 | # 1 1 1 형태로 마스킹 284 | 285 | # encoder_self_attention 286 | # if encoder_input_data: i like 287 | # 1 1 0 288 | # 1 1 0 289 | # 0 0 0 형태로 마스킹 290 | 291 | # ED_attention 292 | # if encoder_input_data: i like 293 | # 1 1 0 294 | # 1 1 0 295 | # 1 1 0 형태로 마스킹 296 | 297 | softmax = tf.nn.softmax(score, dim=2) # [self.multihead_num*N, query_sequence_length, key_value_sequence_length] 298 | 299 | # Attention dropout 300 | # https://arxiv.org/abs/1706.03762v4 => v4 paper에는 attention dropout 하라고 되어 있음. 301 | softmax = tf.nn.dropout(softmax, keep_prob=self.keep_prob) 302 | 303 | # Attention weighted sum 304 | attention = tf.matmul(softmax, V) # [self.multihead_num*N, query_sequence_length, self.embedding_size/self.multihead_num] 305 | 306 | # split: [N, query_sequence_length, self.embedding_size/self.multihead_num]이 self.multihead_num개 존재 307 | concat = tf.concat(tf.split(attention, self.multihead_num, axis=0), axis=-1) # [N, query_sequence_length, self.embedding_size] 308 | 309 | # Linear 310 | Multihead = tf.layers.dense( 311 | concat, 312 | units=self.embedding_size, 313 | activation=activation, 314 | use_bias=False, 315 | name='linear' 316 | ) # [N, query_sequence_length, self.embedding_size] 317 | 318 | if output_mask is not None: 319 | Multihead *= output_mask 320 | 321 | # residual Drop Out 322 | Multihead = tf.nn.dropout(Multihead, keep_prob=self.keep_prob) 323 | # Add 324 | Multihead += query 325 | # Layer Norm 326 | Multihead = tf.contrib.layers.layer_norm(Multihead, begin_norm_axis=2) # [N, query_sequence_length, self.embedding_size] 327 | 328 | return Multihead 329 | 330 | 331 | 332 | def dense_add_norm(self, embedding, units, output_mask=None, activation=None, name=None): 333 | # FFN(x) = max(0, x*W1+b1)*W2 + b2 334 | # Sharing Variables 335 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 336 | inner_layer = tf.layers.dense( 337 | embedding, 338 | units=4*self.embedding_size, #bert paper 339 | activation=activation # relu 340 | ) # [N, self.decoder_input_length, 4*self.embedding_size] 341 | dense = tf.layers.dense( 342 | inner_layer, 343 | units=units, 344 | activation=None 345 | ) # [N, self.decoder_input_length, self.embedding_size] 346 | 347 | if output_mask is not None: 348 | dense *= output_mask # set 0 bias added pad position 349 | 350 | # Drop out 351 | dense = tf.nn.dropout(dense, keep_prob=self.keep_prob) 352 | # Add 353 | dense += embedding 354 | # Layer Norm 355 | dense = tf.contrib.layers.layer_norm(dense, begin_norm_axis=2) 356 | 357 | return dense 358 | 359 | 360 | 361 | def positional_encoding(self): 362 | PE = np.zeros([self.PE_sequence_length, self.embedding_size], np.float32) 363 | for pos in range(self.PE_sequence_length): #충분히 크게 만들어두고 slice 해서 쓰자. 364 | sin, cos = [], [] 365 | for i in range(0, self.embedding_size//2): 366 | sin.append(np.sin( pos / np.power(10000, 2*i/self.embedding_size) ).astype(np.float32)) 367 | cos.append(np.cos( pos / np.power(10000, 2*i/self.embedding_size) ).astype(np.float32)) 368 | PE[pos] = np.concatenate((sin,cos)) 369 | return PE #[self.PE_sequence_length, self.embedding_siz] 370 | 371 | ''' 372 | # 기존 373 | def positional_encoding(self): 374 | PE = np.zeros([self.PE_sequence_length, self.embedding_size]) 375 | for pos in range(self.PE_sequence_length): #충분히 크게 만들어두고 slice 해서 쓰자. 376 | for i in range(self.embedding_size//2): 377 | PE[pos, 2*i] = np.sin( pos / np.power(10000, 2*i/self.embedding_size) ) 378 | PE[pos, 2*i+1] = np.cos( pos / np.power(10000, 2*i/self.embedding_size) ) 379 | 380 | return PE #[self.PE_sequence_length, self.embedding_siz] 381 | ''' -------------------------------------------------------------------------------- /translation_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import transformer 4 | import inference_helper 5 | import bucket_data_helper 6 | import os 7 | from tqdm import tqdm 8 | import warnings 9 | import argparse 10 | 11 | warnings.filterwarnings('ignore') 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | '-train_path_2017', 17 | help="train_path", 18 | required=True 19 | ) 20 | parser.add_argument( 21 | '-valid_path_2014', 22 | help="valid_path", 23 | required=True 24 | ) 25 | parser.add_argument( 26 | '-test_path_2015', 27 | help="test_path", 28 | required=True 29 | ) 30 | parser.add_argument( 31 | '-test_path_2016', 32 | help="test_path", 33 | required=True 34 | ) 35 | parser.add_argument( 36 | '-voca_path', 37 | help="Vocabulary_path", 38 | required=True 39 | ) 40 | 41 | args = parser.parse_args() 42 | train_path_2017 = args.train_path_2017 43 | valid_path_2014 = args.valid_path_2014 44 | test_path_2015 = args.test_path_2015 45 | test_path_2016 = args.test_path_2016 46 | voca_path = args.voca_path 47 | 48 | 49 | saver_path = './saver/' 50 | tensorboard_path = './tensorboard/' 51 | 52 | 53 | def read_voca(path): 54 | sorted_voca = [] 55 | with open(path, 'r', encoding='utf-8') as f: 56 | for bpe_voca in f: 57 | bpe_voca = bpe_voca.strip() 58 | if bpe_voca: 59 | bpe_voca = bpe_voca.split() 60 | sorted_voca.append(bpe_voca) 61 | return sorted_voca 62 | 63 | 64 | def _read_csv(path): 65 | data = np.loadtxt( 66 | path, 67 | delimiter=",", 68 | dtype=np.int32, 69 | ndmin=2 # csv가 1줄이여도 2차원으로 출력. 70 | ) 71 | return data 72 | 73 | def _read_txt(path): 74 | with open(path, 'r', encoding='utf-8') as f: 75 | documents = f.readlines() 76 | 77 | data = [] 78 | for sentence in documents: 79 | data.append(sentence.strip().split()) 80 | return data 81 | 82 | 83 | def _get_bucket_name(path): 84 | bucket = {} 85 | for filename in os.listdir(path): 86 | bucket[filename.split('.')[-2].split('_')[-1]] = 1 87 | return tuple(bucket.keys()) 88 | 89 | 90 | def make_bpe2idx(voca): 91 | bpe2idx = {'

':0, '':1, '':2, '':3} 92 | idx2bpe = ['

', '', '', ''] 93 | idx = 4 94 | 95 | for word, _ in voca: 96 | bpe2idx[word] = idx 97 | idx += 1 98 | idx2bpe.append(word) 99 | 100 | return bpe2idx, idx2bpe 101 | 102 | 103 | def read_data_set(path, target_type='csv'): 104 | buckets = _get_bucket_name(path) 105 | 106 | dictionary = {} 107 | total_sentence = 0 108 | 109 | for i in tqdm(range(len(buckets)), ncols=50): 110 | bucket = buckets[i] # '(35, 35)' string 111 | 112 | source_path = os.path.join(path, 'source_'+bucket+'.csv') 113 | sentence = _read_csv(source_path) 114 | 115 | if target_type == 'csv': 116 | target_path = os.path.join(path, 'target_'+bucket+'.csv') 117 | target = _read_csv(target_path) 118 | else: 119 | target_path = os.path.join(path, 'target_'+bucket+'.txt') 120 | target = _read_txt(target_path) 121 | 122 | # 개수가 0인 bucket은 버림. 123 | sentence_num = len(sentence) 124 | if sentence_num != 0: 125 | total_sentence += sentence_num 126 | 127 | sentence_bucket, target_bucket = bucket[1:-1].split(',') 128 | tuple_bucket = (int(sentence_bucket), int(target_bucket)) 129 | dictionary[tuple_bucket] = [sentence, target] 130 | 131 | print('data_path:', path, 'data_size:', total_sentence, '\n') 132 | return dictionary 133 | 134 | 135 | 136 | def get_lr(embedding_size, step_num): 137 | ''' 138 | https://ufal.mff.cuni.cz/pbml/110/art-popel-bojar.pdf 139 | step_num(training_steps): number of iterations, ie. the number of times the optimizer update was run 140 | This number also equals the number of mini batches that were processed. 141 | ''' 142 | lr = (embedding_size**-0.5) * min( (step_num**-0.5), (step_num*(warmup_steps**-1.5)) ) 143 | return lr 144 | 145 | 146 | 147 | def train(model, data, epoch): 148 | loss = 0 149 | 150 | dataset = data.get_dataset(bucket_shuffle=True, dataset_shuffle=True) 151 | total_iter = len(dataset) 152 | 153 | for i in tqdm(range(total_iter), ncols=50): 154 | step_num = ((epoch-1)*total_iter)+(i+1) 155 | lr = get_lr(embedding_size=embedding_size, step_num=step_num) # epoch: [1, @], i:[0, total_iter) 156 | 157 | encoder_input, temp = dataset[i] 158 | decoder_input = temp[:, :-1] 159 | #print(encoder_input.shape, decoder_input.shape, 4*np.multiply(*encoder_input.shape)*512/1000000000,"GB", 4*np.multiply(*decoder_input.shape)*40297/1000000000,'GB') 160 | target = temp[:, 1:] # except '' 161 | train_loss, _ = sess.run([model.train_cost, model.minimize], 162 | { 163 | model.lr:lr, 164 | model.encoder_input:encoder_input, 165 | model.decoder_input:decoder_input, 166 | model.target:target, 167 | model.keep_prob:0.9 # dropout rate = 0.1 168 | } 169 | ) 170 | loss += train_loss 171 | #if (i+1) % 5000 == 0: 172 | # print(i+1,loss/(i+1), 'lr:', lr) 173 | 174 | print('current step_num:', step_num, 'lr:', lr) 175 | return loss/total_iter 176 | 177 | 178 | def infer(model, data): 179 | pred_list = [] 180 | target_list = [] 181 | 182 | dataset = data.get_dataset(bucket_shuffle=False, dataset_shuffle=False) 183 | total_iter = len(dataset) 184 | 185 | for i in tqdm(range(total_iter), ncols=50): 186 | encoder_input, target = dataset[i] 187 | target_length = encoder_input.shape[1] + 30 188 | 189 | pred = infer_helper.decode(encoder_input, target_length) # [N, target_length] 190 | del encoder_input 191 | first_eos = np.argmax(pred == bpe2idx[''], axis=1) # [N] 최초로 eos 나오는 index. 192 | 193 | for _pred, _first_eos, _target in (zip(pred, first_eos, target)): 194 | if _first_eos != 0: 195 | _pred = _pred[:_first_eos] 196 | _pred = [idx2bpe[idx] for idx in _pred] # idx2bpe 197 | _pred = ''.join(_pred) # 공백 없이 전부 concat 198 | _pred = _pred.replace('', ' ') # 공백 symbol을 공백으로 치환. 199 | pred_list.append(_pred.split()) 200 | target_list.append([_target]) 201 | 202 | bleu = utils.bleu(target_list, pred_list) * 100 203 | return bleu 204 | 205 | 206 | 207 | def run(model, trainset2017, validset2014, testset2015, testset2016, restore=0): 208 | if restore != 0: 209 | model.saver.restore(sess, saver_path+str(restore)+".ckpt") 210 | print('restore:', restore) 211 | 212 | 213 | with tf.name_scope("tensorboard"): 214 | train_loss_tensorboard_2017 = tf.placeholder(tf.float32, name='train_loss_2017') 215 | valid_bleu_tensorboard_2014 = tf.placeholder(tf.float32, name='valid_bleu_2014') 216 | test_bleu_tensorboard_2015 = tf.placeholder(tf.float32, name='test_bleu_2015') 217 | test_bleu_tensorboard_2016 = tf.placeholder(tf.float32, name='test_bleu_2016') 218 | 219 | train_summary_2017 = tf.summary.scalar("train_loss_wmt17", train_loss_tensorboard_2017) 220 | valid_summary_2014 = tf.summary.scalar("valid_bleu_newstest2014", valid_bleu_tensorboard_2014) 221 | test_summary_2015 = tf.summary.scalar("test_bleu_newstest2015", test_bleu_tensorboard_2015) 222 | test_summary_2016 = tf.summary.scalar("test_bleu_newstest2016", test_bleu_tensorboard_2016) 223 | 224 | merged = tf.summary.merge_all() 225 | writer = tf.summary.FileWriter(tensorboard_path, sess.graph) 226 | #merged_train_valid = tf.summary.merge([train_summary, valid_summary]) 227 | #merged_test = tf.summary.merge([test_summary]) 228 | 229 | 230 | if not os.path.exists(saver_path): 231 | print("create save directory") 232 | os.makedirs(saver_path) 233 | 234 | for epoch in range(restore+1, 20000+1): 235 | #train 236 | train_loss_2017 = train(model, trainset2017, epoch) 237 | 238 | #save 239 | model.saver.save(sess, saver_path+str(epoch)+".ckpt") 240 | 241 | #validation 242 | valid_bleu_2014 = infer(model, validset2014) 243 | 244 | #test 245 | test_bleu_2015 = infer(model, testset2015) 246 | test_bleu_2016 = infer(model, testset2016) 247 | print("epoch:", epoch) 248 | print('train_loss_wmt17:', train_loss_2017, 'valid_bleu_newstest2014:', valid_bleu_2014) 249 | print('test_bleu_newstest2015:', test_bleu_2015, 'test_bleu_newstest2016:', test_bleu_2016, '\n') 250 | 251 | 252 | #tensorboard 253 | summary = sess.run(merged, { 254 | train_loss_tensorboard_2017:train_loss_2017, 255 | valid_bleu_tensorboard_2014:valid_bleu_2014, 256 | test_bleu_tensorboard_2015:test_bleu_2015, 257 | test_bleu_tensorboard_2016:test_bleu_2016, 258 | } 259 | ) 260 | writer.add_summary(summary, epoch) 261 | 262 | 263 | 264 | print('Data read') # key: bucket_size(tuple) , value: [source, target] 265 | train_dict_2017 = read_data_set(train_path_2017) 266 | valid_dict_2014 = read_data_set(valid_path_2014, 'txt') 267 | test_dict_2015 = read_data_set(test_path_2015, 'txt') 268 | test_dict_2016 = read_data_set(test_path_2016, 'txt') 269 | 270 | train_set_2017 = bucket_data_helper.bucket_data(train_dict_2017, batch_token = 11000) # batch_token // len(sentence||target token) == batch_size 271 | valid_set_2014 = bucket_data_helper.bucket_data(valid_dict_2014, batch_token = 9000) # batch_token // len(sentence||target token) == batch_size 272 | test_set_2015 = bucket_data_helper.bucket_data(test_dict_2015, batch_token = 9000) # batch_token // len(sentence||target token) == batch_size 273 | test_set_2016 = bucket_data_helper.bucket_data(test_dict_2016, batch_token = 9000) # batch_token // len(sentence||target token) == batch_size 274 | del train_dict_2017, valid_dict_2014, test_dict_2015, test_dict_2016 275 | 276 | 277 | print("Model read") 278 | config = tf.ConfigProto() 279 | config.gpu_options.allow_growth=True 280 | sess = tf.Session(config=config) 281 | 282 | voca = read_voca(voca_path) 283 | bpe2idx, idx2bpe = make_bpe2idx(voca) 284 | warmup_steps = 4000 * 8 # paper warmup_steps: 4000(with 8-gpus), so warmup_steps of single gpu: 4000*8 285 | embedding_size = 512 286 | encoder_decoder_stack = 6 287 | multihead_num = 8 288 | label_smoothing = 0.1 289 | beam_width = 4 290 | length_penalty = 0.6 291 | 292 | print('voca_size:', len(bpe2idx)) 293 | print('warmup_steps:', warmup_steps) 294 | print('embedding_size:', embedding_size) 295 | print('encoder_decoder_stack:', encoder_decoder_stack) 296 | print('multihead_num:', multihead_num) 297 | print('label_smoothing:', label_smoothing) 298 | print('beam_width:', beam_width) 299 | print('length_penalty:', length_penalty, '\n') 300 | 301 | model = transformer.Transformer( 302 | sess = sess, 303 | voca_size = len(bpe2idx), 304 | embedding_size = embedding_size, 305 | is_embedding_scale = True, 306 | PE_sequence_length = 300, 307 | encoder_decoder_stack = encoder_decoder_stack, 308 | multihead_num = multihead_num, 309 | eos_idx=bpe2idx[''], 310 | pad_idx=bpe2idx['

'], 311 | label_smoothing=label_smoothing 312 | ) 313 | 314 | # beam search 315 | infer_helper = inference_helper.beam( 316 | sess = sess, 317 | model = model, 318 | go_idx = bpe2idx[''], 319 | eos_idx = bpe2idx[''], 320 | beam_width = beam_width, 321 | length_penalty = length_penalty 322 | ) 323 | # bleu util 324 | utils = inference_helper.utils() 325 | 326 | 327 | print('run') 328 | run(model, train_set_2017, valid_set_2014, test_set_2015, test_set_2016) 329 | 330 | ''' 331 | # greedy search 332 | infer_helper = inference_helper.greedy( 333 | sess = sess, 334 | model = model, 335 | go_idx = bpe2idx[''] 336 | ) 337 | ''' --------------------------------------------------------------------------------