├── LICENSE ├── split_source_target.py ├── split_train_test_valid.py ├── .gitignore ├── parse_bccwj.py ├── predict_ngram.py ├── evaluate.py ├── predict.py ├── rnn_predictor.py ├── predict_both.py ├── train.py ├── grid_search.py ├── rnn_trainer.py ├── utility.py ├── README.md ├── experiment.py ├── decode_ngram.py ├── decode_both.py └── decode.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Yoh Okuno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /split_source_target.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | 5 | def parse_file(file): 6 | for line in file: 7 | line = line.rstrip('\n') 8 | words = line.split(' ') 9 | yield [word.split('/', 1) for word in words] 10 | 11 | 12 | def split_source_target(sentences): 13 | target = '' 14 | source = '' 15 | for sentence in sentences: 16 | for target_word, source_word in sentence: 17 | target += target_word 18 | source += source_word 19 | target += '\n' 20 | source += '\n' 21 | return target, source 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('input_file', type=argparse.FileType('r')) 27 | parser.add_argument('target_file', type=argparse.FileType('w')) 28 | parser.add_argument('source_file', type=argparse.FileType('w')) 29 | args = parser.parse_args() 30 | 31 | sentences = parse_file(args.input_file) 32 | target, source = split_source_target(sentences) 33 | args.target_file.write(target) 34 | args.source_file.write(source) 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /split_train_test_valid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import random 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('input_file', type=argparse.FileType('r')) 9 | parser.add_argument('train_file', type=argparse.FileType('w')) 10 | parser.add_argument('test_file', type=argparse.FileType('w')) 11 | parser.add_argument('valid_file', type=argparse.FileType('w')) 12 | parser.add_argument('--test_size', type=float, default=0.1) 13 | parser.add_argument('--valid_size', type=float, default=0.01) 14 | parser.add_argument('--seed', type=int, default=0) 15 | args = parser.parse_args() 16 | 17 | input_data = args.input_file.readlines() 18 | 19 | random.seed(args.seed) 20 | random.shuffle(input_data) 21 | 22 | test_index = int(len(input_data) * args.test_size) 23 | test_data = input_data[:test_index] 24 | 25 | valid_index = test_index + int(len(input_data) * args.valid_size) 26 | valid_data = input_data[test_index:valid_index] 27 | train_data = input_data[valid_index:] 28 | 29 | args.train_file.writelines(train_data) 30 | args.test_file.writelines(test_data) 31 | args.valid_file.writelines(valid_data) 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /parse_bccwj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import glob 4 | import xml.etree.ElementTree 5 | 6 | 7 | def parse_sentence(sentence): 8 | for element in sentence.iter('SUW'): 9 | target = ''.join(element.itertext()) 10 | target = target.replace('\n', '') 11 | if ' ' in target: 12 | # Ignore full width space 13 | continue 14 | 15 | if 'kana' in element.attrib: 16 | source = element.attrib['kana'] 17 | else: 18 | source = element.attrib['formBase'] 19 | if source == '': 20 | source = target 21 | 22 | yield target, source 23 | 24 | 25 | def parse_file(filename): 26 | tree = xml.etree.ElementTree.parse(filename) 27 | 28 | for sentence in tree.iter('sentence'): 29 | sentence = list(parse_sentence(sentence)) 30 | if sentence: 31 | yield sentence 32 | 33 | 34 | def parse_pathname(pathname): 35 | for filename in glob.glob(pathname): 36 | corpus = parse_file(filename) 37 | for sentence in corpus: 38 | yield sentence 39 | 40 | 41 | def katakana_to_hiragana(string): 42 | result = '' 43 | for character in string: 44 | code = ord(character) 45 | # if 0x30A1 <= code < = 0x30F6: 46 | if ord('ァ') <= code <= ord('ヶ'): 47 | # result += chr(code + 0x3041 - 0x30A1) 48 | result += chr(code - ord('ァ') + ord('ぁ')) 49 | else: 50 | result += character 51 | return result 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('pathname') 57 | args = parser.parse_args() 58 | 59 | sentences = parse_pathname(args.pathname) 60 | for sentence in sentences: 61 | line = ' '.join('{}/{}'.format(target, katakana_to_hiragana(source)) for target, source in sentence) 62 | print(line) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /predict_ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import math 4 | import collections 5 | 6 | 7 | def parse_srilm(file): 8 | order = 0 9 | ngrams = collections.defaultdict(list) 10 | 11 | for line in file: 12 | line = line.rstrip('\n') 13 | fields = line.split('\t', 2) 14 | 15 | if len(fields) not in (2, 3): 16 | continue 17 | 18 | cost = -math.log(10 ** float(fields[0])) 19 | ngram = fields[1].split(' ') 20 | 21 | if len(ngram) > order: 22 | order = len(ngram) 23 | 24 | context = tuple(ngram[:-1]) 25 | word = ngram[-1] 26 | ngrams[context].append((cost, word)) 27 | 28 | return ngrams, order 29 | 30 | 31 | def predict(ngrams, context): 32 | if context not in ngrams: 33 | return predict(ngrams, context[1:]) 34 | 35 | cost, word = min(ngrams[context]) 36 | return word 37 | 38 | 39 | def match_predictions(ngrams, order, words): 40 | for i in range(1, len(words) - 1): 41 | word = words[i] 42 | context = tuple(words[max(i-order+1, 0):i]) 43 | prediction = predict(ngrams, context) 44 | yield prediction == word 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('test_file', type=argparse.FileType('r')) 50 | parser.add_argument('ngram_file', type=argparse.FileType('r')) 51 | args = parser.parse_args() 52 | 53 | ngrams, order = parse_srilm(args.ngram_file) 54 | 55 | all_predictions = 0 56 | correct_predictions = 0 57 | 58 | for line in args.test_file: 59 | words = line.split(' ') 60 | words = [''] + words + [''] 61 | result = list(match_predictions(ngrams, order, words)) 62 | all_predictions += len(result) 63 | correct_predictions += sum(result) 64 | print(correct_predictions / all_predictions, end='\r') 65 | 66 | print(correct_predictions / all_predictions) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | 5 | def get_common_length(left, right): 6 | # Compute length of the longest common sub-sequence of two strings 7 | table = [[0 for _ in range(len(right) + 1)] for _ in range(len(left) + 1)] 8 | 9 | for i in range(1, len(left) + 1): 10 | for j in range(1, len(right) + 1): 11 | if left[i - 1] == right[j - 1]: 12 | table[i][j] = table[i-1][j-1] + 1 13 | else: 14 | table[i][j] = max(table[i-1][j], table[i][j-1]) 15 | return table[-1][-1] 16 | 17 | 18 | def evaluate(system, reference): 19 | # extract statistics 20 | common_length = sum(get_common_length(r, s) for r, s in zip(reference, system)) 21 | reference_length = len(''.join(reference)) 22 | system_length = len(''.join(system)) 23 | sentence_match = sum(r == s for r, s in zip(reference, system)) 24 | 25 | # calculate metrics 26 | if system_length > 0: 27 | precision = 100. * common_length / system_length 28 | else: 29 | precision = 0. 30 | recall = 100. * common_length / reference_length 31 | fscore = 200. * common_length / (reference_length + system_length) 32 | accuracy = 100. * sentence_match / len(reference) 33 | 34 | # return metrics 35 | return precision, recall, fscore, accuracy 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('system', type=argparse.FileType('r')) 41 | parser.add_argument('reference', type=argparse.FileType('r')) 42 | args = parser.parse_args() 43 | 44 | # load data 45 | system = [line.rstrip('\n') for line in args.system] 46 | reference = [line.rstrip('\n') for line in args.reference] 47 | reference = reference[:len(system)] 48 | 49 | # calculate metrics 50 | metrics = evaluate(system, reference) 51 | 52 | # print metrics 53 | print('precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*metrics)) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import collections 5 | import numpy as np 6 | 7 | from decode import load_settings 8 | from rnn_predictor import RNNPredictor 9 | 10 | 11 | def load_dictionary(model_directory): 12 | vocabulary_path = os.path.join(model_directory, 'vocabulary.txt') 13 | dictionary = collections.defaultdict(int) 14 | for word_id, line in enumerate(open(vocabulary_path)): 15 | word = line.rstrip('\n') 16 | dictionary[word] = word_id 17 | vocabulary = [line.rstrip('\n') for line in open(vocabulary_path)] 18 | return dictionary, vocabulary 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('test_file', type=argparse.FileType('r')) 24 | parser.add_argument('model_directory') 25 | parser.add_argument('--model_file') 26 | args = parser.parse_args() 27 | 28 | # Load settings and vocabulary 29 | settings = load_settings(args.model_directory) 30 | dictionary, vocabulary = load_dictionary(args.model_directory) 31 | 32 | # Create model and load parameters 33 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type) 34 | if args.model_file: 35 | rnn_predictor.restore_from_file(args.model_file) 36 | else: 37 | rnn_predictor.restore_from_directory(args.model_directory) 38 | 39 | all_predictions = 0 40 | correct_predictions = 0 41 | 42 | for line in args.test_file: 43 | line = line.rstrip('\n') 44 | words = line.split(' ') 45 | words = ['_BOS/_BOS'] + words + ['_EOS/_EOS'] 46 | state = None 47 | 48 | for i in range(len(words) - 2): 49 | word_id = dictionary[words[i]] 50 | predictions, state = rnn_predictor.predict([word_id], state) 51 | prediction = vocabulary[np.argmin(predictions[0])] 52 | 53 | if prediction == words[i + 1]: 54 | correct_predictions += 1 55 | all_predictions += 1 56 | 57 | print(correct_predictions / all_predictions, end='\r') 58 | 59 | print(correct_predictions / all_predictions) 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /rnn_predictor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class RNNPredictor: 5 | # TODO: Load meta graph from file and use it instead to support various internal structures 6 | def __init__(self, vocabulary_size, hidden_size, layer_size, cell_type): 7 | with tf.Graph().as_default(): 8 | # Placeholder for test data 9 | self.input = tf.placeholder(tf.int32, [None]) 10 | 11 | # Lookup word embedding 12 | embedding = tf.Variable(tf.zeros([vocabulary_size, hidden_size]), name='embedding') 13 | rnn_inputs = tf.nn.embedding_lookup(embedding, self.input) 14 | 15 | # Create RNN cell 16 | if cell_type == 'lstm': 17 | cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 18 | elif cell_type == 'rnn': 19 | cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 20 | else: 21 | cell = tf.nn.rnn_cell.GRUCell(hidden_size) 22 | 23 | # Stack multiple RNN cells 24 | if layer_size > 1: 25 | cell = tf.nn.rnn_cell.MultiRNNCell([cell] * layer_size) 26 | 27 | self.initial_state = cell.zero_state(1, dtype=tf.float32) 28 | 29 | # Call the RNN cell 30 | with tf.variable_scope('RNN'): 31 | rnn_output, self.next_state = cell(rnn_inputs, self.initial_state) 32 | 33 | # Predict distribution over next word 34 | softmax_w = tf.Variable(tf.zeros([hidden_size, vocabulary_size]), name='softmax_w') 35 | softmax_b = tf.Variable(tf.zeros([vocabulary_size]), name='softmax_b') 36 | logits = tf.matmul(rnn_output, softmax_w) + softmax_b 37 | 38 | # predictions is negative log probability of shape [vocabulary_size] 39 | self.predictions = -tf.nn.log_softmax(logits) 40 | 41 | self.saver = tf.train.Saver(tf.trainable_variables()) 42 | self.session = tf.Session() 43 | 44 | def predict(self, input_value, state_value=None): 45 | if state_value is not None: 46 | feed_dict = {self.input: input_value, self.initial_state: state_value} 47 | else: 48 | feed_dict = {self.input: input_value} 49 | predictions, next_state = self.session.run([self.predictions, self.next_state], feed_dict) 50 | return predictions, next_state 51 | 52 | def restore_from_directory(self, model_directory): 53 | model_path = tf.train.latest_checkpoint(model_directory) 54 | self.saver.restore(self.session, model_path) 55 | 56 | def restore_from_file(self, model_path): 57 | self.saver.restore(self.session, model_path) 58 | 59 | def close(self): 60 | self.session.close() 61 | -------------------------------------------------------------------------------- /predict_both.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import numpy as np 4 | 5 | from decode import load_settings 6 | from decode_ngram import get_ngram_cost, parse_srilm 7 | from predict import load_dictionary 8 | from rnn_predictor import RNNPredictor 9 | 10 | 11 | def match_predictions(rnn_predictor, dictionary, vocabulary, ngrams, words): 12 | state = None 13 | 14 | for i in range(len(words) - 2): 15 | # RNN prediction 16 | word_id = dictionary[words[i]] 17 | predictions, state = rnn_predictor.predict([word_id], state) 18 | rnn_prediction = predictions[0] 19 | 20 | # N-gram prediction 21 | context = words[max(i - 3, 0):i + 1] 22 | ngram_prediction = np.zeros(len(rnn_prediction)) 23 | for word in list(dictionary.values()): 24 | history = tuple(context + [word]) 25 | probability = get_ngram_cost(ngrams, history) 26 | word_id = dictionary[word] 27 | ngram_prediction[word_id] = probability 28 | 29 | interpolation = -np.log((np.exp(-rnn_prediction) + np.exp(-ngram_prediction)) / 2.0) 30 | prediction = vocabulary[np.argmin(interpolation)] 31 | yield prediction == words[i + 1] 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('test_file', type=argparse.FileType('r')) 37 | parser.add_argument('model_directory') 38 | parser.add_argument('ngram_file', type=argparse.FileType('r')) 39 | parser.add_argument('--model_file') 40 | args = parser.parse_args() 41 | 42 | # Load settings and vocabulary 43 | settings = load_settings(args.model_directory) 44 | dictionary, vocabulary = load_dictionary(args.model_directory) 45 | 46 | # Create model and load parameters 47 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type) 48 | if args.model_file: 49 | rnn_predictor.restore_from_file(args.model_file) 50 | else: 51 | rnn_predictor.restore_from_directory(args.model_directory) 52 | 53 | # Load N-gram model 54 | ngrams = parse_srilm(args.ngram_file) 55 | 56 | all_predictions = 0 57 | correct_predictions = 0 58 | 59 | for line in args.test_file: 60 | line = line.rstrip('\n') 61 | words = line.split(' ') 62 | words = ['_BOS/_BOS'] + words + ['_EOS/_EOS'] 63 | 64 | result = list(match_predictions(rnn_predictor, dictionary, vocabulary, ngrams, words)) 65 | all_predictions += len(result) 66 | correct_predictions += sum(result) 67 | print(correct_predictions / all_predictions, end='\r') 68 | 69 | print(correct_predictions / all_predictions) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import math 5 | import os 6 | import time 7 | 8 | from utility import load_train_data 9 | from rnn_trainer import RNNTrainer 10 | 11 | 12 | def train_epoch(rnn_trainer, train_data, model_directory, epoch): 13 | total_loss = 0.0 14 | start_time = time.time() 15 | 16 | for batch, (input_, output_) in enumerate(train_data): 17 | start_time_batch = time.time() 18 | loss, gradient_norm = rnn_trainer.train(input_, output_) 19 | train_time_batch = time.time() - start_time_batch 20 | total_loss += loss 21 | perplexity = math.exp(loss) 22 | print('training batch: {:} perplexity: {:.2f} time: {:.2f}'.format(batch, perplexity, train_time_batch), end='\r') 23 | 24 | train_time = time.time() - start_time 25 | perplexity = math.exp(total_loss / len(train_data)) 26 | 27 | log_text = 'training epoch: {} perplexity: {:.2f} \n'.format(epoch, perplexity) 28 | log_text += 'training total time: {:.2f} average time: {:.2f}'.format(train_time, train_time / len(train_data)) 29 | print(log_text) 30 | 31 | # Save model every epoch 32 | model_path = os.path.join(model_directory, 'model.ckpt') 33 | rnn_trainer.save(model_path, epoch) 34 | 35 | return perplexity, train_time 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('train_file') 41 | parser.add_argument('model_directory') 42 | parser.add_argument('--sentence_size', type=int, default=30) 43 | parser.add_argument('--vocabulary_size', type=int, default=50000) 44 | parser.add_argument('--batch_size', type=int, default=50) 45 | parser.add_argument('--hidden_size', type=int, default=400) 46 | parser.add_argument('--layer_size', type=int, default=1) 47 | parser.add_argument('--epoch_size', type=int, default=10) 48 | parser.add_argument('--clip_norm', type=float, default=5) 49 | parser.add_argument('--keep_prob', type=float, default=0.5) 50 | parser.add_argument('--cell_type', default='gru') 51 | parser.add_argument('--optimizer_type', default='adam') 52 | parser.add_argument('--max_keep', type=int, default=0) 53 | args = parser.parse_args() 54 | 55 | # Print settings 56 | print(json.dumps(vars(args), indent=4)) 57 | 58 | # Load and preprocess training data 59 | train_data = load_train_data(args) 60 | print('number of batches:', len(train_data)) 61 | 62 | # Create RNN model for training 63 | rnn_trainer = RNNTrainer(args.batch_size, args.sentence_size, args.vocabulary_size, args.hidden_size, 64 | args.layer_size, args.cell_type, args.optimizer_type, args.clip_norm, 65 | args.keep_prob, args.max_keep) 66 | 67 | start_time = time.time() 68 | 69 | for epoch in range(args.epoch_size): 70 | # Train one epoch and save the model 71 | train_epoch(rnn_trainer, train_data, args.model_directory, epoch) 72 | 73 | total_time = time.time() - start_time 74 | print('total training time:', total_time) 75 | print() 76 | 77 | rnn_trainer.close() 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /grid_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | import time 6 | import itertools 7 | from experiment import experiment 8 | 9 | 10 | def grid_search(args): 11 | start_time = time.time() 12 | 13 | results = [] 14 | 15 | # split args to default settings and hyperparameters 16 | default_settings = {} 17 | search_settings = {} 18 | for name, value in vars(args).items(): 19 | if type(value) == list and len(value) == 1: 20 | value = value[0] 21 | if type(value) != list: 22 | default_settings[name] = value 23 | else: 24 | search_settings[name] = value 25 | 26 | # Search hyperparameters 27 | for values in itertools.product(*search_settings.values()): 28 | # Merge default and search settings 29 | hyperparameters = dict(zip(search_settings.keys(), values)) 30 | merged_settings = {**default_settings, **hyperparameters} 31 | settings = argparse.Namespace(**merged_settings) 32 | 33 | # Set directory path 34 | directory_name = '-'.join(name + str(value) for name, value in hyperparameters.items()) 35 | settings.model_directory = os.path.join(settings.model_directory, directory_name) 36 | 37 | # Run experiment 38 | metrics, epoch = experiment(settings) 39 | result = (hyperparameters, metrics, epoch) 40 | results.append(result) 41 | 42 | # print best experiment result 43 | hyperparameters, metrics, epoch = max(results, key=lambda x: x[1][2]) 44 | print('best experiment settings in epoch', epoch) 45 | print(json.dumps(hyperparameters, indent=4)) 46 | print('best experiment metrics: precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*metrics)) 47 | 48 | # save all results 49 | results_path = os.path.join(args.model_directory, 'results.json') 50 | json.dump(results, open(results_path, 'w'), indent=4) 51 | 52 | # Print total time 53 | total_time = time.time() - start_time 54 | print('total grid search time:', total_time) 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | # mandatory parameters 60 | parser.add_argument('train_file') 61 | parser.add_argument('valid_target_file') 62 | parser.add_argument('valid_source_file') 63 | parser.add_argument('model_directory') 64 | # grid search parameters 65 | parser.add_argument('--hidden_size', type=int, default=400, nargs='*') 66 | parser.add_argument('--layer_size', type=int, default=1, nargs='*') 67 | parser.add_argument('--keep_prob', type=float, default=0.5, nargs='*') 68 | parser.add_argument('--cell_type', default='gru', nargs='*') 69 | parser.add_argument('--optimizer_type', default='adam', nargs='*') 70 | parser.add_argument('--vocabulary_size', type=int, default=50000, nargs='*') 71 | parser.add_argument('--sentence_size', type=int, default=30, nargs='*') 72 | parser.add_argument('--batch_size', type=int, default=50, nargs='*') 73 | parser.add_argument('--clip_norm', type=float, default=5, nargs='*') 74 | # optional parameters 75 | parser.add_argument('--epoch_size', type=int, default=10) 76 | parser.add_argument('--max_keep', type=int, default=0) 77 | parser.add_argument('--beam_size', type=int, default=5) 78 | parser.add_argument('--viterbi_size', type=int, default=1) 79 | args = parser.parse_args() 80 | 81 | # Run grid search. This might take long time to complete. 82 | grid_search(args) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /rnn_trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class RNNTrainer: 5 | # TODO split to separate functions 6 | # TODO support truncated BPTT 7 | def __init__(self, batch_size, sentence_size, vocabulary_size, hidden_size, layer_size, 8 | cell_type, optimizer_type, clip_norm, keep_prob, max_keep): 9 | 10 | with tf.Graph().as_default(): 11 | # Placeholders for training data 12 | self.input = tf.placeholder(tf.int32, [batch_size, sentence_size]) 13 | self.output = tf.placeholder(tf.int32, [batch_size, sentence_size]) 14 | 15 | # Lookup word embedding 16 | embedding = tf.Variable(tf.truncated_normal([vocabulary_size, hidden_size], stddev=0.01), name='embedding') 17 | inputs = tf.nn.embedding_lookup(embedding, self.input) 18 | inputs = tf.nn.dropout(inputs, keep_prob) 19 | 20 | # Create and connect RNN cells 21 | if cell_type == 'lstm': 22 | cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 23 | elif cell_type == 'rnn': 24 | cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 25 | else: 26 | cell = tf.nn.rnn_cell.GRUCell(hidden_size) 27 | 28 | # Dropout output of the RNN cell 29 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=keep_prob) 30 | 31 | # Stack multiple RNN cells 32 | if layer_size > 1: 33 | cell = tf.nn.rnn_cell.MultiRNNCell([cell] * layer_size) 34 | 35 | rnn_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, sentence_size, inputs)] 36 | outputs, _ = tf.nn.rnn(cell, rnn_inputs, dtype=tf.float32) 37 | 38 | # Predict distribution over next word 39 | output = tf.reshape(tf.concat(1, outputs), [-1, hidden_size]) 40 | softmax_w = tf.Variable(tf.truncated_normal([hidden_size, vocabulary_size], stddev=0.01), name='softmax_w') 41 | softmax_b = tf.Variable(tf.zeros(shape=[vocabulary_size]), name='softmax_b') 42 | logits = tf.matmul(output, softmax_w) + softmax_b 43 | 44 | # Define loss function and optimizer 45 | self.loss = tf.nn.seq2seq.sequence_loss( 46 | [logits], 47 | [tf.reshape(self.output, [-1])], 48 | [tf.ones([batch_size * sentence_size], dtype=tf.float32)]) 49 | 50 | # Apply gradient clipping to address gradient explosion 51 | trainable_variables = tf.trainable_variables() 52 | gradients = tf.gradients(self.loss, trainable_variables) 53 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, clip_norm) 54 | if optimizer_type == 'sgd': 55 | optimizer = tf.train.GradientDescentOptimizer(0.01) 56 | elif optimizer_type == 'adagrad': 57 | optimizer = tf.train.AdagradOptimizer(0.01) 58 | elif optimizer_type == 'rmsprop': 59 | optimizer = tf.train.RMSPropOptimizer(0.01) 60 | else: 61 | optimizer = tf.train.AdamOptimizer() 62 | self.train_step = optimizer.apply_gradients(zip(clipped_gradients, trainable_variables)) 63 | 64 | # Keep latest max_keep checkpoints 65 | self.saver = tf.train.Saver(trainable_variables, max_to_keep=max_keep) 66 | self.session = tf.Session() 67 | self.session.run(tf.initialize_all_variables()) 68 | 69 | def train(self, input_, output_): 70 | _, loss, gradient_norm = self.session.run([self.train_step, self.loss, self.gradient_norm], 71 | {self.input: input_, self.output: output_}) 72 | return loss, gradient_norm 73 | 74 | def save(self, model_path, epoch): 75 | self.saver.save(self.session, model_path, global_step=epoch) 76 | 77 | def close(self): 78 | self.session.close() 79 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import os 3 | import json 4 | 5 | 6 | def parse_file(file): 7 | for line in file: 8 | line = line.rstrip('\n') 9 | sentence = line.split(' ') 10 | yield sentence 11 | 12 | 13 | # TODO: current method does not allow the model to learn boundary beyond bigram. 14 | def adjust_size(sentences, sentence_size): 15 | # Increment sentence size for shifting output later 16 | sentence_size_plus = sentence_size + 1 17 | 18 | for sentence in sentences: 19 | # Insert BOS = Beginning Of Sentence 20 | sentence.insert(0, '_BOS/_BOS') 21 | 22 | # Split long sentence allowing overlap of 1 word 23 | while len(sentence) >= sentence_size_plus: 24 | yield sentence[:sentence_size_plus] 25 | sentence = sentence[sentence_size:] 26 | 27 | # Do not yield EOS-only sentence 28 | if sentence: 29 | # Insert EOS = End Of Sentence 30 | sentence.append('_EOS/_EOS') 31 | 32 | if len(sentence) < sentence_size_plus: 33 | # Padding sentence to make its size sentence_size_plus 34 | sentence += ['_PAD/_PAD'] * (sentence_size_plus - len(sentence)) 35 | yield sentence 36 | 37 | 38 | def create_vocabulary(sentences, vocabulary_size): 39 | # Create list of words indexed by word ID 40 | counter = Counter(word for words in sentences for word in words) 41 | most_common = counter.most_common(vocabulary_size - 1) 42 | vocabulary = [word for word, count in most_common] 43 | vocabulary.insert(0, '_UNK/_UNK') 44 | return vocabulary 45 | 46 | 47 | def convert_to_ids(sentences, vocabulary): 48 | dictionary = dict((word, word_id) for word_id, word in enumerate(vocabulary)) 49 | 50 | for sentence in sentences: 51 | word_ids = [] 52 | 53 | for word in sentence: 54 | if word in dictionary: 55 | word_id = dictionary[word] 56 | else: 57 | word_id = dictionary['_UNK/_UNK'] 58 | word_ids.append(word_id) 59 | 60 | yield word_ids 61 | 62 | 63 | # TODO: current batching ignores sentences that does't fit into last batch. 64 | def create_batches(sentences, batch_size): 65 | all_batches = int(len(sentences) / batch_size) 66 | 67 | for i in range(all_batches): 68 | batch_sentences = sentences[i * batch_size:(i + 1) * batch_size] 69 | batch_input = [] 70 | batch_output = [] 71 | 72 | for sentence in batch_sentences: 73 | # Shift sentence by 1 time step 74 | input_ = sentence[:-1] 75 | output_ = sentence[1:] 76 | 77 | batch_input.append(input_) 78 | batch_output.append(output_) 79 | 80 | yield batch_input, batch_output 81 | 82 | 83 | def save_metadata(args, vocabulary): 84 | # Create directory if not exists 85 | if not os.path.exists(args.model_directory): 86 | os.makedirs(args.model_directory) 87 | 88 | # Save settings 89 | settings_path = os.path.join(args.model_directory, 'settings.json') 90 | with open(settings_path, 'w') as settings_file: 91 | json.dump(vars(args), settings_file, indent=4) 92 | 93 | # Save vocabulary 94 | vocabulary_path = os.path.join(args.model_directory, 'vocabulary.txt') 95 | with open(vocabulary_path, 'w') as vocabulary_file: 96 | vocabulary_file.write('\n'.join(vocabulary)) 97 | 98 | 99 | def load_train_data(args): 100 | sentences = parse_file(open(args.train_file)) 101 | sentences = list(adjust_size(sentences, args.sentence_size)) 102 | vocabulary = create_vocabulary(sentences, args.vocabulary_size) 103 | sentences = list(convert_to_ids(sentences, vocabulary)) 104 | train_data = list(create_batches(sentences, args.batch_size)) 105 | save_metadata(args, vocabulary) 106 | return train_data 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural IME: Neural Input Method Engine 2 | Japanese input method engine can enter next level with deep learning technology. 3 | 4 | # Prerequisite 5 | You need following software to use Neural IME. 6 | 7 | * Python 3.5 8 | * TensorFlow 0.10 9 | 10 | The developer uses Mac OS X 10.11.4, Anaconda 4.1, PyCharm 5.0.4 but it should work elsewhere. 11 | 12 | 13 | # Experimental results 14 | The neural model outperformed N-gram model on reference corpus as shown below. 15 | 16 | | Metrics | N-gram | RNN | 17 | |:-------------------:|:------:|:---------:| 18 | | Sentence Accuracy | 41.5% | **44.2%** | 19 | | Prediction Accuracy | 22.9% | **26.7%** | 20 | 21 | # Training your own models 22 | For training and testing your own models, you need annotated data 23 | such as DVD version of BCCWJ (Balanced Corpus of Contemporary Written Japanese). 24 | 25 | > http://pj.ninjal.ac.jp/corpus_center/bccwj/en/ 26 | 27 | You probably want a modern GPU to train faster, as the developer uses p2.xlarge instance in AWS. 28 | 29 | ## Preparing your data 30 | Training data is text file in UTF-8 and each line corresponds to a sentence. 31 | A sentence is segmented by space character into words, and a word is a pair of *target* 32 | (i.e. Kanji, Hiragana or Katakana) and *source* (Hiragana), concatenated by slash character. 33 | 34 | > 私/わたし の/の 名前/なまえ は/は 中野/なかの です/です 。/。 35 | 36 | Test data should contain different sentences from training data, but ideally its domain is same to training data. 37 | Source file should contain source sentences without space. 38 | 39 | > きょうのてんきははれです。 40 | 41 | Target file should contain target sentences without space. 42 | 43 | > 今日の天気は晴れです。 44 | 45 | ## Pre-processing BCCWJ 46 | The developer uses human-annotated part of BCCWJ as training and testing corpus. 47 | You can use the scripts in this repository to pre-process the XML files after extracted from compressed file. 48 | For example, the following commands parse and split data to train, test source and test target files. 49 | 50 | parse_bccwj.py 'BCCWJ/CORE/M-XML/*.xml' > parsed.txt 51 | 52 | split_train_test_valid.py parsed.txt train.txt test.txt valid.txt 53 | 54 | split_source_target.py test.txt test.target.txt test.source.txt 55 | 56 | 57 | ## Training neural models 58 | Now you can train your own model with default parameters. 59 | 60 | train.py train.txt model 61 | 62 | See help for optional parameters such as number of hidden units and dropout probability. 63 | 64 | train.py --help 65 | 66 | ## Decoding sentences 67 | Once trained your model, you can decode sentences using it. 68 | 69 | decode.py model 70 | 71 | Type source sentence on your console, it will show decoded sentence like this. 72 | 73 | きょうのてんきははれです。 74 | 今日の天気は晴れです。 75 | きょじんにせんせい 76 | 巨人に先制 77 | 78 | Alternatively, you can give file names as input or output. 79 | 80 | decode.py model --input_file test.source.txt --output_file model/test.decode.txt 81 | 82 | You can trade decoding time with accuracy by tuning pruning parameters such as beam size and viterbi size. 83 | For example, the following option is faster than default beam size 5 but less accurate. 84 | 85 | decode.py model --beam_size 1 86 | 87 | ## Evaluating results 88 | You can evaluate decoded results if you have target sentences as reference. 89 | 90 | evaluate.py model/test.decode.txt test.target.txt 91 | 92 | This command gives something like this: 93 | 94 | > precision: 93.59 recall: 93.58 f-score: 93.59 accuracy: 34.06 95 | 96 | Precision, recall and F-score are character-based metrics based on longest common subsequence, 97 | and accuracy is a sentence-level metric. 98 | 99 | ## Hyperparameter search 100 | You can use grid search script to find best hyperparameters. 101 | 102 | grid_search.py train.txt valid.source.txt valid.target.txt model --hidden_size 50 100 200 400 103 | 104 | ## Training N-gram models 105 | In order to train N-gram models as baseline for comparing with neural models, you need to install and use SRILM toolkit. 106 | 107 | > http://www.speech.sri.com/projects/srilm/ 108 | 109 | Once installed, you can run the following command to train the model. 110 | 111 | ngram-count -text train.txt -lm ngram.txt -kndiscount -order 2 112 | 113 | Now you can decode sentences using the N-gram model. 114 | 115 | decode_ngram.py ngram.txt 116 | 117 | Or you can combine both neural model and N-gram model. 118 | 119 | decode_both.py neural_model ngram.txt 120 | 121 | 122 | ## Reference 123 | > Yoh Okuno, Neural IME: Neural Input Method Engine, The 8th Input Method Workshop, 2016. 124 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | import time 6 | 7 | from utility import load_train_data 8 | from rnn_trainer import RNNTrainer 9 | from train import train_epoch 10 | from rnn_predictor import RNNPredictor 11 | from decode import load_dictionary, decode 12 | from evaluate import evaluate 13 | 14 | 15 | def decode_all(rnn_predictor, valid_source_data, dictionary, beam_size, viterbi_size): 16 | start_time = time.time() 17 | system = [] 18 | for i, source in enumerate(valid_source_data): 19 | start_time_sentence = time.time() 20 | top_result, _, _, _ = decode(source, dictionary, rnn_predictor, beam_size, viterbi_size) 21 | decode_time_sentence = time.time() - start_time_sentence 22 | print('decoding sentence: {} time: {:.2f}'.format(i, decode_time_sentence), end='\r') 23 | system.append(top_result) 24 | 25 | decode_time = time.time() - start_time 26 | return system, decode_time 27 | 28 | 29 | def train(rnn_trainer, rnn_predictor, train_data, valid_target_data, valid_source_data, dictionary, 30 | epoch_size, model_directory, beam_size, viterbi_size): 31 | start_time = time.time() 32 | log_path = os.path.join(model_directory, 'log.txt') 33 | log_file = open(log_path, 'w') 34 | best_epoch = None 35 | best_metrics = None 36 | 37 | for epoch in range(epoch_size): 38 | # Train one epoch and save the model 39 | train_epoch(rnn_trainer, train_data, model_directory, epoch) 40 | 41 | # Decode all sentences 42 | rnn_predictor.restore_from_directory(model_directory) 43 | system, decode_time = decode_all(rnn_predictor, valid_source_data, dictionary, beam_size, viterbi_size) 44 | 45 | # Evaluate results 46 | metrics = evaluate(system, valid_target_data) 47 | 48 | # Print metrics 49 | log_text = 'decoding precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}\n'.format(*metrics) 50 | log_text += 'decoding total time: {:.2f} average time: {:.2f}'.format(decode_time, decode_time / len(system)) 51 | print(log_text) 52 | print(log_text, file=log_file) 53 | 54 | # Write decoded results to file 55 | decode_path = os.path.join(model_directory, 'decode-{}.txt'.format(epoch)) 56 | with open(decode_path, 'w') as file: 57 | file.write('\n'.join(system)) 58 | 59 | # Update best epoch 60 | if not best_epoch or best_metrics[2] < metrics[2]: 61 | best_epoch = epoch 62 | best_metrics = metrics 63 | 64 | total_time = time.time() - start_time 65 | print('best epoch:', best_epoch) 66 | print('best epoch metrics: precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*best_metrics)) 67 | print('total experiment time:', total_time) 68 | print() 69 | return best_metrics, best_epoch 70 | 71 | 72 | def experiment(settings): 73 | # Print settings 74 | print(json.dumps(vars(settings), indent=4)) 75 | 76 | # Load and preprocess training data 77 | train_data = load_train_data(settings) 78 | print('number of batches:', len(train_data)) 79 | 80 | # Load validation data 81 | valid_target_data = [line.rstrip('\n') for line in open(settings.valid_target_file)] 82 | valid_source_data = [line.rstrip('\n') for line in open(settings.valid_source_file)] 83 | 84 | # Load dictionary for decoding 85 | dictionary = load_dictionary(settings.model_directory) 86 | 87 | # Create RNN model for training 88 | rnn_trainer = RNNTrainer(settings.batch_size, settings.sentence_size, settings.vocabulary_size, settings.hidden_size, 89 | settings.layer_size, settings.cell_type, settings.optimizer_type, settings.clip_norm, 90 | settings.keep_prob, settings.max_keep) 91 | 92 | # Create RNN model for prediction 93 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type) 94 | 95 | # Run experiment 96 | result = train(rnn_trainer, rnn_predictor, train_data, valid_target_data, valid_source_data, dictionary, 97 | settings.epoch_size, settings.model_directory, settings.beam_size, settings.viterbi_size) 98 | 99 | rnn_trainer.close() 100 | rnn_predictor.close() 101 | return result 102 | 103 | 104 | def main(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('train_file') 107 | parser.add_argument('valid_target_file') 108 | parser.add_argument('valid_source_file') 109 | parser.add_argument('model_directory') 110 | parser.add_argument('--sentence_size', type=int, default=30) 111 | parser.add_argument('--vocabulary_size', type=int, default=50000) 112 | parser.add_argument('--batch_size', type=int, default=50) 113 | parser.add_argument('--hidden_size', type=int, default=400) 114 | parser.add_argument('--layer_size', type=int, default=1) 115 | parser.add_argument('--epoch_size', type=int, default=10) 116 | parser.add_argument('--clip_norm', type=float, default=5) 117 | parser.add_argument('--keep_prob', type=float, default=0.5) 118 | parser.add_argument('--cell_type', default='gru') 119 | parser.add_argument('--optimizer_type', default='adam') 120 | parser.add_argument('--max_keep', type=int, default=0) 121 | parser.add_argument('--beam_size', type=int, default=5) 122 | parser.add_argument('--viterbi_size', type=int, default=1) 123 | args = parser.parse_args() 124 | experiment(args) 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /decode_ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import sys 4 | import heapq 5 | import operator 6 | import math 7 | from collections import defaultdict 8 | 9 | 10 | def parse_ngram(ngram): 11 | for word in ngram.split(' '): 12 | if word == '': 13 | yield '_BOS', '_BOS' 14 | elif word == '': 15 | yield '_EOS', '_EOS' 16 | else: 17 | yield tuple(word.split('/', 1)) 18 | 19 | 20 | def parse_srilm(file): 21 | ngrams = {} 22 | for line in file: 23 | line = line.rstrip('\n') 24 | fields = line.split('\t', 2) 25 | 26 | if len(fields) < 2: 27 | continue 28 | 29 | if len(fields) == 2: 30 | logprob, ngram = fields 31 | backoff = None 32 | elif len(fields) == 3: 33 | logprob, ngram, backoff = fields 34 | backoff = -math.log(10 ** float(backoff)) 35 | cost = -math.log(10 ** float(logprob)) 36 | ngram = tuple(parse_ngram(ngram)) 37 | ngrams[ngram] = (cost, backoff) 38 | return ngrams 39 | 40 | 41 | def create_dictionary(ngrams): 42 | dictionary = defaultdict(list) 43 | for ngram in ngrams.keys(): 44 | if len(ngram) == 1: 45 | target, source = ngram[0] 46 | dictionary[source].append(target) 47 | return dictionary 48 | 49 | 50 | def create_lattice(input_, dictionary): 51 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)] 52 | 53 | for i in range(1, len(input_) + 1): 54 | for j in range(i): 55 | source = input_[j:i] 56 | if source in dictionary: 57 | for target in dictionary[source]: 58 | lattice[i][j].append((target, source)) 59 | elif len(source) == 1: 60 | lattice[i][j].append((source, source)) 61 | 62 | lattice[-1][-1].append(('_EOS', '_EOS')) 63 | return lattice 64 | 65 | 66 | def initialize_queues(lattice): 67 | # A hypothesis is tuple of (cost, history) 68 | queues = [[] for _ in range(len(lattice))] 69 | bos_hypothesis = (0.0, [('_BOS', '_BOS')]) 70 | queues[0].append(bos_hypothesis) 71 | return queues 72 | 73 | 74 | def get_ngram_cost(ngrams, history): 75 | if type(history) is list: 76 | history = tuple(history) 77 | if history in ngrams: 78 | cost, _ = ngrams[history] 79 | return cost 80 | 81 | if len(history) == 1: 82 | return 100.0 83 | 84 | return get_ngram_cost(ngrams, history[1:]) 85 | 86 | 87 | def search(lattice, ngrams, queues, beam_size, viterbi_size): 88 | for i in range(len(lattice)): 89 | for j in range(len(lattice[i])): 90 | for target, source in lattice[i][j]: 91 | 92 | word_queue = [] 93 | for previous_cost, previous_history in queues[j]: 94 | history = previous_history + [(target, source)] 95 | cost = previous_cost + get_ngram_cost(ngrams, tuple(history[-3:])) 96 | hypothesis = (cost, history) 97 | word_queue.append(hypothesis) 98 | 99 | # prune word_queue to viterbi size 100 | if viterbi_size > 0: 101 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0)) 102 | 103 | queues[i] += word_queue 104 | 105 | # prune queues[i] to beam size 106 | if beam_size > 0: 107 | queues[i] = heapq.nsmallest(beam_size, queues[i], key=operator.itemgetter(0)) 108 | return queues 109 | 110 | 111 | def decode(input_, dictionary, ngrams, beam_size, viterbi_size): 112 | lattice = create_lattice(input_, dictionary) 113 | queues = initialize_queues(lattice) 114 | queue = search(lattice, ngrams, queues, beam_size, viterbi_size) 115 | 116 | candidates = [] 117 | for cost, history in queue[-1]: 118 | result = ''.join(target for target, source in history[1:-1]) 119 | candidates.append((result, cost)) 120 | 121 | top_result = candidates[0][0] 122 | return top_result, candidates, lattice, queues 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('ngram_file') 128 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin) 129 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout) 130 | parser.add_argument('--beam_size', type=int, default=5) 131 | parser.add_argument('--viterbi_size', type=int, default=1) 132 | parser.add_argument('--print_nbest', action='store_true') 133 | parser.add_argument('--print_queue', action='store_true') 134 | parser.add_argument('--print_lattice', action='store_true') 135 | args = parser.parse_args() 136 | 137 | ngrams = parse_srilm(open(args.ngram_file)) 138 | dictionary = create_dictionary(ngrams) 139 | 140 | for line in args.input_file: 141 | line = line.rstrip('\n') 142 | result, candidates, lattice, queues = decode(line, dictionary, ngrams, args.beam_size, args.viterbi_size) 143 | 144 | # Print decoded results 145 | if not args.print_nbest: 146 | print(result, file=args.output_file) 147 | else: 148 | for string, cost in candidates: 149 | print(string, cost, file=args.output_file) 150 | 151 | # Print lattice for debug 152 | if args.print_lattice: 153 | for i in range(len(lattice)): 154 | for j in range(len(lattice[i])): 155 | print('i = {}, j = {}'.format(i, j), file=args.output_file) 156 | for target, source in lattice[i][j]: 157 | print(target, source, file=args.output_file) 158 | 159 | # Print queues for debug 160 | if args.print_queue: 161 | for i, queue in enumerate(queues): 162 | print('queue', i, file=args.output_file) 163 | for cost, history in queue: 164 | print(cost, history, file=args.output_file) 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /decode_both.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import sys 4 | import heapq 5 | import operator 6 | import math 7 | 8 | from rnn_predictor import RNNPredictor 9 | from decode import load_settings, load_dictionary 10 | from decode_ngram import parse_srilm, get_ngram_cost 11 | 12 | 13 | def create_lattice(input_, dictionary): 14 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)] 15 | _, unk_id = dictionary['_UNK'][0] 16 | 17 | for i in range(1, len(input_) + 1): 18 | for j in range(i): 19 | source = input_[j:i] 20 | if source in dictionary: 21 | for target, word_id in dictionary[source]: 22 | lattice[i][j].append((target, source, word_id)) 23 | elif len(source) == 1: 24 | # Create _UNK node with verbatim target when single character key is not found in the dictionary. 25 | lattice[i][j].append((source, source, unk_id)) 26 | 27 | _, eos_id = dictionary['_EOS'][0] 28 | lattice[-1][-1].append(('_EOS', '_EOS', eos_id)) 29 | return lattice 30 | 31 | 32 | def initialize_queues(lattice, rnn_predictor, dictionary): 33 | # A hypothesis is tuple of (cost, history, state, prediction) 34 | _, bos_id = dictionary['_BOS'][0] 35 | bos_predictions, bos_states = rnn_predictor.predict([bos_id]) 36 | bos_hypothesis = (0.0, [('_EOS', '_EOS')], bos_states[0], bos_predictions[0]) 37 | queues = [[] for _ in range(len(lattice))] 38 | queues[0].append(bos_hypothesis) 39 | return queues 40 | 41 | 42 | def interpolate(rnn_cost, ngram_cost): 43 | # Linear interpolation needs to be done in probability space, not log probability space 44 | return -math.log((math.exp(-rnn_cost) + math.exp(-ngram_cost)) / 2.0) 45 | 46 | 47 | def search(lattice, queues, rnn_predictor, ngrams, beam_size, viterbi_size): 48 | # Breadth first search with beam pruning and viterbi-like pruning 49 | for i in range(len(lattice)): 50 | queue = [] 51 | 52 | # create hypotheses without predicting next word 53 | for j in range(len(lattice[i])): 54 | for target, source, word_id in lattice[i][j]: 55 | 56 | word_queue = [] 57 | for previous_cost, previous_history, previous_state, previous_prediction in queues[j]: 58 | history = previous_history + [(target, source)] 59 | cost = previous_cost + interpolate(previous_prediction[word_id], get_ngram_cost(ngrams, history)) 60 | # Temporal hypothesis is tuple of (cost, history, word_id, previous_state) 61 | # Lazy prediction replaces word_id and previous_state to state and prediction 62 | hypothesis = (cost, history, word_id, previous_state) 63 | word_queue.append(hypothesis) 64 | 65 | # prune word_queue to viterbi size 66 | if viterbi_size > 0: 67 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0)) 68 | 69 | queue += word_queue 70 | 71 | # prune queue to beam size 72 | if beam_size > 0: 73 | queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0)) 74 | 75 | # predict next word and state before continue 76 | for cost, history, word_id, previous_state in queue: 77 | predictions, states = rnn_predictor.predict([word_id], [previous_state]) 78 | hypothesis = (cost, history, states[0], predictions[0]) 79 | queues[i].append(hypothesis) 80 | 81 | return queues 82 | 83 | 84 | def decode(source, dictionary, rnn_predictor, ngrams, beam_size, viterbi_size): 85 | lattice = create_lattice(source, dictionary) 86 | queues = initialize_queues(lattice, rnn_predictor, dictionary) 87 | queues = search(lattice, queues, rnn_predictor, ngrams, beam_size, viterbi_size) 88 | 89 | candidates = [] 90 | for cost, history, _, _ in queues[-1]: 91 | result = ''.join(target for target, source in history[1:-1]) 92 | candidates.append((result, cost)) 93 | 94 | top_result = candidates[0][0] 95 | return top_result, candidates, lattice, queues 96 | 97 | 98 | def main(): 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('model_directory') 101 | parser.add_argument('ngram_file') 102 | parser.add_argument('--model_file') 103 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin) 104 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout) 105 | parser.add_argument('--beam_size', type=int, default=5) 106 | parser.add_argument('--viterbi_size', type=int, default=1) 107 | parser.add_argument('--print_nbest', action='store_true') 108 | parser.add_argument('--print_queue', action='store_true') 109 | parser.add_argument('--print_lattice', action='store_true') 110 | args = parser.parse_args() 111 | 112 | # Load settings and vocabulary 113 | settings = load_settings(args.model_directory) 114 | dictionary = load_dictionary(args.model_directory) 115 | 116 | # Create model and load parameters 117 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type) 118 | if args.model_file: 119 | rnn_predictor.restore_from_file(args.model_file) 120 | else: 121 | rnn_predictor.restore_from_directory(args.model_directory) 122 | 123 | # Load ngram file in SRILM format 124 | ngrams = parse_srilm(open(args.ngram_file)) 125 | 126 | # Iterate input file line by line 127 | for line in args.input_file: 128 | line = line.rstrip('\n') 129 | 130 | # Decode - this might take some time 131 | result, candidates, lattice, queues = decode(line, dictionary, rnn_predictor, ngrams, args.beam_size, args.viterbi_size) 132 | 133 | # Print decoded results 134 | if not args.print_nbest: 135 | print(result, file=args.output_file) 136 | else: 137 | for string, cost in candidates: 138 | print(string, cost, file=args.output_file) 139 | 140 | # Print lattice for debug 141 | if args.print_lattice: 142 | for i in range(len(lattice)): 143 | for j in range(len(lattice[i])): 144 | print('i = {}, j = {}'.format(i, j), file=args.output_file) 145 | for target, source, word_id in lattice[i][j]: 146 | print(target, source, word_id, file=args.output_file) 147 | 148 | # Print queues for debug 149 | if args.print_queue: 150 | for i, queue in enumerate(queues): 151 | print('queue', i, file=args.output_file) 152 | for cost, history, state, prediction in queue: 153 | print(cost, history, file=args.output_file) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import sys 4 | import os 5 | import json 6 | import collections 7 | import heapq 8 | import operator 9 | from rnn_predictor import RNNPredictor 10 | 11 | 12 | def load_settings(model_directory): 13 | settings_path = os.path.join(model_directory, 'settings.json') 14 | settings = json.load(open(settings_path)) 15 | return argparse.Namespace(**settings) 16 | 17 | 18 | def load_dictionary(model_directory): 19 | vocabulary_path = os.path.join(model_directory, 'vocabulary.txt') 20 | vocabulary = [] 21 | for line in open(vocabulary_path): 22 | line = line.rstrip('\n') 23 | target, source = line.split('/', 1) 24 | vocabulary.append((target, source)) 25 | 26 | dictionary = collections.defaultdict(list) 27 | for i, (target, source) in enumerate(vocabulary): 28 | dictionary[source].append((target, i)) 29 | 30 | return dictionary 31 | 32 | 33 | def create_lattice(input_, dictionary): 34 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)] 35 | _, unk_id = dictionary['_UNK'][0] 36 | 37 | for i in range(1, len(input_) + 1): 38 | for j in range(i): 39 | key = input_[j:i] 40 | if key in dictionary: 41 | for target, word_id in dictionary[key]: 42 | lattice[i][j].append((target, word_id)) 43 | elif len(key) == 1: 44 | # Create _UNK node with verbatim target when single character key is not found in the dictionary. 45 | lattice[i][j].append((key, unk_id)) 46 | 47 | _, eos_id = dictionary['_EOS'][0] 48 | lattice[-1][-1].append(('', eos_id)) 49 | return lattice 50 | 51 | 52 | def initialize_queues(lattice, rnn_predictor, dictionary): 53 | # Initialize priority queues for keeping hypotheses 54 | # A hypothesis is a tuple of (cost, string, state, prediction) 55 | # cost is total negative log probability 56 | # state.shape == [hidden_size * layer_size] 57 | # prediction.shape == [vocabulary_size] 58 | _, bos_id = dictionary['_BOS'][0] 59 | bos_predictions, bos_states = rnn_predictor.predict([bos_id]) 60 | bos_hypothesis = (0.0, '', bos_states[0], bos_predictions[0]) 61 | queues = [[] for _ in range(len(lattice))] 62 | queues[0].append(bos_hypothesis) 63 | return queues 64 | 65 | 66 | def simple_search(lattice, queues, rnn_predictor, beam_size): 67 | # Simple but slow implementation of beam search 68 | for i in range(len(lattice)): 69 | for j in range(len(lattice[i])): 70 | for target, word_id in lattice[i][j]: 71 | for previous_cost, previous_string, previous_state, previous_prediction in queues[j]: 72 | cost = previous_cost + previous_prediction[word_id] 73 | string = previous_string + target 74 | predictions, states = rnn_predictor.predict([word_id], [previous_state]) 75 | hypothesis = (cost, string, states[0], predictions[0]) 76 | queues[i].append(hypothesis) 77 | 78 | # prune queues[i] to beam size 79 | queues[i] = heapq.nsmallest(beam_size, queues[i], key=operator.itemgetter(0)) 80 | return queues 81 | 82 | 83 | def search(lattice, queues, rnn_predictor, beam_size, viterbi_size): 84 | # Breadth first search with beam pruning and viterbi-like pruning 85 | for i in range(len(lattice)): 86 | queue = [] 87 | 88 | # create hypotheses without predicting next word 89 | for j in range(len(lattice[i])): 90 | for target, word_id in lattice[i][j]: 91 | 92 | word_queue = [] 93 | for previous_cost, previous_string, previous_state, previous_prediction in queues[j]: 94 | cost = previous_cost + previous_prediction[word_id] 95 | string = previous_string + target 96 | hypothesis = (cost, string, word_id, previous_state) 97 | word_queue.append(hypothesis) 98 | 99 | # prune word_queue to viterbi size 100 | if viterbi_size > 0: 101 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0)) 102 | 103 | queue += word_queue 104 | 105 | # prune queue to beam size 106 | if beam_size > 0: 107 | queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0)) 108 | 109 | # predict next word and state before continue 110 | for cost, string, word_id, previous_state in queue: 111 | predictions, states = rnn_predictor.predict([word_id], [previous_state]) 112 | hypothesis = (cost, string, states[0], predictions[0]) 113 | queues[i].append(hypothesis) 114 | 115 | return queues 116 | 117 | 118 | def decode(source, dictionary, rnn_predictor, beam_size, viterbi_size): 119 | lattice = create_lattice(source, dictionary) 120 | queues = initialize_queues(lattice, rnn_predictor, dictionary) 121 | queues = search(lattice, queues, rnn_predictor, beam_size, viterbi_size) 122 | 123 | candidates = [] 124 | for cost, string, _, _ in queues[-1]: 125 | candidates.append((string, cost)) 126 | 127 | top_result = candidates[0][0] 128 | return top_result, candidates, lattice, queues 129 | 130 | 131 | def main(): 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('model_directory') 134 | parser.add_argument('--model_file') 135 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin) 136 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout) 137 | parser.add_argument('--beam_size', type=int, default=5) 138 | parser.add_argument('--viterbi_size', type=int, default=1) 139 | parser.add_argument('--print_nbest', action='store_true') 140 | parser.add_argument('--print_queue', action='store_true') 141 | parser.add_argument('--print_lattice', action='store_true') 142 | args = parser.parse_args() 143 | 144 | # Load settings and vocabulary 145 | settings = load_settings(args.model_directory) 146 | dictionary = load_dictionary(args.model_directory) 147 | 148 | # Create model and load parameters 149 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type) 150 | if args.model_file: 151 | rnn_predictor.restore_from_file(args.model_file) 152 | else: 153 | rnn_predictor.restore_from_directory(args.model_directory) 154 | 155 | # Iterate input file line by line 156 | for line in args.input_file: 157 | line = line.rstrip('\n') 158 | 159 | # Decode - this might take ~10 seconds per line 160 | result, candidates, lattice, queues = decode(line, dictionary, rnn_predictor, args.beam_size, args.viterbi_size) 161 | 162 | # Print decoded results 163 | if not args.print_nbest: 164 | print(result, file=args.output_file) 165 | else: 166 | for string, cost in candidates: 167 | print(string, cost, file=args.output_file) 168 | 169 | # Print lattice for debug 170 | if args.print_lattice: 171 | for i in range(len(lattice)): 172 | for j in range(len(lattice[i])): 173 | print('i = {}, j = {}'.format(i, j), file=args.output_file) 174 | for target, word_id in lattice[i][j]: 175 | print(target, word_id, file=args.output_file) 176 | 177 | # Print queues for debug 178 | if args.print_queue: 179 | for i, queue in enumerate(queues): 180 | print('queue', i, file=args.output_file) 181 | for cost, string, state, prediction in queue: 182 | print(string, cost, file=args.output_file) 183 | 184 | if __name__ == '__main__': 185 | main() 186 | --------------------------------------------------------------------------------