├── .gitignore ├── README.md ├── config ├── check_tiny.yml ├── config.yml ├── kor_ballad.yml └── shakespeare.yml ├── data ├── lyricskor │ └── input.txt ├── tiny_lyricskor │ └── input.txt └── tiny_shakespeare │ └── input.txt ├── data_loader.py ├── dataset.py ├── experiment.py ├── generator.py ├── hook.py ├── images ├── kino-samhangsi-example1.png └── kino-samhangsi-example2.png ├── main.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.npy 3 | /checkpoints 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # char-rnn [![hb-research](https://img.shields.io/badge/hb--research-experiment-green.svg?style=flat&colorA=448C57&colorB=555555)](https://github.com/hb-research) 2 | 3 | This code implements multi-layer Recurrent Neural Network (RNN, LSTM, and GRU) for training/sampling from character-level language models. 4 | 5 | ## Requirements 6 | 7 | - Python 3.6 8 | - TensorFlow 1.4 9 | - hb-config 10 | 11 | ## Features 12 | 13 | - Using Higher-APIs in TensorFlow 14 | - [Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) 15 | - [Experiment](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Experiment) 16 | - [Dataset](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) 17 | - Korean SamhangSi (like acrostic poem) 18 | 19 | ## Config 20 | 21 | example: check_tiny.yml 22 | 23 | ```yml 24 | data: 25 | data_dir: 'data/tiny_lyricskor' 26 | model: 27 | batch_size: 4 28 | input_keep_prob: 0.8 29 | log_dir: 'logs' 30 | num_layers: 1 31 | output_keep_prob: 0.8 32 | rnn_size: 64 33 | seq_length: 20 34 | train: 35 | train_steps: 10000 36 | model_dir: 'tiny_checkpoints' 37 | save_every: 1 38 | learning_rate: 0.001 39 | loss_hook_n_iter: 100 40 | check_hook_n_iter: 1000 41 | min_eval_frequency: 100 42 | ``` 43 | 44 | ## Usage 45 | 46 | First, check if the model is valid. 47 | 48 | ```bash 49 | python main.py --config check_tiny --mode train 50 | ``` 51 | 52 | Then, train the model 53 | 54 | ```bash 55 | python main.py --config kor_ballad --mode train_and_evaluate 56 | ``` 57 | 58 | 59 | After training, generate Korean Samhangsi. 60 | 61 | ```bash 62 | python generator.py --config kor_ballad --word 삼행시 63 | ``` 64 | 65 | 66 | ### Samhangsi Examples 67 | 68 | - 삼행시 69 | 70 | ``` 71 | 삼이야 그리움이 좇아 사랑은늘 도망가 72 | 행른 잊어버리고 그대 이 세상 73 | 시제 너의 곁을 떠나면 빗물에 꽃씨하나 흘러가듯 74 | ``` 75 | 76 | - 기계 77 | 78 | ``` 79 | 기를 바라보네 두 손 잡고 고개 끄덕여 달라 하기에 80 | 계 울고 싶어 내 맘을 떠나가던 날 81 | ``` 82 | 83 | - 여름 84 | 85 | ``` 86 | 여도 지금하럼 커피는 날개니 87 | 름다웠던 그대모습 다시 볼 수 없는것 알아요 88 | ``` 89 | 90 | - 커피 91 | 92 | ``` 93 | 커나가 그래 돌아서 눈 감으면 잊을까 94 | 피고 내가 가고 싶지 아파 만날 날 기다려왔어 95 | ``` 96 | 97 | ### Example with kino-bot 98 | 99 | ![images](images/kino-samhangsi-example1.png) 100 | 101 | ![images](images/kino-samhangsi-example2.png) 102 | 103 | 104 | ## Reference 105 | 106 | - [sherjilozair/char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow) 107 | - [insikk/kor-char-rnn-tensorflow](https://github.com/insikk/kor-char-rnn-tensorflow) 108 | - [Higher-Level APIs in TensorFlow](https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0) 109 | -------------------------------------------------------------------------------- /config/check_tiny.yml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: 'data/tiny_lyricskor' 3 | model: 4 | batch_size: 4 5 | input_keep_prob: 0.8 6 | log_dir: 'logs' 7 | num_layers: 1 8 | output_keep_prob: 0.8 9 | rnn_size: 64 10 | seq_length: 20 11 | train: 12 | train_steps: 10000 13 | model_dir: 'tiny_checkpoints' 14 | save_every: 1 15 | learning_rate: 0.001 16 | loss_hook_n_iter: 100 17 | check_hook_n_iter: 1000 18 | min_eval_frequency: 100 19 | -------------------------------------------------------------------------------- /config/config.yml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: 'data/tiny_lyricskor' 3 | model: 4 | batch_size: 4 5 | grad_clip: 5.0 6 | input_keep_prob: 0.8 7 | log_dir: 'logs' 8 | num_layers: 1 9 | output_keep_prob: 0.8 10 | rnn_size: 64 11 | seq_length: 20 12 | train: 13 | train_steps: 10000 14 | model_dir: 'tiny_checkpoints' 15 | save_every: 1 16 | learning_rate: 0.001 17 | loss_hook_n_iter: 100 18 | check_hook_n_iter: 1000 19 | min_eval_frequency: 100 20 | -------------------------------------------------------------------------------- /config/kor_ballad.yml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: 'data/lyricskor' 3 | model: 4 | batch_size: 32 5 | grad_clip: 5.0 6 | input_keep_prob: 0.8 7 | log_dir: 'logs' 8 | num_layers: 3 9 | output_keep_prob: 0.8 10 | rnn_size: 512 11 | seq_length: 100 12 | train: 13 | train_steps: 200000 14 | model_dir: 'ballad_checkpoints' 15 | save_every: 1000 16 | learning_rate: 0.001 17 | loss_hook_n_iter: 1000 18 | check_hook_n_iter: 2000 19 | min_eval_frequency: 1000 20 | -------------------------------------------------------------------------------- /config/shakespeare.yml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: 'data/tiny_shakespeare' 3 | model: 4 | batch_size: 32 5 | grad_clip: 5.0 6 | input_keep_prob: 0.8 7 | log_dir: 'logs' 8 | num_layers: 3 9 | output_keep_prob: 0.8 10 | rnn_size: 512 11 | seq_length: 100 12 | train: 13 | train_steps: 100000 14 | model_dir: 'shakespeare_checkpoints' 15 | save_every: 1000 16 | learning_rate: 0.001 17 | loss_hook_n_iter: 1000 18 | check_hook_n_iter: 2000 19 | min_eval_frequency: 1000 20 | -------------------------------------------------------------------------------- /data/tiny_lyricskor/input.txt: -------------------------------------------------------------------------------- 1 | 내 곁에서 떠나가지 말아요 2 | 그대없는 밤은 너무 쓸쓸해 3 | 그대가 더 잘 알고 있잖아요 4 | 제발 아무말도 하지 말아요 5 | 나약한 내가 뭘 할수 있을까 생각을 해봐 6 | 그대가 내겐 전부였었는데 음~오 7 | 제발 내 곁에서 떠나가지 말아요 8 | 그대없는 밤은 너무 싫어 9 | 우~우~우~ 돌이킬수 없는 그대 마음 10 | 우~우~우~ 이제와서 다시 어쩌려나 11 | 슬픔마음도 이젠 소용없네 12 | 13 | 내 곁에서 떠나가지 말아요 14 | 그대없는 밤은 너무 쓸쓸해 15 | 그대가 더 잘 알고 있잖아요 16 | 제발 아무말도 하지 말아요 17 | 나약한 내가 뭘 할수 있을까 생각을 해봐 18 | 그대가 내겐 전부였었는데 음~오 19 | 제발 내 곁에서 떠나가지 말아요 20 | 그대없는 밤은 너무 싫어 21 | 우~우~우~ 돌이킬수 없는 그대 마음 22 | 우~우~우~ 이제와서 다시 어쩌려나 23 | 슬픔마음도 이젠 24 | 25 | 26 | 조용한 밤하늘에 27 | 아름다운 별빛이 28 | 멀리 있는 창가에도 29 | 소리 없이 비추고 30 | 한낮의 기억들은 31 | 어디론가 사라져 32 | 꿈을 꾸는 저 하늘만 33 | 바라보고 있어요 34 | 부드러운 노래 소리에 35 | 내 마음은 아이처럼 36 | 파란 추억의 바다로 37 | 뛰어가고 있네요 38 | 깊은 밤 아름다운 그 시간은 39 | 이렇게 찾아와 마음을 물들이고 40 | 영원한 여름밤의 꿈을 41 | 기억하고 있어요 42 | 다시 아침이 밝아와도 43 | 잊혀지지 않도록 44 | 부드러운 노래 소리에 45 | 내 마음은 아이처럼 46 | 파란 추억의 바다로 47 | 뛰어가고 있네요 48 | 깊은 밤 아름다운 그 시간은 49 | 이렇게 찾아와 마음을 물들이고 50 | 영원한 여름밤의 꿈을 51 | 기억하고 있어요 52 | 다시 아침이 밝아와도 53 | 잊혀지지 않도록 54 | 55 | 나의 하늘을 본 적이 있을까 56 | 조각 구름과 빛나는 별들이 57 | 끝없이 펼쳐 있는 58 | 구석진 그 하늘 어디선가 59 | 내 노래는 널 부르고 있음을 60 | 너는 듣고 있는지 61 | 나의 정원을 본 적이 있을까 62 | 국화와 장미 예쁜 사루비아가 63 | 끝없이 피어 있는 64 | 언제든 그 문은 열려 있고 65 | 그 향기는 널 부르고 있음을 66 | 넌 알고 있는지 67 | 나의 어릴 적 내 꿈만큼이나 68 | 아름다운 가을 하늘이랑 69 | 네가 그것들과 손잡고 70 | 고요한 달빛으로 내게 오면 71 | 내 여린 마음으로 피워낸 72 | 나의 사랑을 73 | 너에게 꺾어줄게 74 | 나의 어릴 적 내 꿈만큼이나 75 | 아름다운 가을 하늘이랑 76 | 네가 그것들과 손잡고 77 | 고요한 달빛으로 내게 오면 78 | 내 여린 마음으로 피워낸 79 | 나의 사랑을 80 | 너에게 꺾어줄게 81 | 82 | 사랑은 그렇게 잊고 사는 것 83 | 말할 수 없는 게 너무도 많았어 84 | 너무도 많은 말에 85 | 우리는 지쳐 지쳐 지쳐 지쳐 86 | 하늘을 볼 수 없이 너무도 부끄러워 87 | 나나나나 찾고 싶어 88 | 나나나나 가고 싶어 89 | 헤어나질 못할 사람들 속에 묻혀 90 | 우리도 그렇게 잊고 사는 것 91 | 하늘을 볼 수 없이 92 | 모두가 지쳐 지쳐 지쳐 지쳐 93 | 오늘도 어제처럼 동녘에 해는 떠도 94 | 나나나나 보고 싶어 95 | 나나나나 끝이 없는 96 | 나나나나 내 꿈들을 -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import collections 4 | from six.moves import cPickle 5 | import numpy as np 6 | 7 | 8 | class TextLoader(): 9 | 10 | def __init__(self, data_dir, batch_size=None, seq_length=None, encoding='utf-8'): 11 | self.data_dir = data_dir 12 | self.batch_size = batch_size 13 | self.seq_length = seq_length 14 | self.encoding = encoding 15 | 16 | input_file = os.path.join(data_dir, "input.txt") 17 | vocab_file = os.path.join(data_dir, "vocab.pkl") 18 | tensor_file = os.path.join(data_dir, "data.npy") 19 | 20 | if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): 21 | print("reading text file") 22 | self.preprocess(input_file, vocab_file, tensor_file) 23 | else: 24 | print("loading preprocessed files") 25 | self.load_preprocessed(vocab_file, tensor_file) 26 | 27 | def preprocess(self, input_file, vocab_file, tensor_file): 28 | with codecs.open(input_file, "r", encoding=self.encoding) as f: 29 | data = f.read() 30 | counter = collections.Counter(data) 31 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 32 | self.chars, _ = zip(*count_pairs) 33 | self.vocab_size = len(self.chars) 34 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 35 | with open(vocab_file, 'wb') as f: 36 | cPickle.dump(self.chars, f) 37 | self.tensor = np.array(list(map(self.vocab.get, data))) 38 | np.save(tensor_file, self.tensor) 39 | 40 | def load_preprocessed(self, vocab_file, tensor_file): 41 | with open(vocab_file, 'rb') as f: 42 | self.chars = cPickle.load(f) 43 | self.vocab_size = len(self.chars) 44 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 45 | self.tensor = np.load(tensor_file) 46 | 47 | def make_train_and_test_set(self, train_size=0.8, test_size=0.2): 48 | self.num_batches = int(self.tensor.size / (self.batch_size * 49 | self.seq_length)) 50 | 51 | # When the data (tensor) is too small, 52 | # let's give them a better error message 53 | if self.num_batches == 0: 54 | assert False, "Not enough data. Make seq_length and batch_size small." 55 | if train_size + test_size > 1 : 56 | assert False, "train_size and test_size are large. sum > 1" 57 | 58 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 59 | xdata = self.tensor 60 | ydata = np.copy(self.tensor) 61 | ydata[:-1] = xdata[1:] 62 | ydata[-1] = xdata[0] 63 | 64 | self.X = xdata 65 | self.y = ydata 66 | 67 | train_length = int(len(self.X) / self.seq_length * train_size) * self.seq_length 68 | test_length = int(len(self.X) / self.seq_length * test_size) * self.seq_length 69 | 70 | train_X = self.X[train_length:] 71 | train_y = self.y[train_length:] 72 | 73 | test_X = self.X[:test_length] 74 | test_y = self.y[:test_length] 75 | 76 | return train_X, test_X, train_y, test_y 77 | 78 | def create_batches(self): 79 | self.num_batches = int(self.tensor.size / (self.batch_size * 80 | self.seq_length)) 81 | 82 | self.X_batches = np.split(self.X.reshape(self.batch_size, -1), 83 | self.num_batches, 1) 84 | self.y_batches = np.split(self.y.reshape(self.batch_size, -1), 85 | self.num_batches, 1) 86 | self.reset_batch_pointer() 87 | 88 | def next_batch(self): 89 | X, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 90 | self.pointer += 1 91 | return X, y 92 | 93 | def reset_batch_pointer(self): 94 | self.pointer = 0 95 | 96 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import tensorflow as tf 4 | 5 | 6 | 7 | class IteratorInitializerHook(tf.train.SessionRunHook): 8 | """Hook to initialise data iterator after Session is created.""" 9 | 10 | def __init__(self): 11 | super(IteratorInitializerHook, self).__init__() 12 | self.iterator_initializer_func = None 13 | 14 | def after_create_session(self, session, coord): 15 | """Initialise the iterator after the session has been created.""" 16 | self.iterator_initializer_func(session) 17 | 18 | 19 | def get_train_inputs(X, y): 20 | 21 | iterator_initializer_hook = IteratorInitializerHook() 22 | 23 | def train_inputs(): 24 | with tf.name_scope('training'): 25 | 26 | nonlocal X 27 | nonlocal y 28 | 29 | X = X.reshape([-1, Config.model.seq_length]) 30 | y = y.reshape([-1, Config.model.seq_length]) 31 | 32 | # Define placeholders 33 | input_placeholder = tf.placeholder( 34 | tf.int32, X.shape) 35 | output_placeholder = tf.placeholder( 36 | tf.int32, y.shape) 37 | 38 | # Build dataset iterator 39 | dataset = tf.data.Dataset.from_tensor_slices( 40 | (input_placeholder, output_placeholder)) 41 | dataset = dataset.repeat(None) # Infinite iterations 42 | dataset = dataset.shuffle(buffer_size=10000) 43 | dataset = dataset.batch(Config.model.batch_size) 44 | iterator = dataset.make_initializable_iterator() 45 | next_X, next_y = iterator.get_next() 46 | 47 | tf.identity(next_X[0], 'input_0') 48 | tf.identity(next_y[0], 'output_0') 49 | 50 | # Set runhook to initialize iterator 51 | iterator_initializer_hook.iterator_initializer_func = \ 52 | lambda sess: sess.run( 53 | iterator.initializer, 54 | feed_dict={input_placeholder: X, 55 | output_placeholder: y}) 56 | 57 | # Return batched (features, labels) 58 | return next_X, next_y 59 | 60 | # Return function and hook 61 | return train_inputs, iterator_initializer_hook 62 | 63 | 64 | def get_test_inputs(X, y): 65 | 66 | iterator_initializer_hook = IteratorInitializerHook() 67 | 68 | def test_inputs(): 69 | with tf.name_scope('test'): 70 | 71 | nonlocal X 72 | nonlocal y 73 | 74 | X = X.reshape([-1, Config.model.seq_length]) 75 | y = y.reshape([-1, Config.model.seq_length]) 76 | 77 | # Define placeholders 78 | input_placeholder = tf.placeholder( 79 | tf.int32, X.shape) 80 | output_placeholder = tf.placeholder( 81 | tf.int32, y.shape) 82 | 83 | # Build dataset iterator 84 | dataset = tf.data.Dataset.from_tensor_slices( 85 | (input_placeholder, output_placeholder)) 86 | dataset = dataset.repeat(None) # Infinite iterations 87 | dataset = dataset.shuffle(buffer_size=10000) 88 | dataset = dataset.batch(Config.model.batch_size) 89 | iterator = dataset.make_initializable_iterator() 90 | next_X, next_y = iterator.get_next() 91 | 92 | tf.identity(next_X[0], 'input_0') 93 | tf.identity(next_y[0], 'output_0') 94 | 95 | # Set runhook to initialize iterator 96 | iterator_initializer_hook.iterator_initializer_func = \ 97 | lambda sess: sess.run( 98 | iterator.initializer, 99 | feed_dict={input_placeholder: X, 100 | output_placeholder: y}) 101 | 102 | # Return batched (features, labels) 103 | return next_X, next_y 104 | 105 | # Return function and hook 106 | return test_inputs, iterator_initializer_hook 107 | 108 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import tensorflow as tf 4 | 5 | from data_loader import TextLoader 6 | import dataset 7 | from model import CharRNN 8 | import hook 9 | 10 | 11 | 12 | def experiment_fn(run_config, params): 13 | 14 | char_rnn = CharRNN() 15 | estimator = tf.estimator.Estimator( 16 | model_fn=char_rnn.model_fn, 17 | model_dir=Config.train.model_dir, 18 | params=params, 19 | config=run_config) 20 | 21 | data_loader = TextLoader(Config.data.data_dir, 22 | batch_size=params.batch_size, 23 | seq_length=params.seq_length) 24 | Config.data.vocab_size = data_loader.vocab_size 25 | 26 | train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set() 27 | 28 | train_input_fn, train_input_hook = dataset.get_train_inputs(train_X, train_y) 29 | test_input_fn, test_input_hook = dataset.get_test_inputs(test_X, test_y) 30 | 31 | experiment = tf.contrib.learn.Experiment( 32 | estimator=estimator, 33 | train_input_fn=train_input_fn, 34 | eval_input_fn=test_input_fn, 35 | train_steps=Config.train.train_steps, 36 | #min_eval_frequency=Config.train.min_eval_frequency, 37 | train_monitors=[ 38 | train_input_hook, 39 | hook.print_variables( 40 | variables=['training/output_0', 'prediction_0'], 41 | vocab=data_loader.vocab, 42 | every_n_iter=Config.train.check_hook_n_iter)], 43 | eval_hooks=[test_input_hook], 44 | #eval_steps=None 45 | ) 46 | return experiment 47 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | #-- coding: utf-8 -*- 2 | 3 | import argparse 4 | 5 | from hbconfig import Config 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from data_loader import TextLoader 10 | from model import CharRNN 11 | 12 | 13 | 14 | class SamhangSiGenerator: 15 | 16 | SENTENCE_LENGTH = 20 17 | 18 | def __init__(self): 19 | self._set_data() 20 | self._make_estimator() 21 | 22 | def _set_data(self): 23 | data_loader = TextLoader(Config.data.data_dir) 24 | Config.data.vocab_size = data_loader.vocab_size 25 | 26 | def get_rev_vocab(vocab): 27 | if vocab is None: 28 | return None 29 | return {idx: key for key, idx in vocab.items()} 30 | self.vocab = data_loader.vocab 31 | self.rev_vocab = get_rev_vocab(data_loader.vocab) 32 | 33 | def _make_estimator(self): 34 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 35 | run_config = tf.contrib.learn.RunConfig( 36 | model_dir=Config.train.model_dir) 37 | 38 | char_rnn = CharRNN() 39 | self.estimator = tf.estimator.Estimator( 40 | model_fn=char_rnn.model_fn, 41 | model_dir=Config.train.model_dir, 42 | params=params, 43 | config=run_config) 44 | 45 | def generate(self, word): 46 | result = "" 47 | for char in word: 48 | result += self._generate_sentence(char) 49 | return self._combine_sentence(result, word) 50 | 51 | def _generate_sentence(self, char): 52 | 53 | if char not in self.vocab: 54 | raise ValueError(f"'{char}' is not trained. (can use char in vocab)") 55 | 56 | sample = self.vocab[char] 57 | sentence = [sample] 58 | 59 | for _ in range(self.SENTENCE_LENGTH): 60 | X = np.zeros((1, 1), dtype=np.int32) 61 | X[0, 0] = sample 62 | 63 | predict_input_fn = tf.estimator.inputs.numpy_input_fn( 64 | x={"input_data": X}, 65 | num_epochs=1, 66 | shuffle=False) 67 | 68 | result = self.estimator.predict(input_fn=predict_input_fn) 69 | probs = next(result)["probs"] 70 | 71 | def weighted_pick(weights): 72 | t = np.cumsum(weights) 73 | s = np.sum(weights) 74 | return(int(np.searchsorted(t, np.random.rand(1)*s))) 75 | 76 | sample = weighted_pick(probs) 77 | sentence.append(sample) 78 | 79 | sentence = list(map(lambda sample: self.rev_vocab.get(sample, ''), sentence)) 80 | sentence = "".join(sentence) 81 | return sentence 82 | 83 | def _combine_sentence(self, result, word): 84 | print("word: " + word) 85 | result = result.replace("\n", " ") 86 | for char in word[1:]: 87 | result = result.replace(char, "\n"+char, 1) 88 | return result 89 | 90 | 91 | 92 | def main(word): 93 | samhangsi_generator = SamhangSiGenerator() 94 | result = samhangsi_generator.generate(word) 95 | print(result) 96 | 97 | 98 | if __name__ == '__main__': 99 | 100 | parser = argparse.ArgumentParser( 101 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 102 | parser.add_argument('--config', type=str, default='config', 103 | help='config file name') 104 | parser.add_argument('--word', type=str, default='삼행시', 105 | help='Input Korean word (ex. 삼행시)') 106 | args = parser.parse_args() 107 | 108 | Config(args.config) 109 | Config.model.batch_size = 1 110 | Config.model.seq_length = 1 111 | print("Config: ", Config) 112 | 113 | main(args.word) 114 | -------------------------------------------------------------------------------- /hook.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | 6 | def print_variables(variables, vocab=None, every_n_iter=100): 7 | 8 | return tf.train.LoggingTensorHook( 9 | variables, 10 | every_n_iter=every_n_iter, 11 | formatter=format_variable(variables, vocab=vocab)) 12 | 13 | 14 | def format_variable(keys, vocab=None): 15 | rev_vocab = get_rev_vocab(vocab) 16 | 17 | def to_str(sequence): 18 | tokens = [ 19 | rev_vocab.get(x, '') for x in sequence] 20 | return ''.join(tokens) 21 | 22 | def format(values): 23 | result = [] 24 | for key in keys: 25 | if vocab is None: 26 | result.append(f"{key} = {values[key]}") 27 | else: 28 | result.append(f"{key} = {to_str(values[key])}") 29 | print('\n - '.join(result)) 30 | return format 31 | 32 | 33 | def get_rev_vocab(vocab): 34 | if vocab is None: 35 | return None 36 | return {idx: key for key, idx in vocab.items()} 37 | -------------------------------------------------------------------------------- /images/kino-samhangsi-example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/char-rnn-tensorflow/535a949e28e7b00d2408978049bf0baf22d76af9/images/kino-samhangsi-example1.png -------------------------------------------------------------------------------- /images/kino-samhangsi-example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/char-rnn-tensorflow/535a949e28e7b00d2408978049bf0baf22d76af9/images/kino-samhangsi-example2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #-- coding: utf-8 -*- 2 | 3 | import argparse 4 | import logging 5 | 6 | from hbconfig import Config 7 | import tensorflow as tf 8 | 9 | import experiment 10 | 11 | 12 | 13 | def main(mode): 14 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 15 | 16 | run_config = tf.contrib.learn.RunConfig( 17 | model_dir=Config.train.model_dir) 18 | 19 | tf.contrib.learn.learn_runner.run( 20 | experiment_fn=experiment.experiment_fn, 21 | run_config=run_config, 22 | schedule=mode, 23 | hparams=params 24 | ) 25 | 26 | 27 | if __name__ == '__main__': 28 | 29 | parser = argparse.ArgumentParser( 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument('--config', type=str, default='config', 32 | help='config file name') 33 | parser.add_argument('--mode', type=str, default='train', 34 | help='Mode (train/test/train_and_evaluate)') 35 | args = parser.parse_args() 36 | 37 | tf.logging._logger.setLevel(logging.INFO) 38 | 39 | Config(args.config) 40 | print("Config: ", Config) 41 | 42 | main(args.mode) 43 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | 8 | class CharRNN: 9 | 10 | def __init__(self): 11 | pass 12 | 13 | def model_fn(self, mode, features, labels, params): 14 | self.mode = mode 15 | self.params = params 16 | 17 | self.input_data = features 18 | self.targets = labels 19 | 20 | if type(features) == dict: 21 | self.input_data = features["input_data"] 22 | 23 | self.build_graph() 24 | 25 | if mode == tf.estimator.ModeKeys.PREDICT: 26 | return tf.estimator.EstimatorSpec( 27 | mode=mode, 28 | predictions={"probs": self.probs}) 29 | 30 | return tf.estimator.EstimatorSpec( 31 | mode=mode, 32 | predictions=None, 33 | loss=self.loss, 34 | train_op=self.train_op 35 | ) 36 | 37 | def build_graph(self): 38 | self._create_embedding() 39 | self._create_rnn_cell() 40 | self._create_inferece() 41 | self._create_predictions() 42 | 43 | if self.mode == tf.estimator.ModeKeys.PREDICT: 44 | pass 45 | else: 46 | self._create_loss() 47 | self._creat_train_op() 48 | 49 | def _create_embedding(self): 50 | self.embedding = tf.get_variable("embedding", [Config.data.vocab_size, self.params.rnn_size]) 51 | 52 | def _create_rnn_cell(self): 53 | cells = [] 54 | for _ in range(self.params.num_layers): 55 | cell = tf.contrib.rnn.GRUCell(self.params.rnn_size) 56 | if self.mode == tf.estimator.ModeKeys.TRAIN: 57 | cell = tf.contrib.rnn.DropoutWrapper(cell, 58 | input_keep_prob=self.params.input_keep_prob, 59 | output_keep_prob=self.params.output_keep_prob) 60 | cells.append(cell) 61 | self.rnn_cells = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) 62 | self.initial_state = self.rnn_cells.zero_state(self.params.batch_size, tf.float32) 63 | 64 | def _create_inferece(self): 65 | 66 | with tf.variable_scope('rnnlm'): 67 | softmax_w = tf.get_variable("softmax_w", 68 | [self.params.rnn_size, Config.data.vocab_size]) 69 | softmax_b = tf.get_variable("softmax_b", [Config.data.vocab_size]) 70 | 71 | inputs = tf.nn.embedding_lookup(self.embedding, self.input_data) 72 | 73 | if self.mode == tf.estimator.ModeKeys.TRAIN and self.params.output_keep_prob: 74 | inputs = tf.nn.dropout(inputs, self.params.output_keep_prob) 75 | 76 | inputs = tf.split(inputs, self.params.seq_length, 1) 77 | inputs = [tf.squeeze(input_, [1]) for input_ in inputs] 78 | 79 | def loop(prev, _): 80 | prev = tf.matmul(prev, softmax_w) + softmax_b 81 | prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) 82 | return tf.nn.embedding_lookup(self.embedding, prev_symbol) 83 | 84 | is_training = self.mode == tf.estimator.ModeKeys.TRAIN 85 | outputs, last_state = tf.contrib.legacy_seq2seq.rnn_decoder( 86 | inputs, self.initial_state, self.rnn_cells, loop_function=loop if not is_training else None, scope='rnnlm') 87 | output = tf.reshape(tf.concat(outputs, 1), [-1, self.params.rnn_size]) 88 | 89 | self.logits = tf.matmul(output, softmax_w) + softmax_b 90 | self.probs = tf.nn.softmax(self.logits, name="probs") 91 | 92 | def _create_predictions(self): 93 | self.predictions = tf.argmax(self.probs, axis=1) 94 | tf.identity(self.predictions[:self.params.seq_length], 'prediction_0') 95 | 96 | def _create_loss(self): 97 | sequnece_loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( 98 | [self.logits], 99 | [tf.reshape(self.targets, [-1])], 100 | [tf.ones([self.params.batch_size * self.params.seq_length])]) 101 | self.loss = tf.reduce_sum(sequnece_loss, name="loss/reduce_sum") / self.params.batch_size / self.params.seq_length 102 | 103 | def _creat_train_op(self): 104 | self.train_op = tf.contrib.layers.optimize_loss( 105 | loss=self.loss, 106 | global_step=tf.contrib.framework.get_global_step(), 107 | optimizer=tf.train.AdamOptimizer, 108 | learning_rate=Config.train.learning_rate 109 | ) 110 | --------------------------------------------------------------------------------