├── .gitignore ├── .gitmodules ├── README.md ├── assets └── introduction.png ├── config.py ├── dataset.py ├── examples ├── paper_in.map ├── paper_out.map ├── simple.karel └── simple.map ├── main.py ├── models ├── __init__.py ├── decoder.py ├── encoder.py └── main.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Log 2 | logs 3 | 4 | # Data 5 | data 6 | *.txt 7 | *.npz 8 | *.out 9 | 10 | # ply 11 | karel 12 | *parser.out 13 | *parsetab.py 14 | 15 | # Created by https://www.gitignore.io/api/python,vim 16 | 17 | ### Python ### 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *,cover 64 | .hypothesis/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # dotenv 97 | .env 98 | 99 | # virtualenv 100 | .venv/ 101 | venv/ 102 | ENV/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | 111 | ### Vim ### 112 | # swap 113 | [._]*.s[a-v][a-z] 114 | [._]*.sw[a-p] 115 | [._]s[a-v][a-z] 116 | [._]sw[a-p] 117 | # session 118 | Session.vim 119 | # temporary 120 | .netrwhist 121 | *~ 122 | # auto-generated tag files 123 | tags 124 | 125 | .DS_Store 126 | 127 | # End of https://www.gitignore.io/api/python,vim 128 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/program-synthesis-rl-tensorflow/2bd04374e6204838038974bc3f07cdf024420936/.gitmodules -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leveraging Grammar and Reinforcement Learning for Neural Program Synthesis 2 | 3 | TensorFlow implementation of [Leveraging Grammar and Reinforcement Learning for Neural Program Synthesis](https://openreview.net/forum?id=H1Xw62kRZ). 4 | 5 | ![introduction](./assets/introduction.png) 6 | 7 | 8 | ## Requirements 9 | 10 | - Python 2.7+ 11 | - [tqdm](https://github.com/tqdm/tqdm) 12 | - [karel](https://github.com/carpedm20/karel) 13 | - [TensorFlow](https://www.tensorflow.org/) 1.4.1 14 | 15 | ## Usage 16 | 17 | Prepare with: 18 | 19 | $ pip install -r requirements.txt 20 | 21 | To generate datasets: 22 | 23 | $ python dataset.py --data_dir=data --max_depth=5 24 | 25 | To train a model: 26 | 27 | $ python main.py 28 | $ tensorboard --logdir=logs --host=0.0.0.0 29 | 30 | 31 | ## Results 32 | 33 | Currently, only maximum likelihood optimization is implemented. Expected correctness and RL is in progress. 34 | 35 | (in progress) 36 | 37 | 38 | ## References 39 | 40 | - [Karel dataset](https://github.com/carpedm20/karel) 41 | - [Neural Program Meta-Induction](https://arxiv.org/abs/1710.04157) 42 | 43 | 44 | ## Author 45 | 46 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 47 | -------------------------------------------------------------------------------- /assets/introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/program-synthesis-rl-tensorflow/2bd04374e6204838038974bc3f07cdf024420936/assets/introduction.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | from utils import str2bool 5 | 6 | arg_lists = [] 7 | parser = argparse.ArgumentParser() 8 | 9 | def add_argument_group(name): 10 | arg = parser.add_argument_group(name) 11 | arg_lists.append(arg) 12 | return arg 13 | 14 | 15 | ################ 16 | # Network 17 | ################ 18 | 19 | net_arg = add_argument_group('Network') 20 | net_arg.add_argument('--use_syntax', type=str2bool, default=False) 21 | 22 | ################ 23 | # Data 24 | ################ 25 | 26 | data_arg = add_argument_group('Data') 27 | data_arg.add_argument('--data_dir', type=str, default='data') 28 | data_arg.add_argument('--data_ext', type=str, default='npz') 29 | data_arg.add_argument('--mode', type=str, default='token', choices=['text', 'token']) 30 | data_arg.add_argument('--beautify', type=str2bool, default=False) 31 | 32 | # grid world 33 | data_arg.add_argument('--world_height', type=int, default=8, help='Height of square grid world') 34 | data_arg.add_argument('--world_width', type=int, default=8, help='Width of square grid world') 35 | data_arg.add_argument('--max_marker_in_cell', type=int, default=8) 36 | data_arg.add_argument('--wall_ratio', type=float, default=0.1) 37 | data_arg.add_argument('--marker_ratio', type=float, default=0.1) 38 | 39 | # # of data 40 | data_arg.add_argument('--num_train', type=int, default=1000000) 41 | data_arg.add_argument('--num_test', type=int, default=5000) 42 | data_arg.add_argument('--num_val', type=int, default=5000) 43 | data_arg.add_argument('--num_spec', type=int, default=5) 44 | data_arg.add_argument('--num_heldout', type=int, default=1) 45 | 46 | # program limitations 47 | data_arg.add_argument('--max_depth', type=int, default=4) 48 | data_arg.add_argument('--min_move', type=int, default=0) 49 | data_arg.add_argument('--max_func_call', type=int, default=50, 50 | help="Max # of function call in a single run") 51 | 52 | ################ 53 | # Train 54 | ################ 55 | 56 | train_arg = add_argument_group('Train') 57 | train_arg.add_argument('--base_dir', type=str, default='logs') 58 | train_arg.add_argument('--model_path', type=str, default=None, 59 | help='default is {config.base_dir}/{config.tag}_{timestring}') 60 | train_arg.add_argument('--pretrain_path', type=str, default=None) 61 | train_arg.add_argument('--epoch', type=int, default=100) 62 | train_arg.add_argument('--lr', type=float, default=0.001) 63 | train_arg.add_argument('--seed', type=int, default=123) 64 | train_arg.add_argument('--use_rl', type=str2bool, default=False) 65 | train_arg.add_argument('--batch_size', type=int, default=32) 66 | train_arg.add_argument('--max_step', type=int, default=100000000) 67 | 68 | ################ 69 | # Test 70 | ################ 71 | 72 | test_arg = add_argument_group('Test') 73 | test_arg.add_argument('--world', type=str, default=None) 74 | 75 | ################ 76 | # ETC 77 | ################ 78 | 79 | etc_arg = add_argument_group('ETC') 80 | etc_arg.add_argument('--train', type=str2bool, default=True, 81 | help='whether run under train or test mode') 82 | etc_arg.add_argument('--tag', type=str, default='karel') 83 | etc_arg.add_argument('--log_step', type=int, default=100) 84 | etc_arg.add_argument('--max_summary', type=int, default=3) 85 | etc_arg.add_argument('--debug', type=str2bool, default=False) 86 | etc_arg.add_argument('--parser_debug', type=str2bool, default=False) 87 | 88 | 89 | def get_config(): 90 | config, unparsed = parser.parse_known_args() 91 | return config, unparsed 92 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from collections import namedtuple 5 | 6 | from utils import get_rng 7 | from karel import KarelForSynthesisParser 8 | from karel import str2bool, makedirs, pprint, beautify, TimeoutError 9 | 10 | Data = namedtuple('Data', 'input, output, code') 11 | 12 | def try_beautify(x): 13 | try: 14 | x = beautify(x) 15 | except: 16 | pass 17 | return x 18 | 19 | class Dataset(object): 20 | tokens = [] 21 | idx_to_token = {} 22 | 23 | def __init__(self, config, rng=None, load=True, shuffle=False): 24 | self.config = config 25 | self.rng = get_rng(rng) 26 | 27 | self.inputs, self.outputs, self.codes, self.code_lengths = {}, {}, {}, {} 28 | self.input_strings, self.output_strings = {}, {} 29 | self.with_input_string = False 30 | 31 | self.iterator = {} 32 | self._inputs, self._outputs, self._codes, self._code_lengths = {}, {}, {}, {} 33 | self._input_strings, self._output_strings = {}, {} 34 | 35 | self.data_names = ['train', 'test', 'val'] 36 | self.data_paths = { 37 | key: os.path.join(config.data_dir, '{}.{}'.format(key, config.data_ext)) \ 38 | for key in self.data_names 39 | } 40 | 41 | if load: 42 | self.load_data() 43 | for name in self.data_names: 44 | self.build_tf_data(name) 45 | if shuffle: 46 | self.shuffle() 47 | 48 | def build_tf_data(self, name): 49 | if self.config.train: 50 | batch_size = self.config.batch_size 51 | else: 52 | batch_size = 1 53 | 54 | # inputs, outputs 55 | data = [ 56 | self._inputs[name], self._outputs[name], self._code_lengths[name] 57 | ] 58 | if self.with_input_string: 59 | data.extend([self._input_strings[name], self._output_strings[name]]) 60 | 61 | in_out = tf.data.Dataset.from_tensor_slices(tuple(data)).repeat() 62 | batched_in_out = in_out.batch(batch_size) 63 | 64 | # codes 65 | code = tf.data.Dataset.from_generator(lambda: self._codes[name], tf.int32).repeat() 66 | batched_code = code.padded_batch(batch_size, padded_shapes=[None]) 67 | 68 | batched_data = tf.data.Dataset.zip((batched_in_out, batched_code)) 69 | iterator = batched_data.make_initializable_iterator() 70 | 71 | if self.with_input_string: 72 | (inputs, outputs, code_lengths, input_strings, output_strings), codes = iterator.get_next() 73 | 74 | input_strings = tf.cast(input_strings, tf.string) 75 | output_strings = tf.cast(output_strings, tf.string) 76 | else: 77 | (inputs, outputs, code_lengths), codes = iterator.get_next() 78 | 79 | inputs = tf.cast(inputs, tf.float32) 80 | outputs = tf.cast(outputs, tf.float32) 81 | code_lengths = tf.cast(code_lengths, tf.int32) 82 | 83 | self.inputs[name] = inputs 84 | self.outputs[name] = outputs 85 | self.codes[name] = codes 86 | self.code_lengths[name] = code_lengths 87 | self.iterator[name] = iterator 88 | 89 | if self.with_input_string: 90 | self.input_strings[name] = input_strings 91 | self.output_strings[name] = output_strings 92 | 93 | def get_data(self, name): 94 | data = { 95 | 'inputs': self.inputs[name], 96 | 'outputs': self.outputs[name], 97 | 'codes': self.codes[name], 98 | 'code_lengths': self.code_lengths[name], 99 | 'iterator': self.iterator[name] 100 | } 101 | if self.with_input_string: 102 | data.update({ 103 | 'input_strings': self.input_strings[name], 104 | 'output_strings': self.output_strings[name], 105 | }) 106 | return data 107 | 108 | def count(self, name): 109 | return len(self._inputs[name]) 110 | 111 | def shuffle(self): 112 | raise NotImplementedError 113 | 114 | def load_data(self): 115 | raise NotImplementedError 116 | 117 | @property 118 | def num_token(self): 119 | return len(self.tokens) 120 | 121 | def _idx_to_text(self, idxes, beautify): 122 | code = " ".join(self.token_to_text[ 123 | self.parser.idx_to_token_details[idx]] for idx in idxes).replace("\\", "") 124 | if beautify: 125 | code = try_beautify(code) 126 | return code 127 | 128 | def idx_to_text(self, idxes, markdown=False, beautify=False): 129 | if hasattr(idxes[0], '__len__'): 130 | if markdown: 131 | strings = ["\t{}".format(self._idx_to_text(idxes, beautify) \ 132 | .replace('\n', '\n\t')) for idxes in idxes] 133 | else: 134 | strings = [self._idx_to_text(idxes, beautify) for idxes in idxes] 135 | else: 136 | strings = self._idx_to_text(idxes, beautify) 137 | return np.array(strings) 138 | 139 | def run_and_test(self, batch_code, batch_example, **kwargs): 140 | batch_output = [] 141 | for code, examples in zip(batch_code, batch_example): 142 | outputs = [] 143 | tokens = [token.decode("utf-8") for token in code.split()] 144 | 145 | try: 146 | code = " ".join([token for token in tokens[:tokens.index('END')]]) 147 | 148 | for state in examples: 149 | try: 150 | self.parser.new_game(state=state) 151 | self.parser.run(code, **kwargs) 152 | output = self.parser.draw_for_tensorboard() 153 | except TimeoutError: 154 | output = 'time' 155 | except TypeError: 156 | output = 'type' 157 | except ValueError: 158 | output = 'value' 159 | outputs.append(output) 160 | except ValueError: 161 | outputs = ['no_end'] * len(examples) 162 | 163 | batch_output.append(outputs) 164 | return np.array(batch_output) 165 | 166 | 167 | class KarelDataset(Dataset): 168 | def __init__(self, config, *args, **kwargs): 169 | super(KarelDataset, self).__init__(config, *args, **kwargs) 170 | 171 | self.parser = KarelForSynthesisParser( 172 | rng=self.rng, max_func_call=config.max_func_call, debug=config.debug) 173 | 174 | self.tokens = ['END'] + self.parser.tokens_details 175 | self.token_to_text = { 'END': 'END' } 176 | 177 | for token in self.tokens: 178 | if token in ['END']: 179 | continue 180 | elif token.startswith('INT'): 181 | self.token_to_text[token] = token.replace('INT', self.parser.INT_PREFIX) 182 | continue 183 | 184 | item = getattr(self.parser, 't_{}'.format(token)) 185 | if callable(item): 186 | self.token_to_text[token] = token 187 | else: 188 | self.token_to_text[token] = item 189 | 190 | def load_data(self): 191 | self.data = {} 192 | for name in self.data_names: 193 | data = np.load(self.data_paths[name]) 194 | self._inputs[name] = data['inputs'] 195 | self._outputs[name] = data['outputs'] 196 | self._codes[name] = data['codes'] 197 | self._code_lengths[name] = data['code_lengths'] 198 | 199 | if 'input_strings' in data: 200 | self.with_input_string = True 201 | self._input_strings[name] = data['input_strings'] 202 | self._output_strings[name] = data['output_strings'] 203 | 204 | def shuffle(self): 205 | for name in self.data_names: 206 | self.rng.shuffle(self._inputs[name]) 207 | self.rng.shuffle(self._outputs[name]) 208 | self.rng.shuffle(self._codes[name]) 209 | 210 | 211 | if __name__ == '__main__': 212 | import os 213 | import argparse 214 | import numpy as np 215 | 216 | try: 217 | from tqdm import trange 218 | except: 219 | trange = lambda x, desc: range(x) 220 | 221 | from config import get_config 222 | config, _ = get_config() 223 | 224 | dataset = KarelDataset(config, load=False) 225 | parser = dataset.parser 226 | 227 | # Make directories 228 | makedirs(config.data_dir) 229 | datasets = ['train', 'test', 'val'] 230 | 231 | # Generate datasets 232 | 233 | def generate(): 234 | parser.flush_hit_info() 235 | code = parser.random_code( 236 | stmt_max_depth=config.max_depth, 237 | min_move=config.min_move, 238 | create_hit_info=True) 239 | return code 240 | 241 | if config.mode == 'text': 242 | for name in datasets: 243 | data_num = getattr(config, "num_{}".format(name)) 244 | 245 | text = "" 246 | text_path = os.path.join(config.data_dir, "{}.txt".format(name)) 247 | 248 | for _ in trange(data_num, desc=name): 249 | code = generate() 250 | if config.beautify: 251 | code = beautify(code) 252 | text += code + "\n" 253 | 254 | with open(text_path, 'w') as f: 255 | f.write(text) 256 | else: 257 | for name in datasets: 258 | data_num = getattr(config, "num_{}".format(name)) 259 | 260 | inputs, outputs, codes, code_lengths = [], [], [], [] 261 | input_strings, output_strings = [], [] 262 | 263 | for _ in trange(data_num, desc=name): 264 | while True: 265 | input_examples, output_examples = [], [] 266 | input_string_examples, output_string_examples = [], [] 267 | 268 | code = generate() 269 | #pprint(code) 270 | 271 | num_code_error, resample_code = 0, False 272 | while len(input_examples) < config.num_spec + config.num_heldout: 273 | if num_code_error > 5: 274 | resample_code = True 275 | break 276 | 277 | parser.new_game( 278 | world_size=(config.world_width, config.world_height), 279 | wall_ratio=config.wall_ratio, marker_ratio=config.marker_ratio) 280 | input_string = parser.draw_for_tensorboard() 281 | input_state = parser.get_state() 282 | 283 | try: 284 | parser.run(code, debug=config.parser_debug) 285 | output_state = parser.get_state() 286 | output_string = parser.draw_for_tensorboard() 287 | except TimeoutError: 288 | num_code_error += 1 289 | continue 290 | except IndexError: 291 | num_code_error += 1 292 | continue 293 | 294 | # input/output pair should be different 295 | if np.array_equal(input_state, output_state): 296 | num_code_error += 1 297 | continue 298 | 299 | input_examples.append(input_state) 300 | input_string_examples.append(input_string) 301 | output_examples.append(output_state) 302 | output_string_examples.append(output_string) 303 | 304 | # if there is at least one contionals 305 | if len(parser.hit_info) > 0: 306 | # if there are contionals not hitted 307 | if max(parser.hit_info.values()) > 0: 308 | continue 309 | 310 | if resample_code: 311 | continue 312 | 313 | inputs.append(input_examples) 314 | outputs.append(output_examples) 315 | 316 | input_strings.append(input_string_examples) 317 | output_strings.append(output_string_examples) 318 | 319 | token_idxes = parser.lex_to_idx(code, details=True) 320 | 321 | # Add END tokens for seq2seq prediction 322 | token_idxes = np.array(token_idxes, dtype=np.uint8) 323 | token_idxes_with_end = np.append( 324 | token_idxes, parser.token_to_idx_details['END']) 325 | 326 | codes.append(token_idxes_with_end) 327 | code_lengths.append(len(token_idxes_with_end)) 328 | break 329 | 330 | npz_path = os.path.join(config.data_dir, name) 331 | np.savez(npz_path, 332 | inputs=inputs, 333 | outputs=outputs, 334 | input_strings=input_strings, 335 | output_strings=output_strings, 336 | codes=codes, 337 | code_lengths=code_lengths) 338 | -------------------------------------------------------------------------------- /examples/paper_in.map: -------------------------------------------------------------------------------- 1 | .... 2 | .... 3 | .>.. 4 | .... 5 | -------------------------------------------------------------------------------- /examples/paper_out.map: -------------------------------------------------------------------------------- 1 | .... 2 | .oo. 3 | .>o. 4 | .... 5 | -------------------------------------------------------------------------------- /examples/simple.karel: -------------------------------------------------------------------------------- 1 | move() 2 | move() 3 | repeat(3): 4 | turn_left() 5 | move() 6 | move() 7 | -------------------------------------------------------------------------------- /examples/simple.map: -------------------------------------------------------------------------------- 1 | ...... 2 | >..... 3 | ...### 4 | .o.### 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils import prepare, set_random_seed 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from trainer import Trainer 9 | from config import get_config 10 | 11 | config = None 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 13 | 14 | def main(_): 15 | prepare(config) 16 | rng = set_random_seed(config.seed) 17 | 18 | sess_config = tf.ConfigProto( 19 | log_device_placement=False, 20 | allow_soft_placement=True) 21 | sess_config.gpu_options.allow_growth=True 22 | 23 | trainer = Trainer(config, rng) 24 | 25 | with tf.Session(config=sess_config) as sess: 26 | if config.train: 27 | trainer.train(sess) 28 | else: 29 | if not config.map: 30 | raise Exception("[!] You should specify `map` to synthesize a program") 31 | trainer.synthesize(sess, config.map) 32 | 33 | if __name__ == "__main__": 34 | config, unparsed = get_config() 35 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 36 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import Model 2 | -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.util import nest 4 | from tensorflow.contrib.seq2seq import Helper, TrainingHelper, BasicDecoder 5 | from tensorflow.contrib.rnn import \ 6 | RNNCell, LSTMCell, MultiRNNCell, OutputProjectionWrapper 7 | 8 | from .encoder import linear, int_shape 9 | 10 | def decoder(num_examples, codes, code_lengths, 11 | encoder_out, dataset, config, train_or_test): 12 | """ codes: [B, L] 13 | encoder_out: [BxN, 512] 14 | """ 15 | batch_size = tf.shape(codes)[0] 16 | batch_times_N = batch_size * num_examples 17 | 18 | num_token = dataset.num_token 19 | 20 | with tf.variable_scope("decoder"): 21 | # [BxN, L, 512] 22 | tiled_encoder_out = tf.tile( 23 | tf.expand_dims(encoder_out, 1), 24 | [1, tf.shape(codes)[1], 1], 25 | name="tiled_encoder_out") 26 | 27 | embed = tf.get_variable( 28 | 'embedding', [dataset.num_token, 256], dtype=tf.float32, 29 | initializer=tf.truncated_normal_initializer(stddev=0.5)) 30 | 31 | # [B, L, 256] 32 | code_embed = tf.nn.embedding_lookup(embed, codes) 33 | # [BxN, L, 256] 34 | tiled_code_embed = tf.tile( 35 | code_embed, [num_examples, 1, 1], name="tiled_code_embed") 36 | # [BxN, 256] 37 | start_code_embed = tf.tile( 38 | [[0.0]], [batch_times_N, 256], name="start_token") 39 | 40 | # [BxN, L, 768] 41 | rnn_train_inputs = tf.concat( 42 | [tiled_encoder_out, tiled_code_embed], -1, name="rnn_input") 43 | 44 | decoder_cell = MultiRNNCell([ 45 | LSTMCell(256), 46 | NaiveOutputProjectionWrapper( 47 | MaxPoolWrapper(LSTMCell(256), num_examples), 48 | num_token)], state_is_tuple=True) 49 | 50 | # [BxN, L, 256] -> [B, L, 256] #-> [BxN, L, 256] 51 | decoder_logits, decoder_argamx = build_rnn( 52 | train_or_test, 53 | cell=decoder_cell, 54 | rnn_train_inputs=rnn_train_inputs, 55 | start_code_embed=start_code_embed, 56 | batch_size=batch_times_N, 57 | target_lengths=code_lengths, 58 | embedding=embed, 59 | encoder_out=encoder_out, 60 | name="decoder_rnn") 61 | 62 | # [BxN, L, 256] -> [B, L, 256] 63 | decoder_logits = decoder_logits[:batch_size] 64 | decoder_argamx = decoder_argamx[:batch_size] 65 | 66 | if config.use_syntax: 67 | syntax_cell = MultiRNNCell([ 68 | LSTMCell(256), 69 | LSTMCell(256)], state_is_tuple=True) 70 | 71 | # [B, L, 256] 72 | syntax_out = build_rnn( 73 | train_or_test, 74 | cell=syntax_cell, 75 | rnn_train_inputs=code_embed, 76 | start_code_embed=start_code_embed, 77 | batch_size=batch_size, 78 | target_lengths=code_lengths, 79 | name="syntax_rnn") 80 | 81 | syntax_logits = linear(max_pool, num_token, "out") 82 | 83 | raise NotImplementedError("TODO") 84 | decoder_logits = decoder_logits + syntax_logits 85 | 86 | return decoder_logits, decoder_argamx 87 | 88 | def build_rnn(train_or_test, cell, rnn_train_inputs, start_code_embed, 89 | batch_size, target_lengths, embedding, encoder_out, name): 90 | 91 | if train_or_test: 92 | helper = DefaultZeroInputTrainingHelper( 93 | rnn_train_inputs, target_lengths, encoder_out, start_code_embed) 94 | else: 95 | helper = TestEmbeddingConcatHelper( 96 | batch_size, embedding, encoder_out, start_code_embed) 97 | 98 | initial_state = cell.zero_state( 99 | batch_size=batch_size, dtype=tf.float32) 100 | 101 | (decoder_outputs, decoder_samples), final_decoder_state, _ = \ 102 | tf.contrib.seq2seq.dynamic_decode( 103 | BasicDecoder(cell, helper, initial_state), scope=name) 104 | 105 | return decoder_outputs, decoder_samples 106 | 107 | 108 | class DefaultZeroInputTrainingHelper(TrainingHelper): 109 | def __init__(self, inputs, sequence_length, encoder_out, start_code_embed, 110 | time_major=False, name=None): 111 | super(DefaultZeroInputTrainingHelper, self). \ 112 | __init__(inputs, sequence_length, time_major, name) 113 | 114 | self._start_inputs = tf.concat([ 115 | encoder_out, start_code_embed], -1) 116 | 117 | def initialize(self, name=None): 118 | finished = tf.tile([False], [self._batch_size]) 119 | return (finished, self._start_inputs) 120 | 121 | 122 | class TestEmbeddingConcatHelper(Helper): 123 | def __init__(self, batch_size, embedding, encoder_out, start_code_embed): 124 | # batch_times_N 125 | self._batch_size = batch_size 126 | self._encoder_out = encoder_out 127 | self._start_code_embed = start_code_embed 128 | 129 | self._embedding_fn = ( 130 | lambda ids: tf.nn.embedding_lookup(embedding, ids)) 131 | self._start_inputs = tf.concat([ 132 | self._encoder_out, self._start_code_embed], -1) 133 | self._end_token = 0 134 | 135 | @property 136 | def batch_size(self): 137 | return self._batch_size 138 | 139 | @property 140 | def sample_ids_shape(self): 141 | return tf.TensorShape([]) 142 | 143 | @property 144 | def sample_ids_dtype(self): 145 | return np.int32 146 | 147 | def initialize(self, name=None): 148 | finished = tf.tile([False], [self._batch_size]) 149 | return (finished, self._start_inputs) 150 | 151 | def sample(self, time, outputs, state, name=None): 152 | #del time, state # unused by sample_fn 153 | sample_ids = tf.cast(tf.argmax(outputs, axis=-1), tf.int32) 154 | return sample_ids 155 | 156 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 157 | #del time, outputs # unused by next_inputs_fn 158 | finished = tf.equal(sample_ids, self._end_token) 159 | all_finished = tf.reduce_all(finished) 160 | sampled_embed = tf.cond( 161 | all_finished, 162 | lambda: self._start_code_embed, 163 | lambda: self._embedding_fn(sample_ids)) 164 | next_inputs = tf.concat([self._encoder_out, sampled_embed], -1) 165 | return (finished, next_inputs, state) 166 | 167 | 168 | class MaxPoolWrapper(RNNCell): 169 | def __init__(self, cell, num_examples, reuse=None): 170 | super(MaxPoolWrapper, self).__init__(_reuse=reuse) 171 | self._cell = cell 172 | self._num_examples = num_examples 173 | 174 | @property 175 | def state_size(self): 176 | return self._cell.state_size 177 | 178 | @property 179 | def output_size(self): 180 | return self._output_size 181 | 182 | def zero_state(self, batch_size, dtype): 183 | with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 184 | return self._cell.zero_state(batch_size, dtype) 185 | 186 | def call(self, inputs, state): 187 | output, res_state = self._cell(inputs, state) 188 | B_times_N, cell_dim = int_shape(output) 189 | 190 | # [BxN, 256] -> [B, N, 256] 191 | decoder_out = tf.reshape( 192 | output, [-1, self._num_examples, cell_dim]) 193 | 194 | # [B, 256] 195 | max_pool = tf.reduce_max(decoder_out, 1) 196 | 197 | # [Bx2, 256] 198 | max_pool = tf.tile(max_pool, [self._num_examples, 1]) 199 | 200 | return max_pool, res_state 201 | 202 | 203 | class NaiveOutputProjectionWrapper(OutputProjectionWrapper): 204 | def __init__(self, cell, output_size, activation=None, reuse=None): 205 | try: 206 | super(NaiveOutputProjectionWrapper, self). \ 207 | __init__(cell, output_size, activation, reuse) 208 | except TypeError: 209 | pass 210 | 211 | self._cell = cell 212 | self._output_size = output_size 213 | self._activation = activation 214 | self._linear = None 215 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell 4 | 5 | def encoder(inputs, outputs, codes, dataset): 6 | input_shape = int_shape(inputs) 7 | output_shape = int_shape(outputs) 8 | 9 | # make examples to batch. batch_size := num_examples * batch_size 10 | flat_inputs = tf.reshape(inputs, [-1] + input_shape[2:]) 11 | flat_outputs = tf.reshape(outputs, [-1] + output_shape[2:]) 12 | 13 | with tf.variable_scope("encoder"): 14 | input_enc = conv2d(flat_inputs, 32, name="input") 15 | output_enc = conv2d(flat_outputs, 32, name="output") 16 | 17 | conv_fn = lambda x, name: conv2d(x, 64, name=name) 18 | 19 | grid_enc = tf.concat([input_enc, output_enc], -1) 20 | res1 = residual_block(grid_enc, conv_fn, 3, "res1") 21 | res2 = residual_block(res1, conv_fn, 3, "res2") 22 | 23 | # [BxN, 512] 24 | cnn_out = linear(flatten(res2), 512, "cnn_out") 25 | 26 | return cnn_out 27 | 28 | def linear(x, dim, name): 29 | return tf.layers.dense(x, dim, name=name) 30 | 31 | def flatten(x): 32 | shape = int_shape(x) 33 | last_dim = np.prod(shape[1:]) 34 | return tf.reshape(x, [-1, last_dim], name="flat") 35 | 36 | def residual_block( 37 | x, conv_fn, depth, name="res_block"): 38 | with tf.variable_scope(name): 39 | out = x 40 | for idx in range(depth): 41 | out = conv_fn(out, name="conv{}".format(idx)) 42 | out += x 43 | return out 44 | 45 | def conv2d( 46 | x, 47 | filters, 48 | kernel_size=(3, 3), 49 | activation=tf.nn.relu, 50 | padding='same', 51 | name="conv2d"): 52 | out = tf.layers.conv2d( 53 | x, 54 | filters=filters, 55 | kernel_size=kernel_size, 56 | activation=activation, 57 | padding=padding, 58 | name=name, 59 | ) 60 | return out 61 | 62 | def int_shape(x): 63 | return list(x.get_shape().as_list()) 64 | -------------------------------------------------------------------------------- /models/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import tensorflow as tf 4 | 5 | from .decoder import decoder 6 | from .encoder import encoder 7 | 8 | logger = logging.getLogger("main") 9 | 10 | def model(inputs, outputs, codes, code_lengths, dataset, config, train): 11 | num_examples = tf.shape(inputs)[1] # config.num_spec + config.num_heldout 12 | 13 | encoder_out = encoder( 14 | inputs, outputs, codes, dataset) 15 | decoder_out = decoder( 16 | num_examples, codes, code_lengths, encoder_out, dataset, config, train) 17 | return decoder_out 18 | 19 | model_fn = tf.make_template('model', model) 20 | 21 | class Model(object): 22 | def __init__(self, dataset, config, global_step, train): 23 | self.train = train 24 | self.config = config 25 | self.dataset = dataset 26 | self.global_step = global_step 27 | 28 | data = dataset.get_data('train' if train else 'test') 29 | 30 | self.inputs = data['inputs'] 31 | self.outputs = data['outputs'] 32 | self.codes = data['codes'] 33 | self.code_lengths = data['code_lengths'] 34 | self.iterator = data['iterator'] 35 | 36 | self.inputs_without_heldout = self.inputs[:,:config.num_spec] 37 | self.outputs_without_heldout = self.outputs[:,:config.num_spec] 38 | 39 | if dataset.with_input_string: 40 | self.input_strings = data['input_strings'] 41 | self.output_strings = data['output_strings'] 42 | else: 43 | self.input_strings = None 44 | self.output_strings = None 45 | 46 | self.logits, self.argmax = model_fn( 47 | self.inputs_without_heldout, 48 | outputs=self.outputs_without_heldout, 49 | codes=self.codes, 50 | code_lengths=self.code_lengths, 51 | dataset=dataset, 52 | config=config, 53 | train=train) 54 | 55 | self.loss_mask = tf.sequence_mask(self.code_lengths, dtype=tf.float32) 56 | self.mle_loss = tf.contrib.seq2seq.sequence_loss( 57 | logits=self.logits, 58 | targets=self.codes, 59 | weights=self.loss_mask, 60 | name="MLE_loss") 61 | 62 | if config.use_rl: 63 | self.loss = None 64 | else: 65 | self.loss = self.mle_loss 66 | 67 | self.build_summary() 68 | 69 | def build_optim(self): 70 | self.optim = tf.train.AdamOptimizer(self.config.lr) \ 71 | .minimize(self.loss, self.global_step) 72 | 73 | def build_summary(self): 74 | max_summary = self.config.max_summary 75 | 76 | idx_to_code = lambda x: self.dataset.idx_to_text(x, markdown=True, beautify=False) 77 | idx_to_beatified_code= lambda x: self.dataset.idx_to_text(x, markdown=True, beautify=True) 78 | 79 | gt_codes = tf.py_func( 80 | idx_to_code, [self.codes[:max_summary]], 81 | tf.string, name="gt_codes") 82 | argmax_codes = tf.py_func( 83 | idx_to_code, [self.argmax[:max_summary]], 84 | tf.string, name="argmax_codes") 85 | 86 | gt_clean_codes = tf.py_func( 87 | idx_to_beatified_code, [self.codes[:max_summary]], 88 | tf.string, name="gt_clean_codes") 89 | argmax_clean_codes = tf.py_func( 90 | idx_to_beatified_code, [self.argmax[:max_summary]], 91 | tf.string, name="argmax_clean_codes") 92 | 93 | # Test spec + heldout examples 94 | run = lambda x, y: self.dataset.run_and_test(x, y) 95 | outputs_of_argmax_codes = tf.py_func( 96 | run, [argmax_codes, self.inputs[:max_summary]], 97 | tf.string, name="outputs_of_argmax_codes") 98 | 99 | summaries = [ 100 | tf.summary.scalar("loss/total", self.loss), 101 | tf.summary.scalar("loss/mle", self.mle_loss), 102 | tf.summary.text("clean_code_gt", gt_clean_codes), 103 | tf.summary.text("clean_code_argmax", argmax_clean_codes), 104 | tf.summary.text("code_gt", gt_codes), 105 | tf.summary.text("code_argmax", argmax_codes), 106 | tf.summary.text("outputs_pred_until_END", outputs_of_argmax_codes), 107 | ] 108 | if self.input_strings is not None: 109 | output_match = tf.cast(tf.equal( 110 | outputs_of_argmax_codes, 111 | self.output_strings[:max_summary]), tf.float32) 112 | 113 | def bool_to_str(x): 114 | array = x.astype(int).astype(str) 115 | array[array == '0'] = 'x' 116 | array[array == '1'] = 'o' 117 | return array 118 | 119 | output_match_strings = tf.py_func(bool_to_str, [output_match], tf.string) 120 | 121 | self.code_match = tf.cast(tf.equal(self.argmax, self.codes), tf.float32) 122 | self.code_match = tf.reduce_mean( 123 | tf.minimum(self.code_match + (1 - self.loss_mask), 1)) 124 | 125 | summaries.extend([ 126 | tf.summary.text("inputs_gt", self.input_strings[:max_summary]), 127 | tf.summary.text("outputs_gt", self.output_strings[:max_summary]), 128 | tf.summary.text("outputs_match", tf.cast(output_match_strings, tf.string)), 129 | 130 | tf.summary.scalar("output_match/spec", 131 | tf.reduce_mean(output_match[:,:self.config.num_spec])), 132 | tf.summary.scalar("output_match/heldout", 133 | tf.reduce_mean(output_match[:,:self.config.num_heldout])), 134 | tf.summary.scalar("output_match/total", 135 | tf.reduce_mean(output_match)), 136 | 137 | tf.summary.scalar("code/match", self.code_match), 138 | #tf.summary.scalar("metric/generalization", tf.reduce_mean(output_match)), 139 | ]) 140 | 141 | self.summary_op = tf.summary.merge(summaries) 142 | 143 | summary_path = os.path.join( 144 | self.config.model_path, 'train' if self.train else 'test') 145 | self.writer = tf.summary.FileWriter(summary_path) 146 | 147 | def run(self, sess, with_update=True, with_summary=False): 148 | fetches = { 'step': self.global_step } 149 | 150 | if with_update: 151 | fetches['optim'] = self.optim 152 | if with_summary: 153 | fetches['summary'] = self.summary_op 154 | fetches['loss'] = self.loss 155 | 156 | out = sess.run(fetches) 157 | step = out['step'] 158 | 159 | if with_summary: 160 | self.writer.add_summary(out['summary'], step) 161 | logger.info("[INFO] loss: {:.4f}".format(out['loss'])) 162 | 163 | return step 164 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import trange 3 | import tensorflow as tf 4 | from models import encoder, decoder 5 | 6 | from models import Model 7 | from dataset import KarelDataset 8 | 9 | class Trainer(object): 10 | def __init__(self, config, rng=None): 11 | self.config = config 12 | self.dataset = KarelDataset(config, rng) 13 | 14 | self.global_step = tf.train.get_or_create_global_step() 15 | 16 | self.train_model = Model(self.dataset, self.config, self.global_step, train=True) 17 | self.train_model.build_optim() 18 | 19 | def train(self, sess): 20 | sess.run(tf.global_variables_initializer()) 21 | sess.run(self.train_model.iterator.initializer) 22 | 23 | for _ in trange(self.config.max_step): 24 | step = self.train_model.run(sess) 25 | 26 | if step % self.config.log_step == 0: 27 | step = self.train_model.run( 28 | sess, with_update=False, with_summary=True) 29 | 30 | def test(self): 31 | pass 32 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import errno 4 | import signal 5 | import logging 6 | import numpy as np 7 | import tensorflow as tf 8 | from functools import wraps 9 | from datetime import datetime 10 | from pyparsing import nestedExpr 11 | 12 | logger = logging.getLogger("main") 13 | 14 | 15 | def get_time(): 16 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 17 | 18 | def str2bool(v): 19 | return v.lower() in ('true', '1') 20 | 21 | class Tcolors: 22 | CYAN = '\033[1;30m' 23 | HEADER = '\033[95m' 24 | OKBLUE = '\033[94m' 25 | OKGREEN = '\033[92m' 26 | WARNING = '\033[93m' 27 | FAIL = '\033[91m' 28 | ENDC = '\033[0m' 29 | BOLD = '\033[1m' 30 | UNDERLINE = '\033[4m' 31 | 32 | class TimeoutError(Exception): 33 | pass 34 | 35 | def timeout_fn(timeout=10, error_message=os.strerror(errno.ETIME)): 36 | def decorator(func): 37 | def _handle_timeout(signum, frame): 38 | raise TimeoutError(error_message) 39 | 40 | def wrapper(*args, **kwargs): 41 | signal.signal(signal.SIGALRM, _handle_timeout) 42 | signal.setitimer(signal.ITIMER_REAL, timeout) #used timer instead of alarm 43 | try: 44 | result = func(*args, **kwargs) 45 | finally: 46 | signal.alarm(0) 47 | return result 48 | return wraps(func)(wrapper) 49 | return decorator 50 | 51 | def beautify_fn(inputs, indent=1, tabspace=2): 52 | lines, queue = [], [] 53 | space = tabspace * " " 54 | 55 | for item in inputs: 56 | if item == ";": 57 | lines.append(" ".join(queue)) 58 | queue = [] 59 | elif type(item) == str: 60 | queue.append(item) 61 | else: 62 | lines.append(" ".join(queue + ["{"])) 63 | queue = [] 64 | 65 | inner_lines = beautify_fn(item, indent=indent+1, tabspace=tabspace) 66 | lines.extend([space + line for line in inner_lines[:-1]]) 67 | lines.append(inner_lines[-1]) 68 | 69 | if len(queue) > 0: 70 | lines.append(" ".join(queue)) 71 | 72 | return lines + ["}"] 73 | 74 | def pprint(code, *args, **kwargs): 75 | print(beautify(code, *args, **kwargs)) 76 | 77 | def beautify(code, tabspace=2): 78 | array = nestedExpr('{','}').parseString("{"+code+"}").asList() 79 | lines = beautify_fn(array[0]) 80 | return "\n".join(lines[:-1]).replace(' ( ', '(').replace(' )', ')') 81 | 82 | def makedirs(path): 83 | if not os.path.exists(path): 84 | logger.info("[MAKE] directory: {}".format(path)) 85 | os.makedirs(path) 86 | 87 | def get_rng(rng, seed=123): 88 | if rng is None: 89 | rng = np.random.RandomState(seed) 90 | return rng 91 | 92 | def load_config(config, skip_list=[]): 93 | config_keys = vars(config).keys() 94 | config_path = os.path.join(config.load_path, PARAMS_NAME) 95 | 96 | with open(path) as fp: 97 | new_config = json.load(fp) 98 | 99 | for key, value in new_config.items(): 100 | if key in skip_list: 101 | continue 102 | 103 | original_value = getattr(config, key) 104 | if original_value != value: 105 | logger.info("[UPDATE] config `{}`: {} -> {}".format(key, getattr(config, key), value)) 106 | setattr(config, key, value) 107 | 108 | def save_config(config, config_filename='config.json'): 109 | config_path = os.path.join(config.model_path, config_filename) 110 | with open(config_path, 'w') as fp: 111 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 112 | logger.info('[SAVE] config: {}'.format(config_path)) 113 | 114 | def prepare(config): 115 | formatter = logging.Formatter( 116 | "%(levelname)s:%(asctime)s::%(message)s") 117 | handler = logging.StreamHandler() 118 | handler.setFormatter(formatter) 119 | 120 | logger.addHandler(handler) 121 | logger.setLevel(tf.logging.INFO) 122 | 123 | if config.model_path is None: 124 | config.model_name = "{}_{}".format(config.tag, get_time()) 125 | config.model_path = os.path.join(config.base_dir, config.model_name) 126 | 127 | makedirs(config.model_path) 128 | save_config(config) 129 | else: 130 | if not config.model_path.startswith(config.base_dir): 131 | new_path = os.path.join(config.base_dir, new_path) 132 | setattr(config, 'model_path', new_path) 133 | logger.info("[SET] model_path: {}".format(config.model_path)) 134 | 135 | def set_random_seed(seed): 136 | tf.set_random_seed(seed) 137 | rng = np.random.RandomState(seed) 138 | return rng 139 | --------------------------------------------------------------------------------