├── segment.py ├── README.md ├── word2index.py ├── utils.py └── align-and-translate-char.ipynb /segment.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | from wordsegment import segment 3 | import os 4 | train_prefix = 'data/UM-Corpus/data/Bilingual/' 5 | test_prefix = 'data/UM-Corpus/data/Testing/' 6 | train_classes = os.listdir(train_prefix) 7 | 8 | train_files = [os.listdir(train_prefix + i)[0] for i in train_classes] 9 | 10 | train_chinese = [] 11 | train_english = [] 12 | for i,j in zip(train_classes,train_files): 13 | filename = os.path.join(train_prefix,i,j) 14 | with open(filename,encoding='utf-8') as fhdl: 15 | flag = 0 16 | for line in fhdl: 17 | if flag == 0: 18 | train_english.append(line.strip()) 19 | else: 20 | train_chinese.append(line.strip()) 21 | flag = flag ^ 1 22 | 23 | print(len(train_chinese),len(train_english)) 24 | from utils import * 25 | pb = ProgressBar(worksum=len(train_chinese),auto_display=False) 26 | 27 | pb.startjob() 28 | train_token_chinese = [] 29 | train_token_english = [] 30 | num = 0 31 | with open('middleresult/segmented_train.txt','w',encoding='utf-8') as whdl: 32 | for ch,en in zip(train_chinese,train_english): 33 | num += 1 34 | token_en = [i.lower() for i in jieba.cut(en) if i.strip()] 35 | token_ch = [i for i in ch if i.strip()] 36 | train_token_chinese.append(token_ch) 37 | train_token_english.append(token_en) 38 | whdl.write("{}\n".format(' '.join(token_en))) 39 | whdl.write("{}\n".format(' '.join(token_ch))) 40 | pb.complete(1) 41 | if num % 32 == 0: 42 | pb.display_progress_bar() 43 | 44 | import pickle 45 | with open('middleresult/segmented_train.pkl','wb') as whdl: 46 | pickle.dump((train_token_chinese,train_token_english),whdl) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # icytranslate_offline 2 | This project is the offline part of [Icytranslate](http://translate.icybee.cn/) , an English-Chinese translate platform. The output of this project is a translate model, which is the core component of icytranslate. 3 | 4 | # data preparation 5 | We use UM-corpus as our default training dataset, which can be applied here: 6 | 7 | [http://nlp2ct.cis.umac.mo/um-corpus/](http://nlp2ct.cis.umac.mo/um-corpus/) 8 | 9 | # User your own dataset 10 | Althrough UM-corpus is a fine dataset, we encourage you to use your own dataset and report your results. If you want to use other datasets , you might need to modify the code in segment.py and change the train_prefix and test_prefix to the actual data dir. 11 | 12 | Beweare the dataset you use should have the same data structure as the UM-crop otherwise you might want to read the ```segment.py``` and modify some of the code in it. 13 | 14 | # data preprocessing 15 | 16 | ### tokenlizer 17 | We first need to process the corpus into series of words, run ```python segment.py``` to do that. The output should be ```segment_train.pkl``` in ```middleresult``` dir. 18 | 19 | We process all english sentences in to words in lower case , and process all chinese sentences into lists of chinese characters. 20 | 21 | ### encode the tokenlized series 22 | The next step is to convert the tokenlizered sentences into sequences of words, doing that, you only need to run 23 | 24 | ```python word2index.py --max_words=[max words in a sentence that you want]``` 25 | 26 | # model training 27 | Now we can train our model. You may find a ```align-and-translate-char``` ipynb file in the folder, open the file with an IDE or jupyter notebook, and follow the steps there, you will get the model trained and a test bleu around 0.22. 28 | 29 | # dependences 30 | ``` 31 | tensorflow 1.2.0 for neural network 32 | jieba for english word tokenlizer 33 | nltk to calculate bleu score 34 | sklearn , numpy as toolkit 35 | ``` 36 | -------------------------------------------------------------------------------- /word2index.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from utils import ProgressBar 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--max_words", type=int, 7 | help="max words of the sentences tokened",default=-1) 8 | parser.add_argument("--test_lines", type=int, 9 | help="line number used to test the program,leave empty to process full text",default=-1) 10 | parser.add_argument("--vac_dict_en", type=int, 11 | help="dictionary size for english",default=60000) 12 | parser.add_argument("--vac_dict_ch", type=int, 13 | help="dictionary size for chinese",default=60000) 14 | args = parser.parse_args() 15 | 16 | 17 | print('reading pkl data...') 18 | with open('middleresult/segmented_train.pkl','rb') as fhdl: 19 | (train_token_chinese,train_token_english) = pickle.load(fhdl) 20 | if args.test_lines != -1: 21 | train_token_chinese = train_token_chinese[:args.test_lines] 22 | train_token_english = train_token_english[:args.test_lines] 23 | print(len(train_token_chinese),len(train_token_english)) 24 | 25 | import random 26 | index = random.randint(0,len(train_token_chinese)) 27 | print(train_token_chinese[index]) 28 | print(train_token_english[index]) 29 | 30 | from collections import Counter 31 | from functools import reduce 32 | print('counting...') 33 | 34 | ch_dic = {} 35 | def get_most_common(a1,a2): 36 | temp_dict1 = {} 37 | temp_dict2 = {} 38 | pb = ProgressBar(worksum=len(a1),auto_display=False) 39 | pb.startjob() 40 | num = 0 41 | for s1,s2 in zip(a1,a2): 42 | num += 1 43 | pb.complete(1) 44 | if args.max_words != -1 and len(s1) > args.max_words: 45 | continue 46 | for w1 in s1: 47 | temp_dict1.setdefault(w1,0) 48 | temp_dict1[w1] += 1 49 | for w2 in s2: 50 | temp_dict2.setdefault(w2,0) 51 | temp_dict2[w2] += 1 52 | 53 | if num % 32 == 0: 54 | pb.display_progress_bar() 55 | sorted1 = sorted(temp_dict1.items(),key=lambda i:i[1],reverse=True) 56 | sorted2 = sorted(temp_dict2.items(),key=lambda i:i[1],reverse=True) 57 | #print(sorted1[:100]) 58 | #print(sorted2[:100]) 59 | return [i[0] for i in sorted1[:args.vac_dict_ch]],[i[0] for i in sorted2[:args.vac_dict_en]] 60 | 61 | most_common_ch ,most_common_en = get_most_common(train_token_chinese,train_token_english) 62 | print("\n ch words:{} en words:{}".format(len(most_common_ch),len(most_common_en))) 63 | print(most_common_ch[:20]) 64 | print(most_common_en[:20]) 65 | 66 | print('zipping...') 67 | ind2ch = dict(zip(range(1,len(most_common_ch) + 1),most_common_ch)) 68 | ch2ind = dict(zip(most_common_ch,range(1,len(most_common_ch) + 1))) 69 | ind2en = dict(zip(range(1,len(most_common_en) + 1),most_common_en)) 70 | en2ind = dict(zip(most_common_en,range(1,len(most_common_en) + 1))) 71 | 72 | print('toklizing...') 73 | train_x = [[en2ind.get(j,0) for j in i] for i in train_token_english if (args.max_words == -1 or len(i) < args.max_words)] 74 | train_y = [[ch2ind.get(j,0) for j in i] for i in train_token_chinese if (args.max_words == -1 or len(i) < args.max_words)] 75 | 76 | print(len(train_x),len(train_y)) 77 | print(train_x[0]) 78 | print(' '.join([ind2en.get(i,'') for i in train_x[0]])) 79 | print(train_y[0]) 80 | print(' '.join([ind2ch.get(i,'') for i in train_y[0]])) 81 | 82 | with open('middleresult/tokenlizer_output_{}ch_{}en_{}words.pkl'.format(args.vac_dict_ch,args.vac_dict_en,args.max_words),'wb') as whdl: 83 | pickle.dump(( 84 | ind2ch, 85 | ch2ind, 86 | ind2en, 87 | en2ind, 88 | train_x, 89 | train_y, 90 | ),whdl) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # define dataset class to feed the model 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | 7 | class Dataset(): 8 | def __init__(self,data,label): 9 | self._index_in_epoch = 0 10 | self._epochs_completed = 0 11 | self._data = data 12 | self._label = label 13 | assert(data.shape[0] == label.shape[0]) 14 | self._num_examples = data.shape[0] 15 | pass 16 | 17 | @property 18 | def data(self): 19 | return self._data 20 | 21 | @property 22 | def label(self): 23 | return self._label 24 | 25 | def next_batch(self,batch_size,shuffle = True): 26 | start = self._index_in_epoch 27 | if start == 0 and self._epochs_completed == 0: 28 | idx = np.arange(0, self._num_examples) # get all possible indexes 29 | np.random.shuffle(idx) # shuffle indexe 30 | self._data = self.data[idx] # get list of `num` random samples 31 | self._label = self.label[idx] 32 | 33 | # go to the next batch 34 | if start + batch_size > self._num_examples: 35 | self._epochs_completed += 1 36 | rest_num_examples = self._num_examples - start 37 | data_rest_part = self.data[start:self._num_examples] 38 | label_rest_part = self.label[start:self._num_examples] 39 | idx0 = np.arange(0, self._num_examples) # get all possible indexes 40 | np.random.shuffle(idx0) # shuffle indexes 41 | self._data = self.data[idx0] # get list of `num` random samples 42 | self._label = self.label[idx0] 43 | 44 | start = 0 45 | self._index_in_epoch = batch_size - rest_num_examples #avoid the case where the #sample != integar times of batch_size 46 | end = self._index_in_epoch 47 | data_new_part = self._data[start:end] 48 | label_new_part = self._label[start:end] 49 | return np.concatenate((data_rest_part, data_new_part), axis=0),np.concatenate((label_rest_part, label_new_part), axis=0) 50 | else: 51 | self._index_in_epoch += batch_size 52 | end = self._index_in_epoch 53 | return self._data[start:end],self._label[start:end] 54 | 55 | class ProgressBar(): 56 | def __init__(self,worksum,info="",auto_display=True): 57 | self.worksum = worksum 58 | self.info = info 59 | self.finishsum = 0 60 | self.auto_display = auto_display 61 | def startjob(self): 62 | self.begin_time = time.time() 63 | def complete(self,num): 64 | self.gaptime = time.time() - self.begin_time 65 | self.finishsum += num 66 | if self.auto_display == True: 67 | self.display_progress_bar() 68 | def display_progress_bar(self): 69 | percent = self.finishsum * 100 / self.worksum 70 | eta_time = self.gaptime * 100 / (percent + 0.001) - self.gaptime 71 | strprogress = "[" + "=" * int(percent // 2) + ">" + "-" * int(50 - percent // 2) + "]" 72 | str_log = ("%s %.2f %% %s %s/%s \t used:%ds eta:%d s" % (self.info,percent,strprogress,self.finishsum,self.worksum,self.gaptime,eta_time)) 73 | sys.stdout.write('\r' + str_log) 74 | 75 | def get_dataset(paths): 76 | dataset = [] 77 | for path in paths.split(':'): 78 | path_exp = os.path.expanduser(path) 79 | classes = os.listdir(path_exp) 80 | classes.sort() 81 | nrof_classes = len(classes) 82 | for i in range(nrof_classes): 83 | class_name = classes[i] 84 | facedir = os.path.join(path_exp, class_name) 85 | if os.path.isdir(facedir): 86 | images = os.listdir(facedir) 87 | image_paths = [os.path.join(facedir,img) for img in images] 88 | dataset.append(ImageClass(class_name, image_paths)) 89 | 90 | return dataset 91 | 92 | class ImageClass(): 93 | "Stores the paths to images for a given class" 94 | def __init__(self, name, image_paths): 95 | self.name = name 96 | self.image_paths = image_paths 97 | 98 | def __str__(self): 99 | return self.name + ', ' + str(len(self.image_paths)) + ' images' 100 | 101 | def __len__(self): 102 | return len(self.image_paths) 103 | 104 | def split_dataset(dataset, split_ratio, mode): 105 | if mode=='SPLIT_CLASSES': 106 | nrof_classes = len(dataset) 107 | class_indices = np.arange(nrof_classes) 108 | np.random.shuffle(class_indices) 109 | split = int(round(nrof_classes*split_ratio)) 110 | train_set = [dataset[i] for i in class_indices[0:split]] 111 | test_set = [dataset[i] for i in class_indices[split:-1]] 112 | elif mode=='SPLIT_IMAGES': 113 | train_set = [] 114 | test_set = [] 115 | min_nrof_images = 2 116 | for cls in dataset: 117 | paths = cls.image_paths 118 | np.random.shuffle(paths) 119 | split = int(round(len(paths)*split_ratio)) 120 | if split= src_vocab_size] = 1 \n", 229 | "test_x[test_x >= src_vocab_size] = 1 " 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 182, 235 | "metadata": { 236 | "collapsed": true 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "train_y[train_y >= target_vocat_size] = 1 \n", 241 | "test_y[test_y >= target_vocat_size] = 1 " 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 183, 247 | "metadata": { 248 | "collapsed": true 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "import numpy as np\n", 253 | "x_last_index = np.max(train_x)\n", 254 | "y_last_index = np.max(train_y)\n", 255 | "\n", 256 | "ind2en[x_last_index] = ''\n", 257 | "ind2ch[y_last_index] = ''\n", 258 | "en2ind[''] = x_last_index\n", 259 | "ch2ind[''] = y_last_index\n", 260 | "\n", 261 | "ind2en[1] = ''\n", 262 | "ind2ch[1] = ''\n", 263 | "en2ind[''] = 1\n", 264 | "ch2ind[''] = 1" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 184, 270 | "metadata": { 271 | "scrolled": false 272 | }, 273 | "outputs": [ 274 | { 275 | "name": "stdout", 276 | "output_type": "stream", 277 | "text": [ 278 | "licence authorizing the use of a factory situated in hong kong for treating whales \n", 279 | "批 准 在 位 於 香 港 的 工 厂 加 工 处 理 鲸 的 牌 照 \n", 280 | "here ' s what ' s different about people . we have the same needs , \n", 281 | "这 就 是 人 们 不 一 样 的 地 方 。 我 们 有 同 样 的 需 求 , \n", 282 | "this said , ... he wished to have me in his sight / once , as a friend : this fixed a day in spring \n", 283 | "我 膝 上 . 这 封 说 : 他 多 盼 望 有 个 机 会 , / 能 作 为 朋 友 , 见 一 见 我 . 这 一 封 又 订 了 \n", 284 | "the issue occurs when authenticated mysql users overwrite arbitrary files by using a attack . \n", 285 | "题 存 在 于 经 过 身 份 验 证 的 m y s q l 用 户 利 用 s y m l i n k 攻 击 覆 盖 任 意 文 件 的 过 程 中 。\n", 286 | "watch your back , fish , ' cause squirrel master ain ' t gonna be there for you all the time . \n", 287 | "小 心 点 , 鱼 , 斯 葵 尔 玛 斯 特 不 会 总 照 着 你 的 。 \n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "import random\n", 293 | "for i in range(5):\n", 294 | " index = random.randint(0,len(train_x))\n", 295 | " print(' '.join([ind2en.get(i,'') for i in train_x[index]]))\n", 296 | " print(' '.join([ind2ch.get(i,'') for i in train_y[index]]))" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "# define the translate nmt model" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 26, 309 | "metadata": { 310 | "collapsed": true 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "from tensorflow.python.layers import core as layers_core\n" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 27, 320 | "metadata": { 321 | "collapsed": true 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "import tensorflow as tf\n", 326 | "import tflearn\n", 327 | "tf.reset_default_graph()\n", 328 | "config = tf.ConfigProto(log_device_placement=True,allow_soft_placement = True)\n", 329 | "config.gpu_options.allow_growth = True\n", 330 | "#config.gpu_options.per_process_gpu_memory_fraction = 0.4\n", 331 | "session = tf.Session(config=config)\n", 332 | "\n", 333 | "\n", 334 | "with tf.device('/gpu:1'):\n", 335 | " initializer = tf.random_uniform_initializer(\n", 336 | " -0.08, 0.08)\n", 337 | " tf.get_variable_scope().set_initializer(initializer)\n", 338 | " \n", 339 | " x = tf.placeholder(\"int32\", [None, None])\n", 340 | " y = tf.placeholder(\"int32\", [None, None])\n", 341 | " y_in = tf.placeholder(\"int32\",[None,None])\n", 342 | " x_len = tf.placeholder(\"int32\",[None])\n", 343 | " y_len = tf.placeholder(\"int32\",[None])\n", 344 | " x_real_len = tf.placeholder(\"int32\",[None])\n", 345 | " y_real_len = tf.placeholder(\"int32\",[None])\n", 346 | " learning_rate = tf.placeholder(tf.float32, shape=[])\n", 347 | " \n", 348 | " # embedding\n", 349 | " embedding_encoder = tf.get_variable(\n", 350 | " \"embedding_encoder\", [src_vocab_size, embedding_size],dtype=tf.float32)\n", 351 | " embedding_decoder = tf.get_variable(\n", 352 | " \"embedding_decoder\", [target_vocat_size, embedding_size],dtype=tf.float32)\n", 353 | " \n", 354 | " encoder_emb_inp = tf.nn.embedding_lookup(\n", 355 | " embedding_encoder, x)\n", 356 | " decoder_emb_inp = tf.nn.embedding_lookup(\n", 357 | " embedding_decoder, y_in)\n", 358 | " \n", 359 | " # encoder\n", 360 | " num_bi_layers = int(layer_number / 2)\n", 361 | " cell_list = []\n", 362 | " for i in range(num_bi_layers):\n", 363 | " cell_list.append(\n", 364 | " tf.contrib.rnn.DropoutWrapper(\n", 365 | " tf.contrib.rnn.BasicLSTMCell(num_units), input_keep_prob=(1.0 - dropout)\n", 366 | " )\n", 367 | " )\n", 368 | " if len(cell_list) == 1:\n", 369 | " encoder_cell = cell_list[0]\n", 370 | " else:\n", 371 | " encoder_cell = tf.contrib.rnn.MultiRNNCell(cell_list)\n", 372 | " \n", 373 | " cell_list = []\n", 374 | " \n", 375 | " for i in range(num_bi_layers):\n", 376 | " cell_list.append(\n", 377 | " tf.contrib.rnn.DropoutWrapper(\n", 378 | " tf.contrib.rnn.BasicLSTMCell(num_units), input_keep_prob=(1.0 - dropout)\n", 379 | " )\n", 380 | " )\n", 381 | " if len(cell_list) == 1:\n", 382 | " encoder_backword_cell = cell_list[0]\n", 383 | " else:\n", 384 | " encoder_backword_cell = tf.contrib.rnn.MultiRNNCell(cell_list)\n", 385 | " \n", 386 | " bi_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(\n", 387 | " encoder_cell,encoder_backword_cell, encoder_emb_inp,\n", 388 | " sequence_length=x_len,dtype=tf.float32)\n", 389 | " encoder_outputs = tf.concat(bi_outputs, -1)\n", 390 | " \n", 391 | " if num_bi_layers == 1:\n", 392 | " encoder_state = bi_encoder_state\n", 393 | " else:\n", 394 | " encoder_state = []\n", 395 | " for layer_id in range(num_bi_layers):\n", 396 | " encoder_state.append(bi_encoder_state[0][layer_id]) # forward\n", 397 | " encoder_state.append(bi_encoder_state[1][layer_id]) # backward\n", 398 | " encoder_state = tuple(encoder_state)\n", 399 | " \n", 400 | " # decoder \n", 401 | " #decoder_cell = tf.contrib.rnn.BasicLSTMCell(num_units)\n", 402 | " cell_list = []\n", 403 | " for i in range(layer_number):\n", 404 | " cell_list.append(\n", 405 | " tf.contrib.rnn.DropoutWrapper(\n", 406 | " tf.contrib.rnn.BasicLSTMCell(num_units), input_keep_prob=(1.0 - dropout)\n", 407 | " )\n", 408 | " )\n", 409 | " if len(cell_list) == 1:\n", 410 | " decoder_cell = cell_list[0]\n", 411 | " else:\n", 412 | " decoder_cell = tf.contrib.rnn.MultiRNNCell(cell_list)\n", 413 | " \n", 414 | " # Helper\n", 415 | " \n", 416 | " # attention\n", 417 | " attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n", 418 | " attention_hidden_size, encoder_outputs,\n", 419 | " memory_sequence_length=x_real_len,scale=True)\n", 420 | " decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n", 421 | " decoder_cell, attention_mechanism,\n", 422 | " attention_layer_size=attention_output_size)\n", 423 | " \n", 424 | " \n", 425 | " projection_layer = layers_core.Dense(\n", 426 | " target_vocat_size, use_bias=False)\n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " # Dynamic decoding\n", 431 | " with tf.variable_scope(\"decode_layer\"):\n", 432 | " helper = tf.contrib.seq2seq.TrainingHelper(\n", 433 | " decoder_emb_inp,sequence_length= y_len)\n", 434 | " decoder = tf.contrib.seq2seq.BasicDecoder(\n", 435 | " decoder_cell, helper, initial_state = decoder_cell.zero_state(dtype=tf.float32,batch_size=batch_size),\n", 436 | " output_layer=projection_layer)\n", 437 | " \n", 438 | " outputs, _,___ = tf.contrib.seq2seq.dynamic_decode(decoder)\n", 439 | " logits = outputs.rnn_output\n", 440 | "\n", 441 | " target_weights = tf.sequence_mask(\n", 442 | " y_real_len, seq_max_len, dtype=logits.dtype)\n", 443 | " \n", 444 | " # predicting\n", 445 | " # Helper\n", 446 | " with tf.variable_scope(\"decode_layer\", reuse=True):\n", 447 | " helper_predict = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n", 448 | " embedding_decoder,\n", 449 | " tf.fill([batch_size], ch2ind['']), 0)\n", 450 | " decoder_predict = tf.contrib.seq2seq.BasicDecoder(\n", 451 | " decoder_cell, helper_predict, initial_state = decoder_cell.zero_state(dtype=tf.float32,batch_size=batch_size),\n", 452 | " output_layer=projection_layer)\n", 453 | " outputs_predict,_, __ = tf.contrib.seq2seq.dynamic_decode(\n", 454 | " decoder_predict, maximum_iterations=test_y.shape[1] * 2)\n", 455 | " translations = outputs_predict.sample_id\n", 456 | "\n", 457 | " # calculate loss\n", 458 | " crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(\n", 459 | " labels=y, logits=logits)\n", 460 | " train_loss = (tf.reduce_sum(crossent * target_weights) /\n", 461 | " batch_size)\n", 462 | " \n", 463 | " optimizer_ori = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n", 464 | " trainable_params = tf.trainable_variables()\n", 465 | " gradients = tf.gradients(train_loss, trainable_params)\n", 466 | " clip_gradients, _ = tf.clip_by_global_norm(gradients, max_grad)\n", 467 | " global_step = tf.Variable(0, trainable=False, name='global_step')\n", 468 | " optimizer = optimizer_ori.apply_gradients(\n", 469 | " zip(clip_gradients, trainable_params), global_step=global_step)\n", 470 | " #optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(train_loss)\n", 471 | " #trainop = tflearn.TrainOp(loss=train_loss, optimizer=optimizer,\n", 472 | " # metric=train_loss, batch_size=64)\n" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 28, 478 | "metadata": { 479 | "collapsed": true 480 | }, 481 | "outputs": [], 482 | "source": [ 483 | "def cal_acc(logits,target):\n", 484 | " max_seq = max(target.shape[1], logits.shape[1])\n", 485 | " if max_seq - target.shape[1]:\n", 486 | " target = np.pad(\n", 487 | " target,\n", 488 | " [(0,0),(0,max_seq - target.shape[1])],\n", 489 | " 'constant')\n", 490 | " if max_seq - logits.shape[1]:\n", 491 | " logits = np.pad(\n", 492 | " logits,\n", 493 | " [(0,0),(0,max_seq - logits.shape[1])],\n", 494 | " 'constant')\n", 495 | "\n", 496 | " return np.mean(np.equal(target[:,:seq_max_len], logits[:,:seq_max_len]))\n" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "# init the model" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 29, 509 | "metadata": { 510 | "collapsed": true 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "session.run(tf.global_variables_initializer())" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 30, 520 | "metadata": { 521 | "collapsed": true 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "saver = tf.train.Saver()" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 29, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "INFO:tensorflow:Restoring parameters from middleresult/align/result_1_20847\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "#saver.restore(session,'middleresult/align/result_1_20847')" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 32, 548 | "metadata": { 549 | "scrolled": true 550 | }, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/plain": [ 555 | "'middleresult/result_char'" 556 | ] 557 | }, 558 | "execution_count": 32, 559 | "metadata": {}, 560 | "output_type": "execute_result" 561 | } 562 | ], 563 | "source": [ 564 | "saver.save(session,'middleresult/result_char')" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": 33, 570 | "metadata": { 571 | "collapsed": true 572 | }, 573 | "outputs": [], 574 | "source": [ 575 | "from utils import Dataset,ProgressBar" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 34, 581 | "metadata": { 582 | "collapsed": true 583 | }, 584 | "outputs": [], 585 | "source": [ 586 | "from utils import *\n", 587 | "train_set = Dataset(train_x,train_y)\n", 588 | "test_set = Dataset(test_x,test_y)" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 35, 594 | "metadata": { 595 | "collapsed": true 596 | }, 597 | "outputs": [], 598 | "source": [ 599 | "def get_bleu_score(predict,target):\n", 600 | " try:\n", 601 | " target = [[[j for index,j in enumerate(i) if j > 0 or index < 4]] for i in target]\n", 602 | " predict = [[j for index,j in enumerate(i) if j > 0 or index < 4] for i in predict]\n", 603 | " BLEUscore = nltk.translate.bleu_score.corpus_bleu(target,predict)\n", 604 | " except:\n", 605 | " BLEUscore = -1\n", 606 | " return BLEUscore" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 36, 612 | "metadata": { 613 | "collapsed": true 614 | }, 615 | "outputs": [], 616 | "source": [ 617 | "import numpy as np\n", 618 | "def calc_test_loss(test_set = Dataset(test_x,test_y),display=True):\n", 619 | " accs = []\n", 620 | " worksum = int(len(test_x) / batch_size)\n", 621 | " loss_list = []\n", 622 | " predict_list = []\n", 623 | " target_list = []\n", 624 | " source_list = []\n", 625 | " pb = ProgressBar(worksum=worksum,info=\"validating...\",auto_display=display)\n", 626 | " pb.startjob()\n", 627 | " #test_set = Dataset(test_x,test_y)\n", 628 | " for j in range(worksum):\n", 629 | " batch_x,batch_y = test_set.next_batch(batch_size)\n", 630 | " lx = [seq_max_len] * batch_size\n", 631 | " ly = [seq_max_len] * batch_size\n", 632 | " bx = [np.sum(m > 0) for m in batch_x]\n", 633 | " by = [np.sum(m > 0) for m in batch_y]\n", 634 | " tmp_loss,tran = session.run([train_loss,translations],feed_dict={x:batch_x,y:batch_y,\n", 635 | " y_in:\n", 636 | " np.concatenate((\n", 637 | " np.ones((batch_y.shape[0],1),dtype=np.int) * ch2ind[''],batch_y[:,:-1]) ,axis=1)\n", 638 | " ,x_len:lx,y_len:ly,\n", 639 | " y_real_len:by,\n", 640 | " x_real_len:bx})\n", 641 | " loss_list.append(tmp_loss)\n", 642 | " tmp_acc = cal_acc(tran,batch_y)\n", 643 | " accs.append(tmp_acc)\n", 644 | " predict_list += [i for i in tran]\n", 645 | " target_list += [i for i in batch_y]\n", 646 | " source_list += [i for i in batch_x]\n", 647 | " pb.complete(1)\n", 648 | " return np.average(loss_list),np.average(accs),get_bleu_score(predict_list,target_list),predict_list,target_list,source_list" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": {}, 654 | "source": [ 655 | "# calculate the initial loss and see what model outputs in the begining." 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 38, 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "name": "stdout", 665 | "output_type": "stream", 666 | "text": [ 667 | "validating... 100.00 % [==================================================>] 300/300 \t used:120s eta:0 s" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "w_loss,w_acc,bleu_score,predict_list,target_list,source_list = calc_test_loss(Dataset(train_x[::100],train_y[::100]))" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": 39, 678 | "metadata": {}, 679 | "outputs": [ 680 | { 681 | "data": { 682 | "text/plain": [ 683 | "(233.5325, 5.5989583333333333e-05, 0.004130308764851706)" 684 | ] 685 | }, 686 | "execution_count": 39, 687 | "metadata": {}, 688 | "output_type": "execute_result" 689 | } 690 | ], 691 | "source": [ 692 | "w_loss,w_acc,bleu_score" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 40, 698 | "metadata": { 699 | "collapsed": true 700 | }, 701 | "outputs": [], 702 | "source": [ 703 | "def get_all_text(x):\n", 704 | " return [' '.join([ind2ch.get(j,'') for j in i]) for i in x]\n", 705 | "def get_all_en_text(x):\n", 706 | " return [' '.join([ind2en.get(j,'') for j in i]) for i in x]" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 41, 712 | "metadata": { 713 | "scrolled": true 714 | }, 715 | "outputs": [ 716 | { 717 | "data": { 718 | "text/plain": [ 719 | "['紮 紮 紮 紮 紮 麈 紮 麈 麈 麈 麈 麈 麈 麈 麈 黉 黉 佗 佗 黉 黉 黉 眃 眃 眃 眃 眃 眃 眃 眃 眃 格 格 格 格 贤 贤 格 ㏑ 格 格 格 格 格 ㏑ 格 格 ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ ㏑ 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 挪 衲 衲 衲 衲',\n", 720 | " '間 間 間 間 煜 煜 煜 臆 前 前 前 蝚 蝚 靘 μ ; 讫 讫 讫 讫 讫 讫 讫 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 鞑 蝚 铲 蝚 蝚 蝚 铲 铲 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 铲 铲 铲 铲 铲 铲 铲 铲 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦 睦',\n", 721 | " '馓 戴 籁 庥 龋 龋 语 д 4 囊 砖 气 气 气 气 气 气 气 气 气 气 气 气 气 气 粕 民 民 民 民 民 民 赔 民 筽 民 迹 迹 迹 迹 迹 迹 迹 鰶 犸 犸 犸 犸 犸 天 天 天 天 猴 猴 炫 猴 猴 洐 洐 猴 腋 钩 钩 钩 钩 钩 马 \\ue009 臻 臻 臻 臻 \\ue009 \\ue009 \\ue009 \\ue009 \\ue009 \\ue009 \\ue009',\n", 722 | " '呔 呔 浼 浼 浼 浼 浼 浼 驶 里 里 里 里 虬 蝚 蝚 灑 灑 蝚 蝚 蝚 蝚 》 苞 苞 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 涩 钭 钭 钭 钭 钭 钭 钭 钭 鯦 鯦 钭 骄 钭 钭 钭 桕 桕 桕 未 桕 桕 桕 拾 拾 拾 拾 拾 绗 绗 鰶',\n", 723 | " '庢 庢 庢 臆 臆 臆 臆 臆 磌 磌 磌 磌 磌 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 衲 \\ue40f 奇 \\ue40f 衲 衲 叻 奇 奇 圈 圈 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 谜 辆 辆 谜 辆 辆 辆 辆 辆 辆 罙 阍 罙 阍 阍 阍 叻 阍 阍 叻',\n", 724 | " '轰 轰 轰 浼 癞 浼 臆 臆 臆 臆 臆 臆 臆 臆 臆 臆 铜 臆 臆 臆 臆 臆 灑 灑 灑 灑 冼 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 鞑 蝚 鞑 鞑 鞑 蝚 蝚 蝚 蝚 蝚 匙 匙 匙 匙 匙 匙 蝚 蝚 蝚 蝚 蝚 蝚',\n", 725 | " '潵 間 潵 轰 轰 轰 \\ue620 局 \\ue620 辙 辙 辙 辙 辙 辙 辙 辙 讫 讫 甡 ㏑ 甡 甡 甡 》 》 》 》 箴 箴 》 》 》 》 》 豪 》 》 》 》 》 》 》 》 》 》 》 豪 杯 杯 茧 茧 茧 茧 簺 揄 簺 茧 奇 奇 奇 奇 奇 奇 奇 奇 奇 奇 奇 挢 挢 挢 醪 绗 绗 绗 绗 晦 里 里',\n", 726 | " '呔 呔 呔 呔 圞 拊 里 里 里 臆 臆 臆 算 算 算 臆 殭 殭 苞 苞 苞 苞 苞 苞 苞 苞 苞 苞 苞 里 里 苞 苞 苞 苞 苞 袈 衲 衲 桼 衲 衲 衲 衲 蝚 黔 黔 黔 黔 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 橹 匙 蝚 蝚 蝚 蝚 蝚 桼 畄 庇 庇 庇 庇 庇 庇 庇 庇 庇',\n", 727 | " '間 間 間 浼 杯 杯 杯 臆 臆 臆 臆 臆 讫 讫 臆 灑 臆 讫 灑 讫 蝚 蝚 佗 梢 佗 佗 佗 佗 佗 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 蝚 贤 贤 贤 贤 贤 黔 黔 せ 黔 碹 碹 碹 碹 碹 碹 碹 碹 碹 碹 珂 珂 叿 叿 芙 芙 芙 叿 叿 芙 芙 芙',\n", 728 | " 'k k 厉 厉 厉 厉 厉 里 蓆 蓆 蓆 拾 蓆 拾 拾 堤 叿 叿 拾 叿 叿 叿 叿 叿 叿 叿 叿 増 锇 锇 锇 锇 锇 锇 糗 蘧 蘧 蘧 蘧 蘧 蘧 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 虽 \\ue677 \\ue677 庚 庚 庚 庚 庚 嗙 癇 庚 癇 庚 癇 像 像 像 像 像']" 729 | ] 730 | }, 731 | "execution_count": 41, 732 | "metadata": {}, 733 | "output_type": "execute_result" 734 | } 735 | ], 736 | "source": [ 737 | "texts = get_all_text(predict_list)\n", 738 | "texts[:10]" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": 42, 744 | "metadata": {}, 745 | "outputs": [ 746 | { 747 | "data": { 748 | "text/plain": [ 749 | "['不 然 的 话 , 我 把 你 屁 股 捣 烂 - - 噢 , 上 帝 啊 ',\n", 750 | " '医 生 给 她 注 射 以 减 轻 疼 痛 ',\n", 751 | " '“ 我 们 希 望 中 国 能 够 慎 重 考 虑 很 多 谈 判 伙 伴 本 周 表 达 的 关 切 , 调 整 立 场 , 及 早 恢 复 谈 判 。 ”',\n", 752 | " '《 牛 津 当 代 大 辞 典 》 词 库 1 0 0 % 采 用 真 人 女 声 。 ',\n", 753 | " '定 一 个 牌 手 基 于 对 手 给 他 的 错 误 信 息 采 取 了 行 动 , 将 酌 情 按 第 2 1 条 或 第 4 7 条 e 款 处 理 。',\n", 754 | " '易 腐 产 品 标 准 化 和 质 量 改 进 工 作 队 ',\n", 755 | " '本 品 具 有 优 越 的 匀 染 效 果 , 优 越 的 耐 硬 水 及 耐 盐 性 。 ',\n", 756 | " '由 通 知 行 签 发 的 本 信 用 证 正 本 通 知 书 , 因 为 有 效 信 用 证 必 须 由 指 定 银 行 在 该 通 知 书 上 背 书 。',\n", 757 | " 'r s s 根 据 应 用 的 地 址 将 中 断 指 向 特 定 的 处 理 器 内 核 。 ',\n", 758 | " ') 报 送 上 市 公 司 收 购 报 告 书 时 所 持 有 被 收 购 公 司 股 份 数 占 该 公 司 已 发 行 的 股 份 总 数 的 比 例 。']" 759 | ] 760 | }, 761 | "execution_count": 42, 762 | "metadata": {}, 763 | "output_type": "execute_result" 764 | } 765 | ], 766 | "source": [ 767 | "texts = get_all_text(target_list)\n", 768 | "texts[:10]" 769 | ] 770 | }, 771 | { 772 | "cell_type": "markdown", 773 | "metadata": {}, 774 | "source": [ 775 | "# now train the model" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": 43, 781 | "metadata": { 782 | "collapsed": true 783 | }, 784 | "outputs": [], 785 | "source": [ 786 | "#tran.shape\n", 787 | "i_save = 0\n", 788 | "j_save = 0" 789 | ] 790 | }, 791 | { 792 | "cell_type": "code", 793 | "execution_count": 44, 794 | "metadata": {}, 795 | "outputs": [ 796 | { 797 | "name": "stdout", 798 | "output_type": "stream", 799 | "text": [ 800 | "0 0\n" 801 | ] 802 | } 803 | ], 804 | "source": [ 805 | "print(i_save,j_save)" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": 45, 811 | "metadata": { 812 | "collapsed": true 813 | }, 814 | "outputs": [], 815 | "source": [ 816 | "model_path = 'align_char'" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": 46, 822 | "metadata": { 823 | "collapsed": true 824 | }, 825 | "outputs": [], 826 | "source": [ 827 | "os.mkdir('middleresult/{}'.format(model_path))\n", 828 | "os.mkdir('eval/{}'.format(model_path))" 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 49, 834 | "metadata": { 835 | "scrolled": true 836 | }, 837 | "outputs": [ 838 | { 839 | "name": "stdout", 840 | "output_type": "stream", 841 | "text": [ 842 | "validating... 100.00 % [==================================================>] 300/300 \t used:106s eta:0 s7445/29780 \t used:1984s eta:5951 s\n", 843 | "iter 13 step 7445 train loss 51.49441909790039 train acc 0.3890546875 test loss 56.52604675292969 test acc 0.3802877604166666 bleu 0.20683722390776155 lr 0.0625\n", 844 | "\n", 845 | "validating... 100.00 % [==================================================>] 300/300 \t used:105s eta:0 s14890/29780 \t used:4461s eta:4460 s\n", 846 | "iter 13 step 14890 train loss 51.49875259399414 train acc 0.38941145833333335 test loss 56.567684173583984 test acc 0.38301171875000006 bleu 0.2169776407879747 lr 0.03125\n", 847 | "\n", 848 | "validating... 100.00 % [==================================================>] 300/300 \t used:103s eta:0 s22335/29780 \t used:6938s eta:2312 ss\n", 849 | "iter 13 step 22335 train loss 51.54439163208008 train acc 0.39030338541666665 test loss 56.541290283203125 test acc 0.3805 bleu 0.22007911262894841 lr 0.03125\n", 850 | "\n", 851 | "validating... 100.00 % [==================================================>] 300/300 \t used:104s eta:0 s 7445/29780 \t used:2240s eta:6721 ss\n", 852 | "iter 14 step 7445 train loss 51.224857330322266 train acc 0.3899205729166667 test loss 56.82951736450195 test acc 0.378453125 bleu 0.21533934461057255 lr 0.015625\n", 853 | "\n", 854 | "validating... 100.00 % [==================================================>] 300/300 \t used:104s eta:0 s] 14890/29780 \t used:4706s eta:4706 s\n", 855 | "iter 14 step 14890 train loss 51.121070861816406 train acc 0.3904127604166666 test loss 56.48078155517578 test acc 0.3814088541666667 bleu 0.218569247826193 lr 0.0078125\n", 856 | "\n", 857 | "validating... 100.00 % [==================================================>] 300/300 \t used:106s eta:0 s] 22335/29780 \t used:7177s eta:2392 ss\n", 858 | "iter 14 step 22335 train loss 51.33269500732422 train acc 0.3897239583333333 test loss 56.58401107788086 test acc 0.38013671875 bleu 0.21346233125359212 lr 0.0078125\n", 859 | "\n", 860 | "validating... 100.00 % [==================================================>] 300/300 \t used:101s eta:0 s-] 7445/29780 \t used:2236s eta:6709 ss\n", 861 | "iter 15 step 7445 train loss 51.140235900878906 train acc 0.39041406250000005 test loss 56.441165924072266 test acc 0.379875 bleu 0.22126377277611756 lr 0.00390625\n", 862 | "\n", 863 | "validating... 100.00 % [==================================================>] 300/300 \t used:105s eta:0 s-] 14890/29780 \t used:4700s eta:4700 ss\n", 864 | "iter 15 step 14890 train loss 51.05939865112305 train acc 0.3896588541666667 test loss 56.73957061767578 test acc 0.38069791666666664 bleu 0.21723252877675084 lr 0.001953125\n", 865 | "\n", 866 | "validating... 100.00 % [==================================================>] 300/300 \t used:101s eta:0 s--] 22335/29780 \t used:7168s eta:2389 ss\n", 867 | "iter 15 step 22335 train loss 51.1295166015625 train acc 0.3886432291666666 test loss 56.51940155029297 test acc 0.3808815104166666 bleu 0.2187056862095337 lr 0.001953125\n", 868 | "\n", 869 | "validating... 100.00 % [==================================================>] 300/300 \t used:107s eta:0 s----] 7445/29780 \t used:2228s eta:6686 s\n", 870 | "iter 16 step 7445 train loss 51.121768951416016 train acc 0.3897473958333333 test loss 56.849037170410156 test acc 0.37901822916666666 bleu 0.21269532562586332 lr 0.0009765625\n", 871 | "\n", 872 | "validating... 100.00 % [==================================================>] 300/300 \t used:106s eta:0 s---] 14890/29780 \t used:4695s eta:4695 ss\n", 873 | "iter 16 step 14890 train loss 51.05976486206055 train acc 0.38977734375 test loss 56.51028823852539 test acc 0.3819075520833333 bleu 0.21460443822701872 lr 0.00048828125\n", 874 | "\n", 875 | "validating... 100.00 % [==================================================>] 300/300 \t used:105s eta:0 s-----] 22335/29780 \t used:7164s eta:2388 s\n", 876 | "iter 16 step 22335 train loss 51.14641189575195 train acc 0.3891145833333333 test loss 56.565284729003906 test acc 0.38264583333333335 bleu 0.21811796139827241 lr 0.00048828125\n", 877 | "\n", 878 | "validating... 100.00 % [==================================================>] 300/300 \t used:106s eta:0 s-----] 7445/29780 \t used:2229s eta:6687 ss\n", 879 | "iter 17 step 7445 train loss 51.078975677490234 train acc 0.3900143229166667 test loss 56.53923034667969 test acc 0.37984375 bleu 0.2164484593871533 lr 0.000244140625\n", 880 | "\n", 881 | "validating... 100.00 % [==================================================>] 300/300 \t used:103s eta:0 s------] 14890/29780 \t used:4702s eta:4701 s\n", 882 | "iter 17 step 14890 train loss 51.1338005065918 train acc 0.3888997395833334 test loss 56.62047576904297 test acc 0.3807604166666667 bleu 0.21641832416101056 lr 0.0001220703125\n", 883 | "\n", 884 | "validating... 100.00 % [==================================================>] 300/300 \t used:103s eta:0 s------] 22335/29780 \t used:7172s eta:2390 ss\n", 885 | "iter 17 step 22335 train loss 51.171836853027344 train acc 0.39005989583333334 test loss 56.64115905761719 test acc 0.3802877604166667 bleu 0.21857356474866954 lr 0.0001220703125\n", 886 | "\n", 887 | "iter 17 loss:52.313621520996094 lr:0.0001220703125 81.34 % [========================================>----------] 24223/29780 \t used:7968s eta:1827 s" 888 | ] 889 | }, 890 | { 891 | "ename": "KeyboardInterrupt", 892 | "evalue": "", 893 | "output_type": "error", 894 | "traceback": [ 895 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 896 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 897 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 27\u001b[0m np.ones((batch_y.shape[0],1),dtype=np.int) * ch2ind[''],batch_y[:,:-1]) ,axis=1)\n\u001b[1;32m 28\u001b[0m \u001b[1;33m,\u001b[0m\u001b[0my_real_len\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mby\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mx_real_len\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mbx\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m })\n\u001b[1;32m 31\u001b[0m \u001b[0mtrain_loss_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 898 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 789\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 899 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 995\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 996\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 997\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 998\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 900 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1131\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1132\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1134\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 901 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1137\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1140\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1141\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 902 | "\u001b[0;32mC:\\Program Files\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1119\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1120\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1121\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1122\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m 1123\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 903 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 904 | ] 905 | } 906 | ], 907 | "source": [ 908 | "n_epoch = 15\n", 909 | "restore = True\n", 910 | "#lr = 1\n", 911 | "for i in range(i_save,n_epoch):\n", 912 | " \n", 913 | " i_save = i\n", 914 | " worksum = int(len(train_y)/batch_size)\n", 915 | " pb = ProgressBar(worksum=worksum)\n", 916 | " pb.startjob()\n", 917 | " train_loss_list = []\n", 918 | " train_acc_list = []\n", 919 | " for j in range(worksum):\n", 920 | " if restore == True and j < j_save:\n", 921 | " pb.finishsum += 1\n", 922 | " continue\n", 923 | " restore = False\n", 924 | " \n", 925 | " j_save = j\n", 926 | " batch_x,batch_y = train_set.next_batch(batch_size)\n", 927 | " lx = [seq_max_len] * batch_size\n", 928 | " ly = [seq_max_len] * batch_size\n", 929 | " bx = [np.sum(m > 0) for m in batch_x]\n", 930 | " by = [np.sum(m > 0) for m in batch_y]\n", 931 | " by =[m + 2 if m < seq_max_len - 1 else m for m in by ]\n", 932 | " _, loss = session.run([optimizer,train_loss],feed_dict={x:batch_x,y:batch_y,x_len:lx,y_len:ly,learning_rate:lr,y_in:\n", 933 | " np.concatenate((\n", 934 | " np.ones((batch_y.shape[0],1),dtype=np.int) * ch2ind[''],batch_y[:,:-1]) ,axis=1)\n", 935 | " ,y_real_len:by,\n", 936 | " x_real_len:bx\n", 937 | " })\n", 938 | " train_loss_list.append(loss)\n", 939 | " #tmp_train_acc = cal_acc(tran,batch_y)\n", 940 | " #train_acc_list.append(tmp_train_acc)\n", 941 | " pb.info = \"iter {} loss:{} lr:{}\".format(i + 1,loss,lr)\n", 942 | " val_step = int(worksum / 4)\n", 943 | " if j % val_step == 0 and j != 0:\n", 944 | " test_loss,test_acc,bleu_score,predict_list,target_list,source_list = calc_test_loss()\n", 945 | " _,train_acc,train_bleu_score,train_predict_list,train_target_list,train_source_list = calc_test_loss(Dataset(train_x[::100],train_y[::100]),display=False)\n", 946 | " predict_texts = get_all_text(predict_list)\n", 947 | " target_texts = get_all_text(target_list)\n", 948 | " source_texts = get_all_en_text(source_list)\n", 949 | " \n", 950 | " train_predict_texts = get_all_text(train_predict_list)\n", 951 | " train_target_texts = get_all_text(train_target_list)\n", 952 | " train_source_texts = get_all_en_text(train_source_list)\n", 953 | " \n", 954 | " with open('eval/{}/{}_{}_predict'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 955 | " for line in predict_texts:\n", 956 | " whdl.write(\"{}\\n\".format(line))\n", 957 | " with open('eval/{}/{}_{}_target'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 958 | " for line in target_texts:\n", 959 | " whdl.write(\"{}\\n\".format(line))\n", 960 | " with open('eval/{}/{}_{}_source'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 961 | " for line in source_texts:\n", 962 | " whdl.write(\"{}\\n\".format(line))\n", 963 | " \n", 964 | " with open('eval/{}/{}_{}_predict_train'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 965 | " for line in train_predict_texts:\n", 966 | " whdl.write(\"{}\\n\".format(line))\n", 967 | " with open('eval/{}/{}_{}_target_train'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 968 | " for line in train_target_texts:\n", 969 | " whdl.write(\"{}\\n\".format(line))\n", 970 | " with open('eval/{}/{}_{}_source_train'.format(model_path,i + 1,j),'w',encoding='utf-8') as whdl:\n", 971 | " for line in train_source_texts:\n", 972 | " whdl.write(\"{}\\n\".format(line))\n", 973 | " print(\"\\niter {} step {} train loss {} train acc {} test loss {} test acc {} bleu {} lr {}\\n\".format(i+1,j,np.average(train_loss_list[-val_step:]),train_acc,test_loss,test_acc,bleu_score,lr))\n", 974 | " try:\n", 975 | " saver = tf.train.Saver()\n", 976 | " saver.save(session,'middleresult/{}/result_{}_{}'.format(model_path,i + 1,j))\n", 977 | " except:\n", 978 | " print('save fail')\n", 979 | " lr_step = int(worksum / 2) - 1\n", 980 | " if j % lr_step == 0 and j != 0:\n", 981 | " if (i + 1) > 10:\n", 982 | " lr = lr / 2\n", 983 | " pb.complete(1)" 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": 50, 989 | "metadata": {}, 990 | "outputs": [ 991 | { 992 | "data": { 993 | "text/plain": [ 994 | "" 995 | ] 996 | }, 997 | "execution_count": 50, 998 | "metadata": {}, 999 | "output_type": "execute_result" 1000 | } 1001 | ], 1002 | "source": [ 1003 | "session" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "code", 1008 | "execution_count": 51, 1009 | "metadata": {}, 1010 | "outputs": [ 1011 | { 1012 | "data": { 1013 | "text/plain": [ 1014 | "40" 1015 | ] 1016 | }, 1017 | "execution_count": 51, 1018 | "metadata": {}, 1019 | "output_type": "execute_result" 1020 | } 1021 | ], 1022 | "source": [ 1023 | "len([ind2en.get(i,'') for i in test_x[1]])" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "markdown", 1028 | "metadata": {}, 1029 | "source": [ 1030 | "# try to translate" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "code", 1035 | "execution_count": 227, 1036 | "metadata": { 1037 | "collapsed": true, 1038 | "scrolled": true 1039 | }, 1040 | "outputs": [], 1041 | "source": [ 1042 | "sent = 'you are too stupid to know that you are an idiot .'.split()" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": 228, 1048 | "metadata": {}, 1049 | "outputs": [ 1050 | { 1051 | "data": { 1052 | "text/plain": [ 1053 | "['you',\n", 1054 | " 'are',\n", 1055 | " 'too',\n", 1056 | " 'stupid',\n", 1057 | " 'to',\n", 1058 | " 'know',\n", 1059 | " 'that',\n", 1060 | " 'you',\n", 1061 | " 'are',\n", 1062 | " 'an',\n", 1063 | " 'idiot',\n", 1064 | " '.']" 1065 | ] 1066 | }, 1067 | "execution_count": 228, 1068 | "metadata": {}, 1069 | "output_type": "execute_result" 1070 | } 1071 | ], 1072 | "source": [ 1073 | "sent" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "code", 1078 | "execution_count": 229, 1079 | "metadata": { 1080 | "collapsed": true 1081 | }, 1082 | "outputs": [], 1083 | "source": [ 1084 | "sents = [en2ind.get(i) for i in sent]" 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "code", 1089 | "execution_count": 230, 1090 | "metadata": {}, 1091 | "outputs": [], 1092 | "source": [ 1093 | "sents = tf.contrib.keras.preprocessing.sequence.pad_sequences([sents],seq_max_len,padding='post')" 1094 | ] 1095 | }, 1096 | { 1097 | "cell_type": "code", 1098 | "execution_count": 231, 1099 | "metadata": {}, 1100 | "outputs": [ 1101 | { 1102 | "data": { 1103 | "text/plain": [ 1104 | "array([[ 13, 31, 198, 3197, 7, 83, 17, 13, 31, 35, 7122,\n", 1105 | " 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1106 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1107 | " 0, 0, 0, 0, 0, 0, 0]])" 1108 | ] 1109 | }, 1110 | "execution_count": 231, 1111 | "metadata": {}, 1112 | "output_type": "execute_result" 1113 | } 1114 | ], 1115 | "source": [ 1116 | "sents" 1117 | ] 1118 | }, 1119 | { 1120 | "cell_type": "code", 1121 | "execution_count": 232, 1122 | "metadata": {}, 1123 | "outputs": [ 1124 | { 1125 | "data": { 1126 | "text/plain": [ 1127 | "array([[ 13, 31, 198, ..., 0, 0, 0],\n", 1128 | " [ 13, 31, 198, ..., 0, 0, 0],\n", 1129 | " [ 13, 31, 198, ..., 0, 0, 0],\n", 1130 | " ..., \n", 1131 | " [ 13, 31, 198, ..., 0, 0, 0],\n", 1132 | " [ 13, 31, 198, ..., 0, 0, 0],\n", 1133 | " [ 13, 31, 198, ..., 0, 0, 0]])" 1134 | ] 1135 | }, 1136 | "execution_count": 232, 1137 | "metadata": {}, 1138 | "output_type": "execute_result" 1139 | } 1140 | ], 1141 | "source": [ 1142 | "np.repeat(sents,35,axis=0)" 1143 | ] 1144 | }, 1145 | { 1146 | "cell_type": "code", 1147 | "execution_count": 233, 1148 | "metadata": {}, 1149 | "outputs": [ 1150 | { 1151 | "data": { 1152 | "text/plain": [ 1153 | "13" 1154 | ] 1155 | }, 1156 | "execution_count": 233, 1157 | "metadata": {}, 1158 | "output_type": "execute_result" 1159 | } 1160 | ], 1161 | "source": [ 1162 | "sum(sents[0] > 0) + 1" 1163 | ] 1164 | }, 1165 | { 1166 | "cell_type": "code", 1167 | "execution_count": 234, 1168 | "metadata": { 1169 | "collapsed": true 1170 | }, 1171 | "outputs": [], 1172 | "source": [ 1173 | "tran = session.run([translations],feed_dict={x:np.repeat(sents,64,axis=0),x_len:[35] * 64, x_real_len:[sum(sents[0] > 0) + 1] * 64})" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": 235, 1179 | "metadata": {}, 1180 | "outputs": [ 1181 | { 1182 | "data": { 1183 | "text/plain": [ 1184 | "[('你太傻了,知道你是个白痴', 34),\n", 1185 | " ('你太傻了,不知道你是个白痴', 8),\n", 1186 | " ('你太傻了,你就知道你是个白痴', 6),\n", 1187 | " ('你太愚蠢了,知道你是个白痴', 2),\n", 1188 | " ('你太傻了,知道你是白痴', 2)]" 1189 | ] 1190 | }, 1191 | "execution_count": 235, 1192 | "metadata": {}, 1193 | "output_type": "execute_result" 1194 | } 1195 | ], 1196 | "source": [ 1197 | "from collections import Counter\n", 1198 | "Counter(''.join([ind2ch.get(i,'') for i in j]) for j in tran[0]).most_common(5)" 1199 | ] 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "execution_count": 203, 1204 | "metadata": {}, 1205 | "outputs": [ 1206 | { 1207 | "data": { 1208 | "text/plain": [ 1209 | "(50002, 8824)" 1210 | ] 1211 | }, 1212 | "execution_count": 203, 1213 | "metadata": {}, 1214 | "output_type": "execute_result" 1215 | } 1216 | ], 1217 | "source": [ 1218 | "en2ind[''],ch2ind['']" 1219 | ] 1220 | }, 1221 | { 1222 | "cell_type": "markdown", 1223 | "metadata": {}, 1224 | "source": [ 1225 | "# release the model" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "code", 1230 | "execution_count": 135, 1231 | "metadata": {}, 1232 | "outputs": [ 1233 | { 1234 | "name": "stderr", 1235 | "output_type": "stream", 1236 | "text": [ 1237 | "A subdirectory or file release already exists.\n" 1238 | ] 1239 | } 1240 | ], 1241 | "source": [ 1242 | "!mkdir release" 1243 | ] 1244 | }, 1245 | { 1246 | "cell_type": "code", 1247 | "execution_count": 136, 1248 | "metadata": { 1249 | "collapsed": true 1250 | }, 1251 | "outputs": [], 1252 | "source": [ 1253 | "os.mkdir('release/align_and_translate_char_50000')" 1254 | ] 1255 | }, 1256 | { 1257 | "cell_type": "code", 1258 | "execution_count": 137, 1259 | "metadata": {}, 1260 | "outputs": [ 1261 | { 1262 | "data": { 1263 | "text/plain": [ 1264 | "'release/align_and_translate_char_50000/align_and_translate_model'" 1265 | ] 1266 | }, 1267 | "execution_count": 137, 1268 | "metadata": {}, 1269 | "output_type": "execute_result" 1270 | } 1271 | ], 1272 | "source": [ 1273 | "saver.save(session,'release/align_and_translate_char_50000/align_and_translate_model')" 1274 | ] 1275 | }, 1276 | { 1277 | "cell_type": "code", 1278 | "execution_count": 207, 1279 | "metadata": {}, 1280 | "outputs": [], 1281 | "source": [ 1282 | "with open('release/align_and_translate_char_50000/dic.pkl','wb') as whdl:\n", 1283 | " pickle.dump(\n", 1284 | " ( \n", 1285 | " ind2ch,\n", 1286 | " ch2ind,\n", 1287 | " ind2en,\n", 1288 | " en2ind,\n", 1289 | " ),whdl,protocol=2)" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": null, 1295 | "metadata": { 1296 | "collapsed": true 1297 | }, 1298 | "outputs": [], 1299 | "source": [] 1300 | } 1301 | ], 1302 | "metadata": { 1303 | "anaconda-cloud": {}, 1304 | "kernelspec": { 1305 | "display_name": "Python [conda root]", 1306 | "language": "python", 1307 | "name": "conda-root-py" 1308 | }, 1309 | "language_info": { 1310 | "codemirror_mode": { 1311 | "name": "ipython", 1312 | "version": 3 1313 | }, 1314 | "file_extension": ".py", 1315 | "mimetype": "text/x-python", 1316 | "name": "python", 1317 | "nbconvert_exporter": "python", 1318 | "pygments_lexer": "ipython3", 1319 | "version": "3.5.2" 1320 | } 1321 | }, 1322 | "nbformat": 4, 1323 | "nbformat_minor": 1 1324 | } 1325 | --------------------------------------------------------------------------------