├── LICENSE
├── split_source_target.py
├── split_train_test_valid.py
├── .gitignore
├── parse_bccwj.py
├── predict_ngram.py
├── evaluate.py
├── predict.py
├── rnn_predictor.py
├── predict_both.py
├── train.py
├── grid_search.py
├── rnn_trainer.py
├── utility.py
├── README.md
├── experiment.py
├── decode_ngram.py
├── decode_both.py
└── decode.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Yoh Okuno
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/split_source_target.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 |
4 |
5 | def parse_file(file):
6 | for line in file:
7 | line = line.rstrip('\n')
8 | words = line.split(' ')
9 | yield [word.split('/', 1) for word in words]
10 |
11 |
12 | def split_source_target(sentences):
13 | target = ''
14 | source = ''
15 | for sentence in sentences:
16 | for target_word, source_word in sentence:
17 | target += target_word
18 | source += source_word
19 | target += '\n'
20 | source += '\n'
21 | return target, source
22 |
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('input_file', type=argparse.FileType('r'))
27 | parser.add_argument('target_file', type=argparse.FileType('w'))
28 | parser.add_argument('source_file', type=argparse.FileType('w'))
29 | args = parser.parse_args()
30 |
31 | sentences = parse_file(args.input_file)
32 | target, source = split_source_target(sentences)
33 | args.target_file.write(target)
34 | args.source_file.write(source)
35 |
36 | if __name__ == '__main__':
37 | main()
38 |
--------------------------------------------------------------------------------
/split_train_test_valid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import random
4 |
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('input_file', type=argparse.FileType('r'))
9 | parser.add_argument('train_file', type=argparse.FileType('w'))
10 | parser.add_argument('test_file', type=argparse.FileType('w'))
11 | parser.add_argument('valid_file', type=argparse.FileType('w'))
12 | parser.add_argument('--test_size', type=float, default=0.1)
13 | parser.add_argument('--valid_size', type=float, default=0.01)
14 | parser.add_argument('--seed', type=int, default=0)
15 | args = parser.parse_args()
16 |
17 | input_data = args.input_file.readlines()
18 |
19 | random.seed(args.seed)
20 | random.shuffle(input_data)
21 |
22 | test_index = int(len(input_data) * args.test_size)
23 | test_data = input_data[:test_index]
24 |
25 | valid_index = test_index + int(len(input_data) * args.valid_size)
26 | valid_data = input_data[test_index:valid_index]
27 | train_data = input_data[valid_index:]
28 |
29 | args.train_file.writelines(train_data)
30 | args.test_file.writelines(test_data)
31 | args.valid_file.writelines(valid_data)
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/parse_bccwj.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import glob
4 | import xml.etree.ElementTree
5 |
6 |
7 | def parse_sentence(sentence):
8 | for element in sentence.iter('SUW'):
9 | target = ''.join(element.itertext())
10 | target = target.replace('\n', '')
11 | if ' ' in target:
12 | # Ignore full width space
13 | continue
14 |
15 | if 'kana' in element.attrib:
16 | source = element.attrib['kana']
17 | else:
18 | source = element.attrib['formBase']
19 | if source == '':
20 | source = target
21 |
22 | yield target, source
23 |
24 |
25 | def parse_file(filename):
26 | tree = xml.etree.ElementTree.parse(filename)
27 |
28 | for sentence in tree.iter('sentence'):
29 | sentence = list(parse_sentence(sentence))
30 | if sentence:
31 | yield sentence
32 |
33 |
34 | def parse_pathname(pathname):
35 | for filename in glob.glob(pathname):
36 | corpus = parse_file(filename)
37 | for sentence in corpus:
38 | yield sentence
39 |
40 |
41 | def katakana_to_hiragana(string):
42 | result = ''
43 | for character in string:
44 | code = ord(character)
45 | # if 0x30A1 <= code < = 0x30F6:
46 | if ord('ァ') <= code <= ord('ヶ'):
47 | # result += chr(code + 0x3041 - 0x30A1)
48 | result += chr(code - ord('ァ') + ord('ぁ'))
49 | else:
50 | result += character
51 | return result
52 |
53 |
54 | def main():
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument('pathname')
57 | args = parser.parse_args()
58 |
59 | sentences = parse_pathname(args.pathname)
60 | for sentence in sentences:
61 | line = ' '.join('{}/{}'.format(target, katakana_to_hiragana(source)) for target, source in sentence)
62 | print(line)
63 |
64 |
65 | if __name__ == '__main__':
66 | main()
67 |
--------------------------------------------------------------------------------
/predict_ngram.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import math
4 | import collections
5 |
6 |
7 | def parse_srilm(file):
8 | order = 0
9 | ngrams = collections.defaultdict(list)
10 |
11 | for line in file:
12 | line = line.rstrip('\n')
13 | fields = line.split('\t', 2)
14 |
15 | if len(fields) not in (2, 3):
16 | continue
17 |
18 | cost = -math.log(10 ** float(fields[0]))
19 | ngram = fields[1].split(' ')
20 |
21 | if len(ngram) > order:
22 | order = len(ngram)
23 |
24 | context = tuple(ngram[:-1])
25 | word = ngram[-1]
26 | ngrams[context].append((cost, word))
27 |
28 | return ngrams, order
29 |
30 |
31 | def predict(ngrams, context):
32 | if context not in ngrams:
33 | return predict(ngrams, context[1:])
34 |
35 | cost, word = min(ngrams[context])
36 | return word
37 |
38 |
39 | def match_predictions(ngrams, order, words):
40 | for i in range(1, len(words) - 1):
41 | word = words[i]
42 | context = tuple(words[max(i-order+1, 0):i])
43 | prediction = predict(ngrams, context)
44 | yield prediction == word
45 |
46 |
47 | def main():
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('test_file', type=argparse.FileType('r'))
50 | parser.add_argument('ngram_file', type=argparse.FileType('r'))
51 | args = parser.parse_args()
52 |
53 | ngrams, order = parse_srilm(args.ngram_file)
54 |
55 | all_predictions = 0
56 | correct_predictions = 0
57 |
58 | for line in args.test_file:
59 | words = line.split(' ')
60 | words = [''] + words + ['']
61 | result = list(match_predictions(ngrams, order, words))
62 | all_predictions += len(result)
63 | correct_predictions += sum(result)
64 | print(correct_predictions / all_predictions, end='\r')
65 |
66 | print(correct_predictions / all_predictions)
67 |
68 |
69 | if __name__ == '__main__':
70 | main()
71 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 |
4 |
5 | def get_common_length(left, right):
6 | # Compute length of the longest common sub-sequence of two strings
7 | table = [[0 for _ in range(len(right) + 1)] for _ in range(len(left) + 1)]
8 |
9 | for i in range(1, len(left) + 1):
10 | for j in range(1, len(right) + 1):
11 | if left[i - 1] == right[j - 1]:
12 | table[i][j] = table[i-1][j-1] + 1
13 | else:
14 | table[i][j] = max(table[i-1][j], table[i][j-1])
15 | return table[-1][-1]
16 |
17 |
18 | def evaluate(system, reference):
19 | # extract statistics
20 | common_length = sum(get_common_length(r, s) for r, s in zip(reference, system))
21 | reference_length = len(''.join(reference))
22 | system_length = len(''.join(system))
23 | sentence_match = sum(r == s for r, s in zip(reference, system))
24 |
25 | # calculate metrics
26 | if system_length > 0:
27 | precision = 100. * common_length / system_length
28 | else:
29 | precision = 0.
30 | recall = 100. * common_length / reference_length
31 | fscore = 200. * common_length / (reference_length + system_length)
32 | accuracy = 100. * sentence_match / len(reference)
33 |
34 | # return metrics
35 | return precision, recall, fscore, accuracy
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument('system', type=argparse.FileType('r'))
41 | parser.add_argument('reference', type=argparse.FileType('r'))
42 | args = parser.parse_args()
43 |
44 | # load data
45 | system = [line.rstrip('\n') for line in args.system]
46 | reference = [line.rstrip('\n') for line in args.reference]
47 | reference = reference[:len(system)]
48 |
49 | # calculate metrics
50 | metrics = evaluate(system, reference)
51 |
52 | # print metrics
53 | print('precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*metrics))
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import os
4 | import collections
5 | import numpy as np
6 |
7 | from decode import load_settings
8 | from rnn_predictor import RNNPredictor
9 |
10 |
11 | def load_dictionary(model_directory):
12 | vocabulary_path = os.path.join(model_directory, 'vocabulary.txt')
13 | dictionary = collections.defaultdict(int)
14 | for word_id, line in enumerate(open(vocabulary_path)):
15 | word = line.rstrip('\n')
16 | dictionary[word] = word_id
17 | vocabulary = [line.rstrip('\n') for line in open(vocabulary_path)]
18 | return dictionary, vocabulary
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('test_file', type=argparse.FileType('r'))
24 | parser.add_argument('model_directory')
25 | parser.add_argument('--model_file')
26 | args = parser.parse_args()
27 |
28 | # Load settings and vocabulary
29 | settings = load_settings(args.model_directory)
30 | dictionary, vocabulary = load_dictionary(args.model_directory)
31 |
32 | # Create model and load parameters
33 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type)
34 | if args.model_file:
35 | rnn_predictor.restore_from_file(args.model_file)
36 | else:
37 | rnn_predictor.restore_from_directory(args.model_directory)
38 |
39 | all_predictions = 0
40 | correct_predictions = 0
41 |
42 | for line in args.test_file:
43 | line = line.rstrip('\n')
44 | words = line.split(' ')
45 | words = ['_BOS/_BOS'] + words + ['_EOS/_EOS']
46 | state = None
47 |
48 | for i in range(len(words) - 2):
49 | word_id = dictionary[words[i]]
50 | predictions, state = rnn_predictor.predict([word_id], state)
51 | prediction = vocabulary[np.argmin(predictions[0])]
52 |
53 | if prediction == words[i + 1]:
54 | correct_predictions += 1
55 | all_predictions += 1
56 |
57 | print(correct_predictions / all_predictions, end='\r')
58 |
59 | print(correct_predictions / all_predictions)
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/rnn_predictor.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class RNNPredictor:
5 | # TODO: Load meta graph from file and use it instead to support various internal structures
6 | def __init__(self, vocabulary_size, hidden_size, layer_size, cell_type):
7 | with tf.Graph().as_default():
8 | # Placeholder for test data
9 | self.input = tf.placeholder(tf.int32, [None])
10 |
11 | # Lookup word embedding
12 | embedding = tf.Variable(tf.zeros([vocabulary_size, hidden_size]), name='embedding')
13 | rnn_inputs = tf.nn.embedding_lookup(embedding, self.input)
14 |
15 | # Create RNN cell
16 | if cell_type == 'lstm':
17 | cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
18 | elif cell_type == 'rnn':
19 | cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
20 | else:
21 | cell = tf.nn.rnn_cell.GRUCell(hidden_size)
22 |
23 | # Stack multiple RNN cells
24 | if layer_size > 1:
25 | cell = tf.nn.rnn_cell.MultiRNNCell([cell] * layer_size)
26 |
27 | self.initial_state = cell.zero_state(1, dtype=tf.float32)
28 |
29 | # Call the RNN cell
30 | with tf.variable_scope('RNN'):
31 | rnn_output, self.next_state = cell(rnn_inputs, self.initial_state)
32 |
33 | # Predict distribution over next word
34 | softmax_w = tf.Variable(tf.zeros([hidden_size, vocabulary_size]), name='softmax_w')
35 | softmax_b = tf.Variable(tf.zeros([vocabulary_size]), name='softmax_b')
36 | logits = tf.matmul(rnn_output, softmax_w) + softmax_b
37 |
38 | # predictions is negative log probability of shape [vocabulary_size]
39 | self.predictions = -tf.nn.log_softmax(logits)
40 |
41 | self.saver = tf.train.Saver(tf.trainable_variables())
42 | self.session = tf.Session()
43 |
44 | def predict(self, input_value, state_value=None):
45 | if state_value is not None:
46 | feed_dict = {self.input: input_value, self.initial_state: state_value}
47 | else:
48 | feed_dict = {self.input: input_value}
49 | predictions, next_state = self.session.run([self.predictions, self.next_state], feed_dict)
50 | return predictions, next_state
51 |
52 | def restore_from_directory(self, model_directory):
53 | model_path = tf.train.latest_checkpoint(model_directory)
54 | self.saver.restore(self.session, model_path)
55 |
56 | def restore_from_file(self, model_path):
57 | self.saver.restore(self.session, model_path)
58 |
59 | def close(self):
60 | self.session.close()
61 |
--------------------------------------------------------------------------------
/predict_both.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import numpy as np
4 |
5 | from decode import load_settings
6 | from decode_ngram import get_ngram_cost, parse_srilm
7 | from predict import load_dictionary
8 | from rnn_predictor import RNNPredictor
9 |
10 |
11 | def match_predictions(rnn_predictor, dictionary, vocabulary, ngrams, words):
12 | state = None
13 |
14 | for i in range(len(words) - 2):
15 | # RNN prediction
16 | word_id = dictionary[words[i]]
17 | predictions, state = rnn_predictor.predict([word_id], state)
18 | rnn_prediction = predictions[0]
19 |
20 | # N-gram prediction
21 | context = words[max(i - 3, 0):i + 1]
22 | ngram_prediction = np.zeros(len(rnn_prediction))
23 | for word in list(dictionary.values()):
24 | history = tuple(context + [word])
25 | probability = get_ngram_cost(ngrams, history)
26 | word_id = dictionary[word]
27 | ngram_prediction[word_id] = probability
28 |
29 | interpolation = -np.log((np.exp(-rnn_prediction) + np.exp(-ngram_prediction)) / 2.0)
30 | prediction = vocabulary[np.argmin(interpolation)]
31 | yield prediction == words[i + 1]
32 |
33 |
34 | def main():
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('test_file', type=argparse.FileType('r'))
37 | parser.add_argument('model_directory')
38 | parser.add_argument('ngram_file', type=argparse.FileType('r'))
39 | parser.add_argument('--model_file')
40 | args = parser.parse_args()
41 |
42 | # Load settings and vocabulary
43 | settings = load_settings(args.model_directory)
44 | dictionary, vocabulary = load_dictionary(args.model_directory)
45 |
46 | # Create model and load parameters
47 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type)
48 | if args.model_file:
49 | rnn_predictor.restore_from_file(args.model_file)
50 | else:
51 | rnn_predictor.restore_from_directory(args.model_directory)
52 |
53 | # Load N-gram model
54 | ngrams = parse_srilm(args.ngram_file)
55 |
56 | all_predictions = 0
57 | correct_predictions = 0
58 |
59 | for line in args.test_file:
60 | line = line.rstrip('\n')
61 | words = line.split(' ')
62 | words = ['_BOS/_BOS'] + words + ['_EOS/_EOS']
63 |
64 | result = list(match_predictions(rnn_predictor, dictionary, vocabulary, ngrams, words))
65 | all_predictions += len(result)
66 | correct_predictions += sum(result)
67 | print(correct_predictions / all_predictions, end='\r')
68 |
69 | print(correct_predictions / all_predictions)
70 |
71 |
72 | if __name__ == '__main__':
73 | main()
74 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import json
4 | import math
5 | import os
6 | import time
7 |
8 | from utility import load_train_data
9 | from rnn_trainer import RNNTrainer
10 |
11 |
12 | def train_epoch(rnn_trainer, train_data, model_directory, epoch):
13 | total_loss = 0.0
14 | start_time = time.time()
15 |
16 | for batch, (input_, output_) in enumerate(train_data):
17 | start_time_batch = time.time()
18 | loss, gradient_norm = rnn_trainer.train(input_, output_)
19 | train_time_batch = time.time() - start_time_batch
20 | total_loss += loss
21 | perplexity = math.exp(loss)
22 | print('training batch: {:} perplexity: {:.2f} time: {:.2f}'.format(batch, perplexity, train_time_batch), end='\r')
23 |
24 | train_time = time.time() - start_time
25 | perplexity = math.exp(total_loss / len(train_data))
26 |
27 | log_text = 'training epoch: {} perplexity: {:.2f} \n'.format(epoch, perplexity)
28 | log_text += 'training total time: {:.2f} average time: {:.2f}'.format(train_time, train_time / len(train_data))
29 | print(log_text)
30 |
31 | # Save model every epoch
32 | model_path = os.path.join(model_directory, 'model.ckpt')
33 | rnn_trainer.save(model_path, epoch)
34 |
35 | return perplexity, train_time
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument('train_file')
41 | parser.add_argument('model_directory')
42 | parser.add_argument('--sentence_size', type=int, default=30)
43 | parser.add_argument('--vocabulary_size', type=int, default=50000)
44 | parser.add_argument('--batch_size', type=int, default=50)
45 | parser.add_argument('--hidden_size', type=int, default=400)
46 | parser.add_argument('--layer_size', type=int, default=1)
47 | parser.add_argument('--epoch_size', type=int, default=10)
48 | parser.add_argument('--clip_norm', type=float, default=5)
49 | parser.add_argument('--keep_prob', type=float, default=0.5)
50 | parser.add_argument('--cell_type', default='gru')
51 | parser.add_argument('--optimizer_type', default='adam')
52 | parser.add_argument('--max_keep', type=int, default=0)
53 | args = parser.parse_args()
54 |
55 | # Print settings
56 | print(json.dumps(vars(args), indent=4))
57 |
58 | # Load and preprocess training data
59 | train_data = load_train_data(args)
60 | print('number of batches:', len(train_data))
61 |
62 | # Create RNN model for training
63 | rnn_trainer = RNNTrainer(args.batch_size, args.sentence_size, args.vocabulary_size, args.hidden_size,
64 | args.layer_size, args.cell_type, args.optimizer_type, args.clip_norm,
65 | args.keep_prob, args.max_keep)
66 |
67 | start_time = time.time()
68 |
69 | for epoch in range(args.epoch_size):
70 | # Train one epoch and save the model
71 | train_epoch(rnn_trainer, train_data, args.model_directory, epoch)
72 |
73 | total_time = time.time() - start_time
74 | print('total training time:', total_time)
75 | print()
76 |
77 | rnn_trainer.close()
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------
/grid_search.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import json
4 | import os
5 | import time
6 | import itertools
7 | from experiment import experiment
8 |
9 |
10 | def grid_search(args):
11 | start_time = time.time()
12 |
13 | results = []
14 |
15 | # split args to default settings and hyperparameters
16 | default_settings = {}
17 | search_settings = {}
18 | for name, value in vars(args).items():
19 | if type(value) == list and len(value) == 1:
20 | value = value[0]
21 | if type(value) != list:
22 | default_settings[name] = value
23 | else:
24 | search_settings[name] = value
25 |
26 | # Search hyperparameters
27 | for values in itertools.product(*search_settings.values()):
28 | # Merge default and search settings
29 | hyperparameters = dict(zip(search_settings.keys(), values))
30 | merged_settings = {**default_settings, **hyperparameters}
31 | settings = argparse.Namespace(**merged_settings)
32 |
33 | # Set directory path
34 | directory_name = '-'.join(name + str(value) for name, value in hyperparameters.items())
35 | settings.model_directory = os.path.join(settings.model_directory, directory_name)
36 |
37 | # Run experiment
38 | metrics, epoch = experiment(settings)
39 | result = (hyperparameters, metrics, epoch)
40 | results.append(result)
41 |
42 | # print best experiment result
43 | hyperparameters, metrics, epoch = max(results, key=lambda x: x[1][2])
44 | print('best experiment settings in epoch', epoch)
45 | print(json.dumps(hyperparameters, indent=4))
46 | print('best experiment metrics: precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*metrics))
47 |
48 | # save all results
49 | results_path = os.path.join(args.model_directory, 'results.json')
50 | json.dump(results, open(results_path, 'w'), indent=4)
51 |
52 | # Print total time
53 | total_time = time.time() - start_time
54 | print('total grid search time:', total_time)
55 |
56 |
57 | def main():
58 | parser = argparse.ArgumentParser()
59 | # mandatory parameters
60 | parser.add_argument('train_file')
61 | parser.add_argument('valid_target_file')
62 | parser.add_argument('valid_source_file')
63 | parser.add_argument('model_directory')
64 | # grid search parameters
65 | parser.add_argument('--hidden_size', type=int, default=400, nargs='*')
66 | parser.add_argument('--layer_size', type=int, default=1, nargs='*')
67 | parser.add_argument('--keep_prob', type=float, default=0.5, nargs='*')
68 | parser.add_argument('--cell_type', default='gru', nargs='*')
69 | parser.add_argument('--optimizer_type', default='adam', nargs='*')
70 | parser.add_argument('--vocabulary_size', type=int, default=50000, nargs='*')
71 | parser.add_argument('--sentence_size', type=int, default=30, nargs='*')
72 | parser.add_argument('--batch_size', type=int, default=50, nargs='*')
73 | parser.add_argument('--clip_norm', type=float, default=5, nargs='*')
74 | # optional parameters
75 | parser.add_argument('--epoch_size', type=int, default=10)
76 | parser.add_argument('--max_keep', type=int, default=0)
77 | parser.add_argument('--beam_size', type=int, default=5)
78 | parser.add_argument('--viterbi_size', type=int, default=1)
79 | args = parser.parse_args()
80 |
81 | # Run grid search. This might take long time to complete.
82 | grid_search(args)
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
--------------------------------------------------------------------------------
/rnn_trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class RNNTrainer:
5 | # TODO split to separate functions
6 | # TODO support truncated BPTT
7 | def __init__(self, batch_size, sentence_size, vocabulary_size, hidden_size, layer_size,
8 | cell_type, optimizer_type, clip_norm, keep_prob, max_keep):
9 |
10 | with tf.Graph().as_default():
11 | # Placeholders for training data
12 | self.input = tf.placeholder(tf.int32, [batch_size, sentence_size])
13 | self.output = tf.placeholder(tf.int32, [batch_size, sentence_size])
14 |
15 | # Lookup word embedding
16 | embedding = tf.Variable(tf.truncated_normal([vocabulary_size, hidden_size], stddev=0.01), name='embedding')
17 | inputs = tf.nn.embedding_lookup(embedding, self.input)
18 | inputs = tf.nn.dropout(inputs, keep_prob)
19 |
20 | # Create and connect RNN cells
21 | if cell_type == 'lstm':
22 | cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
23 | elif cell_type == 'rnn':
24 | cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
25 | else:
26 | cell = tf.nn.rnn_cell.GRUCell(hidden_size)
27 |
28 | # Dropout output of the RNN cell
29 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=keep_prob)
30 |
31 | # Stack multiple RNN cells
32 | if layer_size > 1:
33 | cell = tf.nn.rnn_cell.MultiRNNCell([cell] * layer_size)
34 |
35 | rnn_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, sentence_size, inputs)]
36 | outputs, _ = tf.nn.rnn(cell, rnn_inputs, dtype=tf.float32)
37 |
38 | # Predict distribution over next word
39 | output = tf.reshape(tf.concat(1, outputs), [-1, hidden_size])
40 | softmax_w = tf.Variable(tf.truncated_normal([hidden_size, vocabulary_size], stddev=0.01), name='softmax_w')
41 | softmax_b = tf.Variable(tf.zeros(shape=[vocabulary_size]), name='softmax_b')
42 | logits = tf.matmul(output, softmax_w) + softmax_b
43 |
44 | # Define loss function and optimizer
45 | self.loss = tf.nn.seq2seq.sequence_loss(
46 | [logits],
47 | [tf.reshape(self.output, [-1])],
48 | [tf.ones([batch_size * sentence_size], dtype=tf.float32)])
49 |
50 | # Apply gradient clipping to address gradient explosion
51 | trainable_variables = tf.trainable_variables()
52 | gradients = tf.gradients(self.loss, trainable_variables)
53 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, clip_norm)
54 | if optimizer_type == 'sgd':
55 | optimizer = tf.train.GradientDescentOptimizer(0.01)
56 | elif optimizer_type == 'adagrad':
57 | optimizer = tf.train.AdagradOptimizer(0.01)
58 | elif optimizer_type == 'rmsprop':
59 | optimizer = tf.train.RMSPropOptimizer(0.01)
60 | else:
61 | optimizer = tf.train.AdamOptimizer()
62 | self.train_step = optimizer.apply_gradients(zip(clipped_gradients, trainable_variables))
63 |
64 | # Keep latest max_keep checkpoints
65 | self.saver = tf.train.Saver(trainable_variables, max_to_keep=max_keep)
66 | self.session = tf.Session()
67 | self.session.run(tf.initialize_all_variables())
68 |
69 | def train(self, input_, output_):
70 | _, loss, gradient_norm = self.session.run([self.train_step, self.loss, self.gradient_norm],
71 | {self.input: input_, self.output: output_})
72 | return loss, gradient_norm
73 |
74 | def save(self, model_path, epoch):
75 | self.saver.save(self.session, model_path, global_step=epoch)
76 |
77 | def close(self):
78 | self.session.close()
79 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import os
3 | import json
4 |
5 |
6 | def parse_file(file):
7 | for line in file:
8 | line = line.rstrip('\n')
9 | sentence = line.split(' ')
10 | yield sentence
11 |
12 |
13 | # TODO: current method does not allow the model to learn boundary beyond bigram.
14 | def adjust_size(sentences, sentence_size):
15 | # Increment sentence size for shifting output later
16 | sentence_size_plus = sentence_size + 1
17 |
18 | for sentence in sentences:
19 | # Insert BOS = Beginning Of Sentence
20 | sentence.insert(0, '_BOS/_BOS')
21 |
22 | # Split long sentence allowing overlap of 1 word
23 | while len(sentence) >= sentence_size_plus:
24 | yield sentence[:sentence_size_plus]
25 | sentence = sentence[sentence_size:]
26 |
27 | # Do not yield EOS-only sentence
28 | if sentence:
29 | # Insert EOS = End Of Sentence
30 | sentence.append('_EOS/_EOS')
31 |
32 | if len(sentence) < sentence_size_plus:
33 | # Padding sentence to make its size sentence_size_plus
34 | sentence += ['_PAD/_PAD'] * (sentence_size_plus - len(sentence))
35 | yield sentence
36 |
37 |
38 | def create_vocabulary(sentences, vocabulary_size):
39 | # Create list of words indexed by word ID
40 | counter = Counter(word for words in sentences for word in words)
41 | most_common = counter.most_common(vocabulary_size - 1)
42 | vocabulary = [word for word, count in most_common]
43 | vocabulary.insert(0, '_UNK/_UNK')
44 | return vocabulary
45 |
46 |
47 | def convert_to_ids(sentences, vocabulary):
48 | dictionary = dict((word, word_id) for word_id, word in enumerate(vocabulary))
49 |
50 | for sentence in sentences:
51 | word_ids = []
52 |
53 | for word in sentence:
54 | if word in dictionary:
55 | word_id = dictionary[word]
56 | else:
57 | word_id = dictionary['_UNK/_UNK']
58 | word_ids.append(word_id)
59 |
60 | yield word_ids
61 |
62 |
63 | # TODO: current batching ignores sentences that does't fit into last batch.
64 | def create_batches(sentences, batch_size):
65 | all_batches = int(len(sentences) / batch_size)
66 |
67 | for i in range(all_batches):
68 | batch_sentences = sentences[i * batch_size:(i + 1) * batch_size]
69 | batch_input = []
70 | batch_output = []
71 |
72 | for sentence in batch_sentences:
73 | # Shift sentence by 1 time step
74 | input_ = sentence[:-1]
75 | output_ = sentence[1:]
76 |
77 | batch_input.append(input_)
78 | batch_output.append(output_)
79 |
80 | yield batch_input, batch_output
81 |
82 |
83 | def save_metadata(args, vocabulary):
84 | # Create directory if not exists
85 | if not os.path.exists(args.model_directory):
86 | os.makedirs(args.model_directory)
87 |
88 | # Save settings
89 | settings_path = os.path.join(args.model_directory, 'settings.json')
90 | with open(settings_path, 'w') as settings_file:
91 | json.dump(vars(args), settings_file, indent=4)
92 |
93 | # Save vocabulary
94 | vocabulary_path = os.path.join(args.model_directory, 'vocabulary.txt')
95 | with open(vocabulary_path, 'w') as vocabulary_file:
96 | vocabulary_file.write('\n'.join(vocabulary))
97 |
98 |
99 | def load_train_data(args):
100 | sentences = parse_file(open(args.train_file))
101 | sentences = list(adjust_size(sentences, args.sentence_size))
102 | vocabulary = create_vocabulary(sentences, args.vocabulary_size)
103 | sentences = list(convert_to_ids(sentences, vocabulary))
104 | train_data = list(create_batches(sentences, args.batch_size))
105 | save_metadata(args, vocabulary)
106 | return train_data
107 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural IME: Neural Input Method Engine
2 | Japanese input method engine can enter next level with deep learning technology.
3 |
4 | # Prerequisite
5 | You need following software to use Neural IME.
6 |
7 | * Python 3.5
8 | * TensorFlow 0.10
9 |
10 | The developer uses Mac OS X 10.11.4, Anaconda 4.1, PyCharm 5.0.4 but it should work elsewhere.
11 |
12 |
13 | # Experimental results
14 | The neural model outperformed N-gram model on reference corpus as shown below.
15 |
16 | | Metrics | N-gram | RNN |
17 | |:-------------------:|:------:|:---------:|
18 | | Sentence Accuracy | 41.5% | **44.2%** |
19 | | Prediction Accuracy | 22.9% | **26.7%** |
20 |
21 | # Training your own models
22 | For training and testing your own models, you need annotated data
23 | such as DVD version of BCCWJ (Balanced Corpus of Contemporary Written Japanese).
24 |
25 | > http://pj.ninjal.ac.jp/corpus_center/bccwj/en/
26 |
27 | You probably want a modern GPU to train faster, as the developer uses p2.xlarge instance in AWS.
28 |
29 | ## Preparing your data
30 | Training data is text file in UTF-8 and each line corresponds to a sentence.
31 | A sentence is segmented by space character into words, and a word is a pair of *target*
32 | (i.e. Kanji, Hiragana or Katakana) and *source* (Hiragana), concatenated by slash character.
33 |
34 | > 私/わたし の/の 名前/なまえ は/は 中野/なかの です/です 。/。
35 |
36 | Test data should contain different sentences from training data, but ideally its domain is same to training data.
37 | Source file should contain source sentences without space.
38 |
39 | > きょうのてんきははれです。
40 |
41 | Target file should contain target sentences without space.
42 |
43 | > 今日の天気は晴れです。
44 |
45 | ## Pre-processing BCCWJ
46 | The developer uses human-annotated part of BCCWJ as training and testing corpus.
47 | You can use the scripts in this repository to pre-process the XML files after extracted from compressed file.
48 | For example, the following commands parse and split data to train, test source and test target files.
49 |
50 | parse_bccwj.py 'BCCWJ/CORE/M-XML/*.xml' > parsed.txt
51 |
52 | split_train_test_valid.py parsed.txt train.txt test.txt valid.txt
53 |
54 | split_source_target.py test.txt test.target.txt test.source.txt
55 |
56 |
57 | ## Training neural models
58 | Now you can train your own model with default parameters.
59 |
60 | train.py train.txt model
61 |
62 | See help for optional parameters such as number of hidden units and dropout probability.
63 |
64 | train.py --help
65 |
66 | ## Decoding sentences
67 | Once trained your model, you can decode sentences using it.
68 |
69 | decode.py model
70 |
71 | Type source sentence on your console, it will show decoded sentence like this.
72 |
73 | きょうのてんきははれです。
74 | 今日の天気は晴れです。
75 | きょじんにせんせい
76 | 巨人に先制
77 |
78 | Alternatively, you can give file names as input or output.
79 |
80 | decode.py model --input_file test.source.txt --output_file model/test.decode.txt
81 |
82 | You can trade decoding time with accuracy by tuning pruning parameters such as beam size and viterbi size.
83 | For example, the following option is faster than default beam size 5 but less accurate.
84 |
85 | decode.py model --beam_size 1
86 |
87 | ## Evaluating results
88 | You can evaluate decoded results if you have target sentences as reference.
89 |
90 | evaluate.py model/test.decode.txt test.target.txt
91 |
92 | This command gives something like this:
93 |
94 | > precision: 93.59 recall: 93.58 f-score: 93.59 accuracy: 34.06
95 |
96 | Precision, recall and F-score are character-based metrics based on longest common subsequence,
97 | and accuracy is a sentence-level metric.
98 |
99 | ## Hyperparameter search
100 | You can use grid search script to find best hyperparameters.
101 |
102 | grid_search.py train.txt valid.source.txt valid.target.txt model --hidden_size 50 100 200 400
103 |
104 | ## Training N-gram models
105 | In order to train N-gram models as baseline for comparing with neural models, you need to install and use SRILM toolkit.
106 |
107 | > http://www.speech.sri.com/projects/srilm/
108 |
109 | Once installed, you can run the following command to train the model.
110 |
111 | ngram-count -text train.txt -lm ngram.txt -kndiscount -order 2
112 |
113 | Now you can decode sentences using the N-gram model.
114 |
115 | decode_ngram.py ngram.txt
116 |
117 | Or you can combine both neural model and N-gram model.
118 |
119 | decode_both.py neural_model ngram.txt
120 |
121 |
122 | ## Reference
123 | > Yoh Okuno, Neural IME: Neural Input Method Engine, The 8th Input Method Workshop, 2016.
124 |
--------------------------------------------------------------------------------
/experiment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import json
4 | import os
5 | import time
6 |
7 | from utility import load_train_data
8 | from rnn_trainer import RNNTrainer
9 | from train import train_epoch
10 | from rnn_predictor import RNNPredictor
11 | from decode import load_dictionary, decode
12 | from evaluate import evaluate
13 |
14 |
15 | def decode_all(rnn_predictor, valid_source_data, dictionary, beam_size, viterbi_size):
16 | start_time = time.time()
17 | system = []
18 | for i, source in enumerate(valid_source_data):
19 | start_time_sentence = time.time()
20 | top_result, _, _, _ = decode(source, dictionary, rnn_predictor, beam_size, viterbi_size)
21 | decode_time_sentence = time.time() - start_time_sentence
22 | print('decoding sentence: {} time: {:.2f}'.format(i, decode_time_sentence), end='\r')
23 | system.append(top_result)
24 |
25 | decode_time = time.time() - start_time
26 | return system, decode_time
27 |
28 |
29 | def train(rnn_trainer, rnn_predictor, train_data, valid_target_data, valid_source_data, dictionary,
30 | epoch_size, model_directory, beam_size, viterbi_size):
31 | start_time = time.time()
32 | log_path = os.path.join(model_directory, 'log.txt')
33 | log_file = open(log_path, 'w')
34 | best_epoch = None
35 | best_metrics = None
36 |
37 | for epoch in range(epoch_size):
38 | # Train one epoch and save the model
39 | train_epoch(rnn_trainer, train_data, model_directory, epoch)
40 |
41 | # Decode all sentences
42 | rnn_predictor.restore_from_directory(model_directory)
43 | system, decode_time = decode_all(rnn_predictor, valid_source_data, dictionary, beam_size, viterbi_size)
44 |
45 | # Evaluate results
46 | metrics = evaluate(system, valid_target_data)
47 |
48 | # Print metrics
49 | log_text = 'decoding precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}\n'.format(*metrics)
50 | log_text += 'decoding total time: {:.2f} average time: {:.2f}'.format(decode_time, decode_time / len(system))
51 | print(log_text)
52 | print(log_text, file=log_file)
53 |
54 | # Write decoded results to file
55 | decode_path = os.path.join(model_directory, 'decode-{}.txt'.format(epoch))
56 | with open(decode_path, 'w') as file:
57 | file.write('\n'.join(system))
58 |
59 | # Update best epoch
60 | if not best_epoch or best_metrics[2] < metrics[2]:
61 | best_epoch = epoch
62 | best_metrics = metrics
63 |
64 | total_time = time.time() - start_time
65 | print('best epoch:', best_epoch)
66 | print('best epoch metrics: precision: {:.2f} recall: {:.2f} f-score: {:.2f} accuracy: {:.2f}'.format(*best_metrics))
67 | print('total experiment time:', total_time)
68 | print()
69 | return best_metrics, best_epoch
70 |
71 |
72 | def experiment(settings):
73 | # Print settings
74 | print(json.dumps(vars(settings), indent=4))
75 |
76 | # Load and preprocess training data
77 | train_data = load_train_data(settings)
78 | print('number of batches:', len(train_data))
79 |
80 | # Load validation data
81 | valid_target_data = [line.rstrip('\n') for line in open(settings.valid_target_file)]
82 | valid_source_data = [line.rstrip('\n') for line in open(settings.valid_source_file)]
83 |
84 | # Load dictionary for decoding
85 | dictionary = load_dictionary(settings.model_directory)
86 |
87 | # Create RNN model for training
88 | rnn_trainer = RNNTrainer(settings.batch_size, settings.sentence_size, settings.vocabulary_size, settings.hidden_size,
89 | settings.layer_size, settings.cell_type, settings.optimizer_type, settings.clip_norm,
90 | settings.keep_prob, settings.max_keep)
91 |
92 | # Create RNN model for prediction
93 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type)
94 |
95 | # Run experiment
96 | result = train(rnn_trainer, rnn_predictor, train_data, valid_target_data, valid_source_data, dictionary,
97 | settings.epoch_size, settings.model_directory, settings.beam_size, settings.viterbi_size)
98 |
99 | rnn_trainer.close()
100 | rnn_predictor.close()
101 | return result
102 |
103 |
104 | def main():
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('train_file')
107 | parser.add_argument('valid_target_file')
108 | parser.add_argument('valid_source_file')
109 | parser.add_argument('model_directory')
110 | parser.add_argument('--sentence_size', type=int, default=30)
111 | parser.add_argument('--vocabulary_size', type=int, default=50000)
112 | parser.add_argument('--batch_size', type=int, default=50)
113 | parser.add_argument('--hidden_size', type=int, default=400)
114 | parser.add_argument('--layer_size', type=int, default=1)
115 | parser.add_argument('--epoch_size', type=int, default=10)
116 | parser.add_argument('--clip_norm', type=float, default=5)
117 | parser.add_argument('--keep_prob', type=float, default=0.5)
118 | parser.add_argument('--cell_type', default='gru')
119 | parser.add_argument('--optimizer_type', default='adam')
120 | parser.add_argument('--max_keep', type=int, default=0)
121 | parser.add_argument('--beam_size', type=int, default=5)
122 | parser.add_argument('--viterbi_size', type=int, default=1)
123 | args = parser.parse_args()
124 | experiment(args)
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/decode_ngram.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import sys
4 | import heapq
5 | import operator
6 | import math
7 | from collections import defaultdict
8 |
9 |
10 | def parse_ngram(ngram):
11 | for word in ngram.split(' '):
12 | if word == '':
13 | yield '_BOS', '_BOS'
14 | elif word == '':
15 | yield '_EOS', '_EOS'
16 | else:
17 | yield tuple(word.split('/', 1))
18 |
19 |
20 | def parse_srilm(file):
21 | ngrams = {}
22 | for line in file:
23 | line = line.rstrip('\n')
24 | fields = line.split('\t', 2)
25 |
26 | if len(fields) < 2:
27 | continue
28 |
29 | if len(fields) == 2:
30 | logprob, ngram = fields
31 | backoff = None
32 | elif len(fields) == 3:
33 | logprob, ngram, backoff = fields
34 | backoff = -math.log(10 ** float(backoff))
35 | cost = -math.log(10 ** float(logprob))
36 | ngram = tuple(parse_ngram(ngram))
37 | ngrams[ngram] = (cost, backoff)
38 | return ngrams
39 |
40 |
41 | def create_dictionary(ngrams):
42 | dictionary = defaultdict(list)
43 | for ngram in ngrams.keys():
44 | if len(ngram) == 1:
45 | target, source = ngram[0]
46 | dictionary[source].append(target)
47 | return dictionary
48 |
49 |
50 | def create_lattice(input_, dictionary):
51 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)]
52 |
53 | for i in range(1, len(input_) + 1):
54 | for j in range(i):
55 | source = input_[j:i]
56 | if source in dictionary:
57 | for target in dictionary[source]:
58 | lattice[i][j].append((target, source))
59 | elif len(source) == 1:
60 | lattice[i][j].append((source, source))
61 |
62 | lattice[-1][-1].append(('_EOS', '_EOS'))
63 | return lattice
64 |
65 |
66 | def initialize_queues(lattice):
67 | # A hypothesis is tuple of (cost, history)
68 | queues = [[] for _ in range(len(lattice))]
69 | bos_hypothesis = (0.0, [('_BOS', '_BOS')])
70 | queues[0].append(bos_hypothesis)
71 | return queues
72 |
73 |
74 | def get_ngram_cost(ngrams, history):
75 | if type(history) is list:
76 | history = tuple(history)
77 | if history in ngrams:
78 | cost, _ = ngrams[history]
79 | return cost
80 |
81 | if len(history) == 1:
82 | return 100.0
83 |
84 | return get_ngram_cost(ngrams, history[1:])
85 |
86 |
87 | def search(lattice, ngrams, queues, beam_size, viterbi_size):
88 | for i in range(len(lattice)):
89 | for j in range(len(lattice[i])):
90 | for target, source in lattice[i][j]:
91 |
92 | word_queue = []
93 | for previous_cost, previous_history in queues[j]:
94 | history = previous_history + [(target, source)]
95 | cost = previous_cost + get_ngram_cost(ngrams, tuple(history[-3:]))
96 | hypothesis = (cost, history)
97 | word_queue.append(hypothesis)
98 |
99 | # prune word_queue to viterbi size
100 | if viterbi_size > 0:
101 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0))
102 |
103 | queues[i] += word_queue
104 |
105 | # prune queues[i] to beam size
106 | if beam_size > 0:
107 | queues[i] = heapq.nsmallest(beam_size, queues[i], key=operator.itemgetter(0))
108 | return queues
109 |
110 |
111 | def decode(input_, dictionary, ngrams, beam_size, viterbi_size):
112 | lattice = create_lattice(input_, dictionary)
113 | queues = initialize_queues(lattice)
114 | queue = search(lattice, ngrams, queues, beam_size, viterbi_size)
115 |
116 | candidates = []
117 | for cost, history in queue[-1]:
118 | result = ''.join(target for target, source in history[1:-1])
119 | candidates.append((result, cost))
120 |
121 | top_result = candidates[0][0]
122 | return top_result, candidates, lattice, queues
123 |
124 |
125 | def main():
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument('ngram_file')
128 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin)
129 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout)
130 | parser.add_argument('--beam_size', type=int, default=5)
131 | parser.add_argument('--viterbi_size', type=int, default=1)
132 | parser.add_argument('--print_nbest', action='store_true')
133 | parser.add_argument('--print_queue', action='store_true')
134 | parser.add_argument('--print_lattice', action='store_true')
135 | args = parser.parse_args()
136 |
137 | ngrams = parse_srilm(open(args.ngram_file))
138 | dictionary = create_dictionary(ngrams)
139 |
140 | for line in args.input_file:
141 | line = line.rstrip('\n')
142 | result, candidates, lattice, queues = decode(line, dictionary, ngrams, args.beam_size, args.viterbi_size)
143 |
144 | # Print decoded results
145 | if not args.print_nbest:
146 | print(result, file=args.output_file)
147 | else:
148 | for string, cost in candidates:
149 | print(string, cost, file=args.output_file)
150 |
151 | # Print lattice for debug
152 | if args.print_lattice:
153 | for i in range(len(lattice)):
154 | for j in range(len(lattice[i])):
155 | print('i = {}, j = {}'.format(i, j), file=args.output_file)
156 | for target, source in lattice[i][j]:
157 | print(target, source, file=args.output_file)
158 |
159 | # Print queues for debug
160 | if args.print_queue:
161 | for i, queue in enumerate(queues):
162 | print('queue', i, file=args.output_file)
163 | for cost, history in queue:
164 | print(cost, history, file=args.output_file)
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/decode_both.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import sys
4 | import heapq
5 | import operator
6 | import math
7 |
8 | from rnn_predictor import RNNPredictor
9 | from decode import load_settings, load_dictionary
10 | from decode_ngram import parse_srilm, get_ngram_cost
11 |
12 |
13 | def create_lattice(input_, dictionary):
14 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)]
15 | _, unk_id = dictionary['_UNK'][0]
16 |
17 | for i in range(1, len(input_) + 1):
18 | for j in range(i):
19 | source = input_[j:i]
20 | if source in dictionary:
21 | for target, word_id in dictionary[source]:
22 | lattice[i][j].append((target, source, word_id))
23 | elif len(source) == 1:
24 | # Create _UNK node with verbatim target when single character key is not found in the dictionary.
25 | lattice[i][j].append((source, source, unk_id))
26 |
27 | _, eos_id = dictionary['_EOS'][0]
28 | lattice[-1][-1].append(('_EOS', '_EOS', eos_id))
29 | return lattice
30 |
31 |
32 | def initialize_queues(lattice, rnn_predictor, dictionary):
33 | # A hypothesis is tuple of (cost, history, state, prediction)
34 | _, bos_id = dictionary['_BOS'][0]
35 | bos_predictions, bos_states = rnn_predictor.predict([bos_id])
36 | bos_hypothesis = (0.0, [('_EOS', '_EOS')], bos_states[0], bos_predictions[0])
37 | queues = [[] for _ in range(len(lattice))]
38 | queues[0].append(bos_hypothesis)
39 | return queues
40 |
41 |
42 | def interpolate(rnn_cost, ngram_cost):
43 | # Linear interpolation needs to be done in probability space, not log probability space
44 | return -math.log((math.exp(-rnn_cost) + math.exp(-ngram_cost)) / 2.0)
45 |
46 |
47 | def search(lattice, queues, rnn_predictor, ngrams, beam_size, viterbi_size):
48 | # Breadth first search with beam pruning and viterbi-like pruning
49 | for i in range(len(lattice)):
50 | queue = []
51 |
52 | # create hypotheses without predicting next word
53 | for j in range(len(lattice[i])):
54 | for target, source, word_id in lattice[i][j]:
55 |
56 | word_queue = []
57 | for previous_cost, previous_history, previous_state, previous_prediction in queues[j]:
58 | history = previous_history + [(target, source)]
59 | cost = previous_cost + interpolate(previous_prediction[word_id], get_ngram_cost(ngrams, history))
60 | # Temporal hypothesis is tuple of (cost, history, word_id, previous_state)
61 | # Lazy prediction replaces word_id and previous_state to state and prediction
62 | hypothesis = (cost, history, word_id, previous_state)
63 | word_queue.append(hypothesis)
64 |
65 | # prune word_queue to viterbi size
66 | if viterbi_size > 0:
67 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0))
68 |
69 | queue += word_queue
70 |
71 | # prune queue to beam size
72 | if beam_size > 0:
73 | queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0))
74 |
75 | # predict next word and state before continue
76 | for cost, history, word_id, previous_state in queue:
77 | predictions, states = rnn_predictor.predict([word_id], [previous_state])
78 | hypothesis = (cost, history, states[0], predictions[0])
79 | queues[i].append(hypothesis)
80 |
81 | return queues
82 |
83 |
84 | def decode(source, dictionary, rnn_predictor, ngrams, beam_size, viterbi_size):
85 | lattice = create_lattice(source, dictionary)
86 | queues = initialize_queues(lattice, rnn_predictor, dictionary)
87 | queues = search(lattice, queues, rnn_predictor, ngrams, beam_size, viterbi_size)
88 |
89 | candidates = []
90 | for cost, history, _, _ in queues[-1]:
91 | result = ''.join(target for target, source in history[1:-1])
92 | candidates.append((result, cost))
93 |
94 | top_result = candidates[0][0]
95 | return top_result, candidates, lattice, queues
96 |
97 |
98 | def main():
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('model_directory')
101 | parser.add_argument('ngram_file')
102 | parser.add_argument('--model_file')
103 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin)
104 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout)
105 | parser.add_argument('--beam_size', type=int, default=5)
106 | parser.add_argument('--viterbi_size', type=int, default=1)
107 | parser.add_argument('--print_nbest', action='store_true')
108 | parser.add_argument('--print_queue', action='store_true')
109 | parser.add_argument('--print_lattice', action='store_true')
110 | args = parser.parse_args()
111 |
112 | # Load settings and vocabulary
113 | settings = load_settings(args.model_directory)
114 | dictionary = load_dictionary(args.model_directory)
115 |
116 | # Create model and load parameters
117 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type)
118 | if args.model_file:
119 | rnn_predictor.restore_from_file(args.model_file)
120 | else:
121 | rnn_predictor.restore_from_directory(args.model_directory)
122 |
123 | # Load ngram file in SRILM format
124 | ngrams = parse_srilm(open(args.ngram_file))
125 |
126 | # Iterate input file line by line
127 | for line in args.input_file:
128 | line = line.rstrip('\n')
129 |
130 | # Decode - this might take some time
131 | result, candidates, lattice, queues = decode(line, dictionary, rnn_predictor, ngrams, args.beam_size, args.viterbi_size)
132 |
133 | # Print decoded results
134 | if not args.print_nbest:
135 | print(result, file=args.output_file)
136 | else:
137 | for string, cost in candidates:
138 | print(string, cost, file=args.output_file)
139 |
140 | # Print lattice for debug
141 | if args.print_lattice:
142 | for i in range(len(lattice)):
143 | for j in range(len(lattice[i])):
144 | print('i = {}, j = {}'.format(i, j), file=args.output_file)
145 | for target, source, word_id in lattice[i][j]:
146 | print(target, source, word_id, file=args.output_file)
147 |
148 | # Print queues for debug
149 | if args.print_queue:
150 | for i, queue in enumerate(queues):
151 | print('queue', i, file=args.output_file)
152 | for cost, history, state, prediction in queue:
153 | print(cost, history, file=args.output_file)
154 |
155 |
156 | if __name__ == '__main__':
157 | main()
158 |
--------------------------------------------------------------------------------
/decode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import sys
4 | import os
5 | import json
6 | import collections
7 | import heapq
8 | import operator
9 | from rnn_predictor import RNNPredictor
10 |
11 |
12 | def load_settings(model_directory):
13 | settings_path = os.path.join(model_directory, 'settings.json')
14 | settings = json.load(open(settings_path))
15 | return argparse.Namespace(**settings)
16 |
17 |
18 | def load_dictionary(model_directory):
19 | vocabulary_path = os.path.join(model_directory, 'vocabulary.txt')
20 | vocabulary = []
21 | for line in open(vocabulary_path):
22 | line = line.rstrip('\n')
23 | target, source = line.split('/', 1)
24 | vocabulary.append((target, source))
25 |
26 | dictionary = collections.defaultdict(list)
27 | for i, (target, source) in enumerate(vocabulary):
28 | dictionary[source].append((target, i))
29 |
30 | return dictionary
31 |
32 |
33 | def create_lattice(input_, dictionary):
34 | lattice = [[[] for _ in range(len(input_) + 1)] for _ in range(len(input_) + 2)]
35 | _, unk_id = dictionary['_UNK'][0]
36 |
37 | for i in range(1, len(input_) + 1):
38 | for j in range(i):
39 | key = input_[j:i]
40 | if key in dictionary:
41 | for target, word_id in dictionary[key]:
42 | lattice[i][j].append((target, word_id))
43 | elif len(key) == 1:
44 | # Create _UNK node with verbatim target when single character key is not found in the dictionary.
45 | lattice[i][j].append((key, unk_id))
46 |
47 | _, eos_id = dictionary['_EOS'][0]
48 | lattice[-1][-1].append(('', eos_id))
49 | return lattice
50 |
51 |
52 | def initialize_queues(lattice, rnn_predictor, dictionary):
53 | # Initialize priority queues for keeping hypotheses
54 | # A hypothesis is a tuple of (cost, string, state, prediction)
55 | # cost is total negative log probability
56 | # state.shape == [hidden_size * layer_size]
57 | # prediction.shape == [vocabulary_size]
58 | _, bos_id = dictionary['_BOS'][0]
59 | bos_predictions, bos_states = rnn_predictor.predict([bos_id])
60 | bos_hypothesis = (0.0, '', bos_states[0], bos_predictions[0])
61 | queues = [[] for _ in range(len(lattice))]
62 | queues[0].append(bos_hypothesis)
63 | return queues
64 |
65 |
66 | def simple_search(lattice, queues, rnn_predictor, beam_size):
67 | # Simple but slow implementation of beam search
68 | for i in range(len(lattice)):
69 | for j in range(len(lattice[i])):
70 | for target, word_id in lattice[i][j]:
71 | for previous_cost, previous_string, previous_state, previous_prediction in queues[j]:
72 | cost = previous_cost + previous_prediction[word_id]
73 | string = previous_string + target
74 | predictions, states = rnn_predictor.predict([word_id], [previous_state])
75 | hypothesis = (cost, string, states[0], predictions[0])
76 | queues[i].append(hypothesis)
77 |
78 | # prune queues[i] to beam size
79 | queues[i] = heapq.nsmallest(beam_size, queues[i], key=operator.itemgetter(0))
80 | return queues
81 |
82 |
83 | def search(lattice, queues, rnn_predictor, beam_size, viterbi_size):
84 | # Breadth first search with beam pruning and viterbi-like pruning
85 | for i in range(len(lattice)):
86 | queue = []
87 |
88 | # create hypotheses without predicting next word
89 | for j in range(len(lattice[i])):
90 | for target, word_id in lattice[i][j]:
91 |
92 | word_queue = []
93 | for previous_cost, previous_string, previous_state, previous_prediction in queues[j]:
94 | cost = previous_cost + previous_prediction[word_id]
95 | string = previous_string + target
96 | hypothesis = (cost, string, word_id, previous_state)
97 | word_queue.append(hypothesis)
98 |
99 | # prune word_queue to viterbi size
100 | if viterbi_size > 0:
101 | word_queue = heapq.nsmallest(viterbi_size, word_queue, key=operator.itemgetter(0))
102 |
103 | queue += word_queue
104 |
105 | # prune queue to beam size
106 | if beam_size > 0:
107 | queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0))
108 |
109 | # predict next word and state before continue
110 | for cost, string, word_id, previous_state in queue:
111 | predictions, states = rnn_predictor.predict([word_id], [previous_state])
112 | hypothesis = (cost, string, states[0], predictions[0])
113 | queues[i].append(hypothesis)
114 |
115 | return queues
116 |
117 |
118 | def decode(source, dictionary, rnn_predictor, beam_size, viterbi_size):
119 | lattice = create_lattice(source, dictionary)
120 | queues = initialize_queues(lattice, rnn_predictor, dictionary)
121 | queues = search(lattice, queues, rnn_predictor, beam_size, viterbi_size)
122 |
123 | candidates = []
124 | for cost, string, _, _ in queues[-1]:
125 | candidates.append((string, cost))
126 |
127 | top_result = candidates[0][0]
128 | return top_result, candidates, lattice, queues
129 |
130 |
131 | def main():
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument('model_directory')
134 | parser.add_argument('--model_file')
135 | parser.add_argument('--input_file', type=argparse.FileType('r'), default=sys.stdin)
136 | parser.add_argument('--output_file', type=argparse.FileType('w'), default=sys.stdout)
137 | parser.add_argument('--beam_size', type=int, default=5)
138 | parser.add_argument('--viterbi_size', type=int, default=1)
139 | parser.add_argument('--print_nbest', action='store_true')
140 | parser.add_argument('--print_queue', action='store_true')
141 | parser.add_argument('--print_lattice', action='store_true')
142 | args = parser.parse_args()
143 |
144 | # Load settings and vocabulary
145 | settings = load_settings(args.model_directory)
146 | dictionary = load_dictionary(args.model_directory)
147 |
148 | # Create model and load parameters
149 | rnn_predictor = RNNPredictor(settings.vocabulary_size, settings.hidden_size, settings.layer_size, settings.cell_type)
150 | if args.model_file:
151 | rnn_predictor.restore_from_file(args.model_file)
152 | else:
153 | rnn_predictor.restore_from_directory(args.model_directory)
154 |
155 | # Iterate input file line by line
156 | for line in args.input_file:
157 | line = line.rstrip('\n')
158 |
159 | # Decode - this might take ~10 seconds per line
160 | result, candidates, lattice, queues = decode(line, dictionary, rnn_predictor, args.beam_size, args.viterbi_size)
161 |
162 | # Print decoded results
163 | if not args.print_nbest:
164 | print(result, file=args.output_file)
165 | else:
166 | for string, cost in candidates:
167 | print(string, cost, file=args.output_file)
168 |
169 | # Print lattice for debug
170 | if args.print_lattice:
171 | for i in range(len(lattice)):
172 | for j in range(len(lattice[i])):
173 | print('i = {}, j = {}'.format(i, j), file=args.output_file)
174 | for target, word_id in lattice[i][j]:
175 | print(target, word_id, file=args.output_file)
176 |
177 | # Print queues for debug
178 | if args.print_queue:
179 | for i, queue in enumerate(queues):
180 | print('queue', i, file=args.output_file)
181 | for cost, string, state, prediction in queue:
182 | print(string, cost, file=args.output_file)
183 |
184 | if __name__ == '__main__':
185 | main()
186 |
--------------------------------------------------------------------------------