├── LICENSE ├── README.md ├── asset ├── generalization.png ├── model.jpg └── teaser.png ├── evaler.py ├── karel_env ├── README.md ├── __init__.py ├── add_per.py ├── append_demonstration.py ├── asset │ └── texture.hdf5 ├── dataset_karel.py ├── dsl │ ├── __init__.py │ ├── dsl_base.py │ ├── dsl_enum_program.py │ ├── dsl_parse.py │ ├── dsl_prob.py │ ├── dsl_prob_syntax.py │ ├── dsl_syntax.py │ └── third_party │ │ ├── __init__.py │ │ └── yacc.py ├── generate_dataset.sh ├── generator.py ├── input_ops_karel.py ├── karel.py ├── karel_util.py ├── state_generator.py ├── tool │ ├── eval_execution.py │ ├── inspect_output.py │ └── visualize_data.py └── util.py ├── models ├── __init__.py ├── baselines │ ├── __init__.py │ ├── model_induction.py │ ├── model_summarizer.py │ └── model_synthesis.py ├── model_full.py ├── ops.py ├── seq2seq_helper.py └── util.py ├── requirements.txt ├── trainer.py └── vizdoom_env ├── README.md ├── __init__.py ├── asset ├── default.cfg └── scenarios │ ├── README.md │ ├── basic.cfg │ ├── basic.wad │ ├── bots.cfg │ ├── cig.cfg │ ├── cig.wad │ ├── cig_with_unknown.wad │ ├── deadly_corridor.cfg │ ├── deadly_corridor.wad │ ├── deathmatch.cfg │ ├── deathmatch.wad │ ├── deathmatch.wad.bak │ ├── defend_the_center.cfg │ ├── defend_the_center.wad │ ├── defend_the_line.cfg │ ├── defend_the_line.wad │ ├── doom2.wad │ ├── doom_state.wad │ ├── health_gathering.cfg │ ├── health_gathering.wad │ ├── health_gathering_supreme.cfg │ ├── health_gathering_supreme.wad │ ├── learning.cfg │ ├── multi.cfg │ ├── multi_deathmatch.wad │ ├── multi_duel.cfg │ ├── multi_duel.wad │ ├── my_way_home.cfg │ ├── my_way_home.wad │ ├── predict_position.cfg │ ├── predict_position.wad │ ├── rocket_basic.cfg │ ├── rocket_basic.wad │ ├── simpler_basic.cfg │ ├── simpler_basic.wad │ ├── take_cover.cfg │ └── take_cover.wad ├── dataset_vizdoom.py ├── dsl ├── __init__.py ├── dsl_enum_program.py ├── dsl_hit_analysis.py ├── dsl_parse.py ├── random_code_generator.py ├── random_code_generator_ifelse.py └── vocab.py ├── generate_dataset.sh ├── generator.py ├── generator_ifelse.py ├── input_ops_vizdoom.py ├── measure_program_fix_accuracy.py ├── merge_datasets.py ├── util.py └── vizdoom_env.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shao-Hua Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /asset/generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/asset/generalization.png -------------------------------------------------------------------------------- /asset/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/asset/model.jpg -------------------------------------------------------------------------------- /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/asset/teaser.png -------------------------------------------------------------------------------- /karel_env/README.md: -------------------------------------------------------------------------------- 1 | # Karel Environment 2 | 3 | This directory includes code for Karel environments, which includes: 4 | - Random programs and demonstration generator 5 | - Domain specific language interpreter 6 | 7 | ## Dataset generation 8 | Dataset used in the paper is generated with the following script 9 | ```bash 10 | ./karel/generate_dataset.sh 11 | ``` 12 | ## Domain specific language 13 | The interpreter and random program generator for Karel domain specific language (DSL) is in the [dsl directory](./dsl). 14 | You can find detailed definition of the DSL from the supplementary material of the paper. 15 | -------------------------------------------------------------------------------- /karel_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/karel_env/__init__.py -------------------------------------------------------------------------------- /karel_env/add_per.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import os 4 | import numpy as np 5 | from karel import Karel_world 6 | 7 | 8 | parser = argparse.ArgumentParser( 9 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--dir_name', type=str, default='datasets/karel_dataset', 11 | help=" ") 12 | args = parser.parse_args() 13 | 14 | 15 | def s2per(demo_data, demo_data_len): 16 | per_stack = [] 17 | for i in range(demo_data.shape[0]): 18 | per_stack_t = [] 19 | for j in range(demo_data.shape[1]): 20 | if j < demo_data_len[i]: 21 | s = demo_data[i, j] 22 | k = Karel_world(s) 23 | per = np.array([k.front_is_clear(), k.left_is_clear(), 24 | k.right_is_clear(), k.marker_present(), 25 | k.no_marker_present()]) 26 | else: 27 | per = np.zeros([5]) 28 | per_stack_t.append(per) 29 | per_stack.append(np.stack(per_stack_t)) 30 | per_stack = np.stack(per_stack) 31 | return per_stack 32 | 33 | # Your dataset path to data.hdf5 34 | dataset_path_all = [os.path.join(args.dir_name, 'data.hdf5')] 35 | for dataset_path in dataset_path_all: 36 | f = h5py.File(dataset_path, 'r+') 37 | count = 0 38 | for key in f.keys(): 39 | count += 1 40 | print('{}: {}/{}'.format(dataset_path, count-1, len(f.keys())-1)) 41 | if not key == 'data_info': 42 | per = s2per(f[key]['s_h'], f[key]['s_h_len']) 43 | try: 44 | f.__delitem__(key+'/per') 45 | except: 46 | pass 47 | f[key]['per'] = per 48 | 49 | try: 50 | per = s2per(f[key]['test_s_h'], f[key]['test_s_h_len']) 51 | try: 52 | f.__delitem__(key+'/test_per') 53 | except: 54 | pass 55 | f[key]['test_per'] = per 56 | except: 57 | pass 58 | f.close() 59 | -------------------------------------------------------------------------------- /karel_env/append_demonstration.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import h5py 6 | import os 7 | import argparse 8 | import progressbar 9 | 10 | import numpy as np 11 | 12 | from dsl import get_KarelDSL 13 | from util import log 14 | 15 | import karel 16 | 17 | """ Purpose of file is to append test demonstration data to existing dataset 18 | """ 19 | 20 | 21 | class KarelStateGenerator(object): 22 | def __init__(self, seed=None): 23 | self.rng = np.random.RandomState(seed) 24 | 25 | # generate an initial env 26 | def generate_single_state(self, h=8, w=8, wall_prob=0.1): 27 | s = np.zeros([h, w, 16]) > 0 28 | # Wall 29 | s[:, :, 4] = self.rng.rand(h, w) > 1 - wall_prob 30 | s[0, :, 4] = True 31 | s[h-1, :, 4] = True 32 | s[:, 0, 4] = True 33 | s[:, w-1, 4] = True 34 | # Karel initial location 35 | valid_loc = False 36 | while(not valid_loc): 37 | y = self.rng.randint(0, h) 38 | x = self.rng.randint(0, w) 39 | if not s[y, x, 4]: 40 | valid_loc = True 41 | s[y, x, self.rng.randint(0, 4)] = True 42 | # Marker: num of max marker == 1 for now 43 | s[:, :, 6] = (self.rng.rand(h, w) > 0.9) * (s[:, :, 4] == False) > 0 44 | s[:, :, 5] = 1 - (np.sum(s[:, :, 6:], axis=-1) > 0) > 0 45 | assert np.sum(s[:, :, 5:]) == h*w, np.sum(s[:, :, :5]) 46 | marker_weight = np.reshape(np.array(range(11)), (1, 1, 11)) 47 | return s, y, x, np.sum(s[:, :, 4]), np.sum(marker_weight*s[:, :, 5:]) 48 | 49 | 50 | def generator(config): 51 | dir_name = config.dir_name 52 | h = config.height 53 | w = config.width 54 | c = len(karel.state_table) 55 | 56 | wall_prob = config.wall_prob 57 | 58 | # output files 59 | f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'r+') 60 | dsl_type = f['data_info']['dsl_type'].value 61 | 62 | with open(os.path.join(dir_name, 'id.txt'), 'r') as id_file: 63 | ids = [s.strip() for s in id_file.readlines() if s] 64 | 65 | num_train = f['data_info']['num_train'].value 66 | num_test = f['data_info']['num_test'].value 67 | num_val = f['data_info']['num_val'].value 68 | num_total = num_train + num_test + num_val 69 | 70 | # progress bar 71 | bar = progressbar.ProgressBar(maxval=100, 72 | widgets=[progressbar.Bar('=', '[', ']'), ' ', 73 | progressbar.Percentage()]) 74 | bar.start() 75 | 76 | dsl = get_KarelDSL(dsl_type=dsl_type, seed=config.seed) 77 | s_gen = KarelStateGenerator(seed=config.seed) 78 | karel_world = karel.Karel_world() 79 | 80 | count = 0 81 | max_demo_length_in_dataset = -1 82 | max_program_length_in_dataset = -1 83 | for id_ in ids: 84 | grp = f[id_] 85 | # Reads a single program 86 | program_seq = grp['program'].value 87 | program_code = dsl.intseq2str(program_seq) 88 | 89 | test_s_h_list = [] 90 | a_h_list = [] 91 | num_demo = 0 92 | while num_demo < config.num_test_demo_per_program: 93 | try: 94 | s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob) 95 | karel_world.set_new_state(s) 96 | s_h = dsl.run(karel_world, program_code) 97 | except RuntimeError: 98 | pass 99 | else: 100 | if len(karel_world.s_h) <= config.max_demo_length and \ 101 | len(karel_world.s_h) >= config.min_demo_length: 102 | test_s_h_list.append(np.stack(karel_world.s_h, axis=0)) 103 | a_h_list.append(np.array(karel_world.a_h)) 104 | num_demo += 1 105 | 106 | len_test_s_h = np.array([s_h.shape[0] for s_h in test_s_h_list], dtype=np.int16) 107 | 108 | demos_test_s_h = np.zeros([num_demo, np.max(len_test_s_h), h, w, c], dtype=bool) 109 | for i, s_h in enumerate(test_s_h_list): 110 | demos_test_s_h[i, :s_h.shape[0]] = s_h 111 | 112 | len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16) 113 | 114 | demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8) 115 | for i, a_h in enumerate(a_h_list): 116 | demos_a_h[i, :a_h.shape[0]] = a_h 117 | 118 | max_demo_length_in_dataset = max(max_demo_length_in_dataset, np.max(len_test_s_h)) 119 | max_program_length_in_dataset = max(max_program_length_in_dataset, program_seq.shape[0]) 120 | 121 | try: 122 | f.__delitem__(id_+'/test_s_h_len') 123 | f.__delitem__(id_+'/test_a_h_len') 124 | f.__delitem__(id_+'/test_s_h') 125 | f.__delitem__(id_+'/test_a_h') 126 | except: 127 | pass 128 | 129 | # Save testing state 130 | grp['test_s_h_len'] = len_test_s_h 131 | grp['test_a_h_len'] = len_a_h 132 | grp['test_s_h'] = demos_test_s_h 133 | grp['test_a_h'] = demos_a_h 134 | # progress bar 135 | count += 1 136 | if count % (num_total / 100) == 0: 137 | bar.update(count / (num_total / 100)) 138 | 139 | try: 140 | f.__delitem__('data_info/num_test_demo_per_program') 141 | except: 142 | pass 143 | f['data_info']['num_test_demo_per_program'] = config.num_test_demo_per_program 144 | bar.finish() 145 | f.close() 146 | id_file.close() 147 | log.info('Dataset generated under {} with {}' 148 | ' samples ({} for training and {} for testing ' 149 | 'and {} for val'.format(dir_name, num_total, 150 | num_train, num_test, num_val)) 151 | 152 | 153 | def main(): 154 | parser = argparse.ArgumentParser( 155 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 156 | parser.add_argument('--dir_name', type=str, default='datasets/karel_dataset', 157 | help=" ") 158 | parser.add_argument('--height', type=int, default=8, 159 | help='height of square grid world') 160 | parser.add_argument('--width', type=int, default=8, 161 | help='width of square grid world') 162 | parser.add_argument('--wall_prob', type=float, default=0.1, 163 | help='probability of wall generation') 164 | parser.add_argument('--seed', type=int, default=123, help='seed') 165 | parser.add_argument('--min_max_demo_length_for_program', type=int, default=2) 166 | parser.add_argument('--min_demo_length', type=int, default=8, 167 | help='min demo length') 168 | parser.add_argument('--max_demo_length', type=int, default=20, 169 | help='max demo length') 170 | parser.add_argument('--num_test_demo_per_program', type=int, default=5, 171 | help='number of unseen demonstrations') 172 | args = parser.parse_args() 173 | 174 | generator(args) 175 | 176 | if __name__ == '__main__': 177 | main() 178 | -------------------------------------------------------------------------------- /karel_env/asset/texture.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/karel_env/asset/texture.hdf5 -------------------------------------------------------------------------------- /karel_env/dataset_karel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import numpy as np 7 | import h5py 8 | from karel_env.util import log 9 | 10 | 11 | rs = np.random.RandomState(123) 12 | 13 | 14 | class Dataset(object): 15 | 16 | def __init__(self, ids, dataset_path, name='default', num_k=10, is_train=True): 17 | self._ids = list(ids) 18 | self.name = name 19 | self.num_k = num_k 20 | self.is_train = is_train 21 | 22 | filename = 'data.hdf5' 23 | file = osp.join(dataset_path, filename) 24 | log.info("Reading %s ...", file) 25 | 26 | self.data = h5py.File(file, 'r') 27 | self.dsl_type = self.data['data_info']['dsl_type'].value 28 | self.num_demo = int(self.data['data_info']['num_demo_per_program'].value) 29 | self.max_demo_len = int(self.data['data_info']['max_demo_length'].value) 30 | self.max_program_len = int(self.data['data_info']['max_program_length'].value) 31 | self.num_program_tokens = int(self.data['data_info']['num_program_tokens'].value) 32 | self.num_action_tokens = int(self.data['data_info']['num_action_tokens'].value) 33 | if 'env_type' in self.data['data_info']: 34 | self.env_type = self.data['data_info']['env_type'].value 35 | else: self.env_type = None 36 | log.info("Reading Done: %s", file) 37 | 38 | def get_data(self, id, order=None): 39 | # preprocessing and data augmentation 40 | 41 | # each data point consist of a program + k demo 42 | 43 | # dim: [one hot dim of program tokens, program len] 44 | program_tokens = self.data[id]['program'].value 45 | program = np.zeros([self.num_program_tokens, self.max_program_len], dtype=bool) 46 | program[:, :len(program_tokens)][program_tokens, np.arange(len(program_tokens))] = 1 47 | padded_program_tokens = np.zeros([self.max_program_len], dtype=program_tokens.dtype) 48 | padded_program_tokens[:len(program_tokens)] = program_tokens 49 | 50 | demo_data = self.data[id]['s_h'].value 51 | test_demo_data = self.data[id]['test_s_h'].value 52 | 53 | if 'p_v_h' in self.data[id]: 54 | per_data = self.data[id]['p_v_h'].value 55 | test_per_data = self.data[id]['test_p_v_h'].value 56 | else: 57 | per_data = self.data[id]['per'].value 58 | test_per_data = self.data[id]['test_per'].value 59 | 60 | sz = demo_data.shape 61 | demo = np.zeros([sz[0], self.max_demo_len, sz[2], sz[3], sz[4]], dtype=demo_data.dtype) 62 | demo[:, :sz[1], :, :, :] = demo_data 63 | sz = test_demo_data.shape 64 | test_demo = np.zeros([sz[0], self.max_demo_len, sz[2], sz[3], sz[4]], dtype=demo_data.dtype) 65 | test_demo[:, :sz[1], :, :, :] = test_demo_data 66 | # dim: [k, action_space, max len of demo - 1] 67 | action_history_tokens = self.data[id]['a_h'].value 68 | action_history = [] 69 | for a_h_tokens in action_history_tokens: 70 | # num_action_tokens + 1 is token which is required for detecting 71 | # the end of the sequence. Even though the original length of the 72 | # action history is max_demo_len - 1, we make it max_demo_len, by 73 | # including the last token. 74 | a_h = np.zeros([self.max_demo_len, self.num_action_tokens + 1], dtype=bool) 75 | a_h[:len(a_h_tokens), :][np.arange(len(a_h_tokens)), a_h_tokens] = 1 76 | a_h[len(a_h_tokens), self.num_action_tokens] = 1 # 77 | action_history.append(a_h) 78 | action_history = np.stack(action_history, axis=0) 79 | padded_action_history_tokens = np.argmax(action_history, axis=2) 80 | 81 | # dim: [test_k, action_space, max len of demo - 1] 82 | test_action_history_tokens = self.data[id]['test_a_h'].value 83 | test_action_history = [] 84 | for test_a_h_tokens in test_action_history_tokens: 85 | # num_action_tokens + 1 is token which is required for detecting 86 | # the end of the sequence. Even though the original length of the 87 | # action history is max_demo_len - 1, we make it max_demo_len, by 88 | # including the last token. 89 | test_a_h = np.zeros([self.max_demo_len, self.num_action_tokens + 1], dtype=bool) 90 | test_a_h[:len(test_a_h_tokens), :][np.arange(len(test_a_h_tokens)), test_a_h_tokens] = 1 91 | test_a_h[len(test_a_h_tokens), self.num_action_tokens] = 1 # 92 | test_action_history.append(test_a_h) 93 | test_action_history = np.stack(test_action_history, axis=0) 94 | padded_test_action_history_tokens = np.argmax(test_action_history, axis=2) 95 | 96 | # program length: [1] 97 | program_length = np.array([len(program_tokens)], dtype=np.float32) 98 | 99 | # len of each demo. dim: [k] 100 | demo_length = self.data[id]['s_h_len'].value 101 | test_demo_length = self.data[id]['test_s_h_len'].value 102 | 103 | # per 104 | pad_per_data = np.pad( 105 | per_data, ((0, 0), (0, self.max_demo_len-per_data.shape[1]), (0, 0)), 106 | mode='constant', constant_values=0) 107 | pad_test_per_data = np.pad( 108 | test_per_data, ((0, 0), (0, self.max_demo_len-test_per_data.shape[1]), (0, 0)), 109 | mode='constant', constant_values=0) 110 | 111 | return program, padded_program_tokens, demo[:self.num_k], test_demo, \ 112 | action_history[:self.num_k], padded_action_history_tokens[:self.num_k], \ 113 | test_action_history, padded_test_action_history_tokens, \ 114 | program_length, demo_length[:self.num_k], test_demo_length, \ 115 | pad_per_data[:self.num_k], pad_test_per_data 116 | 117 | @property 118 | def ids(self): 119 | return self._ids 120 | 121 | def __len__(self): 122 | return len(self.ids) 123 | 124 | def __repr__(self): 125 | return 'Dataset (%s, %d examples)' % ( 126 | self.name, 127 | len(self) 128 | ) 129 | 130 | 131 | def create_default_splits(dataset_path, num_k=10, is_train=True): 132 | ids_train, ids_test, ids_val = all_ids(dataset_path) 133 | 134 | dataset_train = Dataset(ids_train, dataset_path, name='train', 135 | num_k=num_k, is_train=is_train) 136 | dataset_test = Dataset(ids_test, dataset_path, name='test', 137 | num_k=num_k, is_train=is_train) 138 | dataset_val = Dataset(ids_val, dataset_path, name='val', 139 | num_k=num_k, is_train=is_train) 140 | return dataset_train, dataset_test, dataset_val 141 | 142 | 143 | def all_ids(dataset_path): 144 | with h5py.File(osp.join(dataset_path, 'data.hdf5'), 'r') as f: 145 | num_train = int(f['data_info']['num_train'].value) 146 | num_test = int(f['data_info']['num_test'].value) 147 | num_val = int(f['data_info']['num_val'].value) 148 | 149 | with open(osp.join(dataset_path, 'id.txt'), 'r') as fp: 150 | ids_total = [s.strip() for s in fp.readlines() if s] 151 | 152 | ids_train = ids_total[:num_train] 153 | ids_test = ids_total[num_train: num_train + num_test] 154 | ids_val = ids_total[num_train + num_test: num_train + num_test + num_val] 155 | 156 | rs.shuffle(ids_train) 157 | rs.shuffle(ids_test) 158 | rs.shuffle(ids_val) 159 | 160 | return ids_train, ids_test, ids_val 161 | -------------------------------------------------------------------------------- /karel_env/dsl/__init__.py: -------------------------------------------------------------------------------- 1 | from dsl_prob import KarelDSLProb 2 | 3 | from dsl_prob_syntax import KarelDSLProbSyntax 4 | 5 | 6 | def get_KarelDSL(dsl_type='prob', seed=None): 7 | if dsl_type == 'prob': 8 | return KarelDSLProb(seed=seed) 9 | else: 10 | raise ValueError('Undefined dsl type') 11 | 12 | 13 | def get_KarelDSLSyntax(dsl_type='prob', seed=None): 14 | if dsl_type == 'prob': 15 | return KarelDSLProbSyntax(seed=seed) 16 | else: 17 | raise ValueError('Undefined dsl syntax type') 18 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Domain specific language for Karel Environment 3 | 4 | Code is adapted from https://github.com/carpedm20/karel 5 | """ 6 | 7 | import numpy as np 8 | import ply.lex as lex 9 | from functools import wraps 10 | 11 | from third_party import yacc 12 | 13 | MIN_INT = 0 14 | MAX_INT = 19 15 | INT_PREFIX = 'R=' 16 | 17 | 18 | class KarelDSLBase(object): 19 | 20 | def get_yacc(self): 21 | self.yacc, self.grammar = yacc.yacc( 22 | module=self, 23 | tabmodule="_parsetab", 24 | with_grammar=True) 25 | 26 | def __init__(self, seed=None): 27 | self.lexer = lex.lex(module=self) 28 | self.get_yacc() 29 | 30 | self.prodnames = self.grammar.Prodnames 31 | self.call_counter = [0] 32 | self.max_func_call = 100 33 | self.rng = np.random.RandomState(seed) 34 | 35 | self.construct_vocab() 36 | 37 | def callout(f): 38 | @wraps(f) 39 | def wrapped(*args, **kwargs): 40 | if self.call_counter[0] > self.max_func_call: 41 | raise RuntimeError("Program execution timeout.") 42 | r = f(*args, **kwargs) 43 | self.call_counter[0] += 1 44 | return r 45 | return wrapped 46 | 47 | self.callout = callout 48 | 49 | def construct_vocab(self): 50 | self.token2int = [] 51 | self.int2token = [] 52 | for term in self.tokens: 53 | token = getattr(self, 't_{}'.format(term)) 54 | if callable(token): 55 | if token == self.t_INT: 56 | for i in range(MIN_INT, MAX_INT + 1): 57 | self.int2token.append("{}{}".format(INT_PREFIX, i)) 58 | else: 59 | self.int2token.append(str(token).replace('\\', '')) 60 | self.token2int = {v: i for i, v in enumerate(self.int2token)} 61 | 62 | def str2intseq(self, code): 63 | return [self.token2int[t] for t in code.split()] 64 | 65 | def code2intseq(self, code): 66 | return [self.token2int[t] for t in code.split()] 67 | 68 | def intseq2str(self, intseq): 69 | return ' '.join([self.int2token[i] for i in intseq]) 70 | 71 | conditional_functions = [] 72 | 73 | action_functions = [] 74 | 75 | ######### 76 | # lexer 77 | ######### 78 | 79 | def t_error(self, t): 80 | t.lexer.skip(1) 81 | raise RuntimeError('Syntax Error') 82 | 83 | ######### 84 | # parser 85 | ######### 86 | 87 | def p_error(self, p): 88 | raise RuntimeError('Syntax Error') 89 | 90 | def random_code(self, start_token="prog", depth=0, max_depth=6, nesting_depth=0, max_nesting_depth=4): 91 | code = " ".join(self.random_tokens(start_token, depth, max_depth, nesting_depth, max_nesting_depth)) 92 | 93 | return code 94 | 95 | def parse(self, code, **kwargs): 96 | self.call_counter = [0] 97 | self.error = False 98 | program = self.yacc.parse(code, **kwargs) 99 | return program 100 | 101 | def run(self, karel_world, code, **kwargs): 102 | self.call_counter = [0] 103 | program = self.parse(code, **kwargs) 104 | 105 | # run program 106 | karel_world.clear_history() 107 | program(karel_world) 108 | return karel_world.s_h 109 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_enum_program.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def check_and_apply(queue, rule): 4 | r = rule[0].split() 5 | l = len(r) 6 | if len(queue) >= l: 7 | t = queue[-l:] 8 | if list(zip(*t)[0]) == r: 9 | new_t = rule[1](list(zip(*t)[1])) 10 | del queue[-l:] 11 | queue.extend(new_t) 12 | return True 13 | return False 14 | 15 | rules = [] 16 | 17 | # k, n, s = fn(k, n) 18 | # k: karel_world 19 | # n: num_call 20 | # s: success 21 | # c: condition [True, False] 22 | MAX_WHILE = 100 23 | 24 | 25 | def r_prog(t): 26 | stmt = t[3] 27 | 28 | return [('prog', stmt(0, 0))] 29 | rules.append(('DEF run m( stmt m)', r_prog)) 30 | 31 | 32 | def r_stmt(t): 33 | stmt = t[0] 34 | 35 | def fn(k, n): 36 | return stmt(k, n) 37 | return [('stmt', fn)] 38 | rules.append(('while_stmt', r_stmt)) 39 | rules.append(('repeat_stmt', r_stmt)) 40 | rules.append(('stmt_stmt', r_stmt)) 41 | rules.append(('action', r_stmt)) 42 | rules.append(('if_stmt', r_stmt)) 43 | rules.append(('ifelse_stmt', r_stmt)) 44 | 45 | 46 | def r_stmt_stmt(t): 47 | stmt1, stmt2 = t[0], t[1] 48 | 49 | def fn(k, n): 50 | return stmt1(k, n) + stmt2(k, n) 51 | return [('stmt_stmt', fn)] 52 | rules.append(('stmt stmt', r_stmt_stmt)) 53 | 54 | 55 | def r_if(t): 56 | cond, stmt = t[2], t[5] 57 | 58 | def fn(k, n): 59 | return ['if'] + cond(k, n) + stmt(k, n) 60 | return [('if_stmt', fn)] 61 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 62 | 63 | 64 | def r_ifelse(t): 65 | cond, stmt1, stmt2 = t[2], t[5], t[9] 66 | 67 | def fn(k, n): 68 | stmt1_out = stmt1(k, n) 69 | stmt2_out = stmt2(k, n) 70 | if stmt1_out == stmt2_out: 71 | return stmt1_out 72 | cond_out = cond(k, n) 73 | if cond_out[0] == 'not': 74 | else_cond = ['if'] + cond_out[1:] 75 | else: 76 | else_cond = ['if', 'not'] + cond_out 77 | return ['if'] + cond_out + stmt1_out + else_cond + stmt2_out 78 | return [('ifelse_stmt', fn)] 79 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 80 | 81 | 82 | def r_while(t): 83 | cond, stmt = t[2], t[5] 84 | 85 | def fn(k, n): 86 | cond_out = cond(k, n) 87 | stmt_out = stmt(k, n) 88 | while_out = [] 89 | for _ in range(MAX_WHILE): 90 | while_out.extend(['if'] + cond_out + stmt_out) 91 | return while_out 92 | return [('while_stmt', fn)] 93 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 94 | 95 | 96 | def r_repeat(t): 97 | cste, stmt = t[1], t[3] 98 | 99 | def fn(k, n): 100 | repeat_out = [] 101 | for _ in range(cste()): 102 | repeat_out.extend(stmt(k, n)) 103 | return repeat_out 104 | return [('repeat_stmt', fn)] 105 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 106 | 107 | 108 | def r_cond1(t): 109 | cond = t[0] 110 | 111 | def fn(k, n): 112 | return cond(k, n) 113 | return [('cond', fn)] 114 | rules.append(('cond_without_not', r_cond1)) 115 | 116 | 117 | def r_cond2(t): 118 | cond = t[2] 119 | 120 | def fn(k, n): 121 | cond_out = cond(k, n) 122 | if cond_out[0] == 'not': 123 | cond_out = cond_out[1:] 124 | else: 125 | cond_out = ['not'] + cond_out 126 | return cond_out 127 | return [('cond', fn)] 128 | rules.append(('not c( cond c)', r_cond2)) 129 | 130 | 131 | def r_cond_without_not1(t): 132 | def fn(k, n): 133 | return ['frontIsClear'] 134 | return [('cond_without_not', fn)] 135 | rules.append(('frontIsClear', r_cond_without_not1)) 136 | 137 | 138 | def r_cond_without_not2(t): 139 | def fn(k, n): 140 | return ['leftIsClear'] 141 | return [('cond_without_not', fn)] 142 | rules.append(('leftIsClear', r_cond_without_not2)) 143 | 144 | 145 | def r_cond_without_not3(t): 146 | def fn(k, n): 147 | return ['rightIsClear'] 148 | return [('cond_without_not', fn)] 149 | rules.append(('rightIsClear', r_cond_without_not3)) 150 | 151 | 152 | def r_cond_without_not4(t): 153 | def fn(k, n): 154 | return ['markersPresent'] 155 | return [('cond_without_not', fn)] 156 | rules.append(('markersPresent', r_cond_without_not4)) 157 | 158 | 159 | def r_cond_without_not5(t): 160 | def fn(k, n): 161 | return ['not', 'markersPresent'] 162 | return [('cond_without_not', fn)] 163 | rules.append(('noMarkersPresent', r_cond_without_not5)) 164 | 165 | 166 | def r_action1(t): 167 | def fn(k, n): 168 | return ['move'] 169 | return [('action', fn)] 170 | rules.append(('move', r_action1)) 171 | 172 | 173 | def r_action2(t): 174 | def fn(k, n): 175 | return ['turnLeft'] 176 | return [('action', fn)] 177 | rules.append(('turnLeft', r_action2)) 178 | 179 | 180 | def r_action3(t): 181 | def fn(k, n): 182 | return ['turnRight'] 183 | return [('action', fn)] 184 | rules.append(('turnRight', r_action3)) 185 | 186 | 187 | def r_action4(t): 188 | def fn(k, n): 189 | return ['pickMarker'] 190 | return [('action', fn)] 191 | rules.append(('pickMarker', r_action4)) 192 | 193 | 194 | def r_action5(t): 195 | def fn(k, n): 196 | return ['putMarker'] 197 | return [('action', fn)] 198 | rules.append(('putMarker', r_action5)) 199 | 200 | 201 | def create_r_cste(number): 202 | def r_cste(t): 203 | return [('cste', lambda: number)] 204 | return r_cste 205 | for i in range(20): 206 | rules.append(('R={}'.format(i), create_r_cste(i))) 207 | 208 | 209 | def parse(program): 210 | p_tokens = program.split()[::-1] 211 | queue = [] 212 | applied = False 213 | while len(p_tokens) > 0 or len(queue) != 1: 214 | if applied: applied = False 215 | else: 216 | queue.append((p_tokens.pop(), None)) 217 | for rule in rules: 218 | applied = check_and_apply(queue, rule) 219 | if applied: break 220 | if not applied and len(p_tokens) == 0: # error parsing 221 | return None, False 222 | return queue[0][1], True 223 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_parse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def check_and_apply(queue, rule): 4 | r = rule[0].split() 5 | l = len(r) 6 | if len(queue) >= l: 7 | t = queue[-l:] 8 | if list(zip(*t)[0]) == r: 9 | new_t = rule[1](list(zip(*t)[1])) 10 | del queue[-l:] 11 | queue.extend(new_t) 12 | return True 13 | return False 14 | 15 | rules = [] 16 | 17 | # k, n, s = fn(k, n) 18 | # k: karel_world 19 | # n: num_call 20 | # s: success 21 | # c: condition [True, False] 22 | MAX_FUNC_CALL = 100 23 | 24 | 25 | def r_prog(t): 26 | stmt = t[3] 27 | 28 | def fn(k, n): 29 | if n > MAX_FUNC_CALL: return k, n, False 30 | return stmt(k, n + 1) 31 | return [('prog', fn)] 32 | rules.append(('DEF run m( stmt m)', r_prog)) 33 | 34 | 35 | def r_stmt(t): 36 | stmt = t[0] 37 | 38 | def fn(k, n): 39 | if n > MAX_FUNC_CALL: return k, n, False 40 | return stmt(k, n + 1) 41 | return [('stmt', fn)] 42 | rules.append(('while_stmt', r_stmt)) 43 | rules.append(('repeat_stmt', r_stmt)) 44 | rules.append(('stmt_stmt', r_stmt)) 45 | rules.append(('action', r_stmt)) 46 | rules.append(('if_stmt', r_stmt)) 47 | rules.append(('ifelse_stmt', r_stmt)) 48 | 49 | 50 | def r_stmt_stmt(t): 51 | stmt1, stmt2 = t[0], t[1] 52 | 53 | def fn(k, n): 54 | if n > MAX_FUNC_CALL: return k, n, False 55 | k, n, s = stmt1(k, n + 1) 56 | if not s: return k, n, s 57 | if n > MAX_FUNC_CALL: return k, n, False 58 | return stmt2(k, n) 59 | return [('stmt_stmt', fn)] 60 | rules.append(('stmt stmt', r_stmt_stmt)) 61 | 62 | 63 | def r_if(t): 64 | cond, stmt = t[2], t[5] 65 | 66 | def fn(k, n): 67 | if n > MAX_FUNC_CALL: return k, n, False 68 | k, n, s, c = cond(k, n + 1) 69 | if not s: return k, n, s 70 | if c: return stmt(k, n) 71 | else: return k, n, s 72 | return [('if_stmt', fn)] 73 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 74 | 75 | 76 | def r_ifelse(t): 77 | cond, stmt1, stmt2 = t[2], t[5], t[9] 78 | 79 | def fn(k, n): 80 | if n > MAX_FUNC_CALL: return k, n, False 81 | k, n, s, c = cond(k, n + 1) 82 | if not s: return k, n, s 83 | if c: return stmt1(k, n) 84 | else: return stmt2(k, n) 85 | return [('ifelse_stmt', fn)] 86 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 87 | 88 | 89 | def r_while(t): 90 | cond, stmt = t[2], t[5] 91 | 92 | def fn(k, n): 93 | if n > MAX_FUNC_CALL: return k, n, False 94 | k, n, s, c = cond(k, n) 95 | if not s: return k, n, s 96 | while(c): 97 | k, n, s = stmt(k, n) 98 | if not s: return k, n, s 99 | k, n, s, c = cond(k, n) 100 | if not s: return k, n, s 101 | return k, n, s 102 | return [('while_stmt', fn)] 103 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 104 | 105 | 106 | def r_repeat(t): 107 | cste, stmt = t[1], t[3] 108 | 109 | def fn(k, n): 110 | if n > MAX_FUNC_CALL: return k, n, False 111 | n += 1 112 | s = True 113 | for _ in range(cste()): 114 | k, n, s = stmt(k, n) 115 | if not s: return k, n, s 116 | return k, n, s 117 | return [('repeat_stmt', fn)] 118 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 119 | 120 | 121 | def r_cond1(t): 122 | cond = t[0] 123 | 124 | def fn(k, n): 125 | if n > MAX_FUNC_CALL: return k, n, False, False 126 | return cond(k, n) 127 | return [('cond', fn)] 128 | rules.append(('cond_without_not', r_cond1)) 129 | 130 | 131 | def r_cond2(t): 132 | cond = t[2] 133 | 134 | def fn(k, n): 135 | if n > MAX_FUNC_CALL: return k, n, False, False 136 | k, n, s, c = cond(k, n) 137 | return k, n, s, not c 138 | return [('cond', fn)] 139 | rules.append(('not c( cond c)', r_cond2)) 140 | 141 | 142 | def r_cond_without_not1(t): 143 | def fn(k, n): 144 | if n > MAX_FUNC_CALL: return k, n, False, False 145 | c = k.front_is_clear() 146 | return k, n, True, c 147 | return [('cond_without_not', fn)] 148 | rules.append(('frontIsClear', r_cond_without_not1)) 149 | 150 | 151 | def r_cond_without_not2(t): 152 | def fn(k, n): 153 | if n > MAX_FUNC_CALL: return k, n, False 154 | c = k.left_is_clear() 155 | return k, n, True, c 156 | return [('cond_without_not', fn)] 157 | rules.append(('leftIsClear', r_cond_without_not2)) 158 | 159 | 160 | def r_cond_without_not3(t): 161 | def fn(k, n): 162 | if n > MAX_FUNC_CALL: return k, n, False 163 | c = k.right_is_clear() 164 | return k, n, True, c 165 | return [('cond_without_not', fn)] 166 | rules.append(('rightIsClear', r_cond_without_not3)) 167 | 168 | 169 | def r_cond_without_not4(t): 170 | def fn(k, n): 171 | if n > MAX_FUNC_CALL: return k, n, False 172 | c = k.marker_present() 173 | return k, n, True, c 174 | return [('cond_without_not', fn)] 175 | rules.append(('markersPresent', r_cond_without_not4)) 176 | 177 | 178 | def r_cond_without_not5(t): 179 | def fn(k, n): 180 | if n > MAX_FUNC_CALL: return k, n, False 181 | c = k.no_marker_present() 182 | return k, n, True, c 183 | return [('cond_without_not', fn)] 184 | rules.append(('noMarkersPresent', r_cond_without_not5)) 185 | 186 | 187 | def r_action1(t): 188 | def fn(k, n): 189 | if n > MAX_FUNC_CALL: return k, n, False 190 | action = np.array([1, 0, 0, 0, 0]) 191 | try: k.state_transition(action) 192 | except: return k, n, False 193 | else: return k, n, True 194 | return [('action', fn)] 195 | rules.append(('move', r_action1)) 196 | 197 | 198 | def r_action2(t): 199 | def fn(k, n): 200 | if n > MAX_FUNC_CALL: return k, n, False 201 | action = np.array([0, 1, 0, 0, 0]) 202 | try: k.state_transition(action) 203 | except: return k, n, False 204 | else: return k, n, True 205 | return [('action', fn)] 206 | rules.append(('turnLeft', r_action2)) 207 | 208 | 209 | def r_action3(t): 210 | def fn(k, n): 211 | if n > MAX_FUNC_CALL: return k, n, False 212 | action = np.array([0, 0, 1, 0, 0]) 213 | try: k.state_transition(action) 214 | except: return k, n, False 215 | else: return k, n, True 216 | return [('action', fn)] 217 | rules.append(('turnRight', r_action3)) 218 | 219 | 220 | def r_action4(t): 221 | def fn(k, n): 222 | if n > MAX_FUNC_CALL: return k, n, False 223 | action = np.array([0, 0, 0, 1, 0]) 224 | try: k.state_transition(action) 225 | except: return k, n, False 226 | else: return k, n, True 227 | return [('action', fn)] 228 | rules.append(('pickMarker', r_action4)) 229 | 230 | 231 | def r_action5(t): 232 | def fn(k, n): 233 | if n > MAX_FUNC_CALL: return k, n, False 234 | action = np.array([0, 0, 0, 0, 1]) 235 | try: k.state_transition(action) 236 | except: return k, n, False 237 | else: return k, n, True 238 | return [('action', fn)] 239 | rules.append(('putMarker', r_action5)) 240 | 241 | 242 | def create_r_cste(number): 243 | def r_cste(t): 244 | return [('cste', lambda: number)] 245 | return r_cste 246 | for i in range(20): 247 | rules.append(('R={}'.format(i), create_r_cste(i))) 248 | 249 | 250 | def parse(program): 251 | p_tokens = program.split()[::-1] 252 | queue = [] 253 | applied = False 254 | while len(p_tokens) > 0 or len(queue) != 1: 255 | if applied: applied = False 256 | else: 257 | queue.append((p_tokens.pop(), None)) 258 | for rule in rules: 259 | applied = check_and_apply(queue, rule) 260 | if applied: break 261 | if not applied and len(p_tokens) == 0: # error parsing 262 | return None, False 263 | return queue[0][1], True 264 | 265 | 266 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_prob.py: -------------------------------------------------------------------------------- 1 | """ 2 | Domain specific language for Karel Environment 3 | 4 | Code is adapted from https://github.com/carpedm20/karel 5 | """ 6 | 7 | import numpy as np 8 | 9 | from dsl_base import KarelDSLBase, MIN_INT, MAX_INT, INT_PREFIX 10 | 11 | 12 | class KarelDSLProb(KarelDSLBase): 13 | tokens = [ 14 | 'DEF', 'RUN', 'M_LBRACE', 'M_RBRACE', 15 | 'MOVE', 'TURN_RIGHT', 'TURN_LEFT', 16 | 'PICK_MARKER', 'PUT_MARKER', 17 | 'R_LBRACE', 'R_RBRACE', 18 | 'INT', # 'NEWLINE', 'SEMI', 19 | 'REPEAT', 20 | 'C_LBRACE', 'C_RBRACE', 21 | 'I_LBRACE', 'I_RBRACE', 'E_LBRACE', 'E_RBRACE', 22 | 'IF', 'IFELSE', 'ELSE', 23 | 'FRONT_IS_CLEAR', 'LEFT_IS_CLEAR', 'RIGHT_IS_CLEAR', 24 | 'MARKERS_PRESENT', 'NO_MARKERS_PRESENT', 25 | 'NOT', 26 | 'W_LBRACE', 'W_RBRACE', 27 | 'WHILE', 28 | ] 29 | 30 | t_ignore = ' \t\n' 31 | 32 | t_M_LBRACE = 'm\(' 33 | t_M_RBRACE = 'm\)' 34 | 35 | t_C_LBRACE = 'c\(' 36 | t_C_RBRACE = 'c\)' 37 | 38 | t_R_LBRACE = 'r\(' 39 | t_R_RBRACE = 'r\)' 40 | 41 | t_W_LBRACE = 'w\(' 42 | t_W_RBRACE = 'w\)' 43 | 44 | t_I_LBRACE = 'i\(' 45 | t_I_RBRACE = 'i\)' 46 | 47 | t_E_LBRACE = 'e\(' 48 | t_E_RBRACE = 'e\)' 49 | 50 | t_DEF = 'DEF' 51 | t_RUN = 'run' 52 | t_WHILE = 'WHILE' 53 | t_REPEAT = 'REPEAT' 54 | t_IF = 'IF' 55 | t_IFELSE = 'IFELSE' 56 | t_ELSE = 'ELSE' 57 | t_NOT = 'not' 58 | 59 | t_FRONT_IS_CLEAR = 'frontIsClear' 60 | t_LEFT_IS_CLEAR = 'leftIsClear' 61 | t_RIGHT_IS_CLEAR = 'rightIsClear' 62 | t_MARKERS_PRESENT = 'markersPresent' 63 | t_NO_MARKERS_PRESENT = 'noMarkersPresent' 64 | 65 | conditional_functions = [ 66 | t_FRONT_IS_CLEAR, t_LEFT_IS_CLEAR, t_RIGHT_IS_CLEAR, 67 | t_MARKERS_PRESENT, t_NO_MARKERS_PRESENT, 68 | ] 69 | 70 | t_MOVE = 'move' 71 | t_TURN_RIGHT = 'turnRight' 72 | t_TURN_LEFT = 'turnLeft' 73 | t_PICK_MARKER = 'pickMarker' 74 | t_PUT_MARKER = 'putMarker' 75 | 76 | action_functions = [ 77 | t_MOVE, 78 | t_TURN_RIGHT, t_TURN_LEFT, 79 | t_PICK_MARKER, t_PUT_MARKER, 80 | ] 81 | 82 | ######### 83 | # lexer 84 | ######### 85 | 86 | def t_INT(self, t): 87 | r'R=\d+' 88 | 89 | value = int(t.value.replace(INT_PREFIX, '')) 90 | if not (MIN_INT <= value <= MAX_INT): 91 | raise Exception(" [!] Out of range ({} ~ {}): `{}`". 92 | format(MIN_INT, MAX_INT, value)) 93 | 94 | t.value = value 95 | return t 96 | 97 | def random_INT(self): 98 | return "{}{}".format( 99 | INT_PREFIX, 100 | self.rng.randint(MIN_INT, MAX_INT + 1)) 101 | 102 | def t_error(self, t): 103 | self.error = True 104 | t.lexer.skip(1) 105 | 106 | ######### 107 | # parser 108 | ######### 109 | 110 | prob_prog = [1.0] 111 | 112 | def p_prog(self, p): 113 | '''prog : DEF RUN M_LBRACE stmt M_RBRACE''' 114 | stmt = p[4] 115 | 116 | @self.callout 117 | def fn(karel_world): 118 | stmt(karel_world) 119 | p[0] = stmt 120 | 121 | prob_stmt = [0.1, 0.02, 0.7, 0.16, 0.01, 0.01] 122 | 123 | def p_stmt(self, p): 124 | '''stmt : while 125 | | repeat 126 | | stmt_stmt 127 | | action 128 | | if 129 | | ifelse 130 | ''' 131 | function = p[1] 132 | 133 | @self.callout 134 | def fn(karel_world): 135 | function(karel_world) 136 | p[0] = fn 137 | 138 | prob_stmt_stmt = [1.0] 139 | 140 | def p_stmt_stmt(self, p): 141 | '''stmt_stmt : stmt stmt 142 | ''' 143 | stmt1, stmt2 = p[1], p[2] 144 | 145 | @self.callout 146 | def fn(karel_world): 147 | stmt1(karel_world) 148 | stmt2(karel_world) 149 | p[0] = fn 150 | 151 | prob_if = [1.0] 152 | 153 | def p_if(self, p): 154 | '''if : IF C_LBRACE cond C_RBRACE I_LBRACE stmt I_RBRACE 155 | ''' 156 | cond, stmt = p[3], p[6] 157 | 158 | @self.callout 159 | def fn(karel_world): 160 | condition = cond(karel_world) 161 | if condition != 'timeout' and condition: 162 | stmt(karel_world) 163 | 164 | p[0] = fn 165 | 166 | prob_ifelse = [1.0] 167 | 168 | def p_ifelse(self, p): 169 | '''ifelse : IFELSE C_LBRACE cond C_RBRACE I_LBRACE stmt I_RBRACE ELSE E_LBRACE stmt E_RBRACE 170 | ''' 171 | cond, stmt1, stmt2 = p[3], p[6], p[10] 172 | 173 | @self.callout 174 | def fn(karel_world): 175 | condition = cond(karel_world) 176 | if condition != 'timeout' and condition: 177 | stmt1(karel_world) 178 | elif condition != 'timeout': 179 | stmt2(karel_world) 180 | else: 181 | return 182 | 183 | p[0] = fn 184 | 185 | prob_while = [1.0] 186 | 187 | def p_while(self, p): 188 | '''while : WHILE C_LBRACE cond C_RBRACE W_LBRACE stmt W_RBRACE 189 | ''' 190 | cond, stmt = p[3], p[6] 191 | 192 | @self.callout 193 | def fn(karel_world): 194 | condition = cond(karel_world) 195 | while(condition != 'timeout' and condition): 196 | stmt(karel_world) 197 | condition = cond(karel_world) 198 | 199 | p[0] = fn 200 | 201 | prob_repeat = [1.0] 202 | 203 | def p_repeat(self, p): 204 | '''repeat : REPEAT cste R_LBRACE stmt R_RBRACE 205 | ''' 206 | cste, stmt = p[2], p[4] 207 | 208 | @self.callout 209 | def fn(karel_world): 210 | for _ in range(cste()): 211 | stmt(karel_world) 212 | 213 | p[0] = fn 214 | 215 | prob_cond = [0.9, 0.1] 216 | 217 | def p_cond(self, p): 218 | '''cond : cond_without_not 219 | | NOT C_LBRACE cond_without_not C_RBRACE 220 | ''' 221 | if callable(p[1]): 222 | cond_without_not = p[1] 223 | 224 | def fn(karel_world): 225 | return cond_without_not(karel_world) 226 | p[0] = fn 227 | else: # NOT 228 | cond_without_not = p[3] 229 | 230 | def fn(karel_world): 231 | return not cond_without_not(karel_world) 232 | p[0] = fn 233 | 234 | prob_cond_without_not = [0.7, 0.1, 0.1, 0.05, 0.05] 235 | 236 | def p_cond_without_not(self, p): 237 | '''cond_without_not : FRONT_IS_CLEAR 238 | | LEFT_IS_CLEAR 239 | | RIGHT_IS_CLEAR 240 | | MARKERS_PRESENT 241 | | NO_MARKERS_PRESENT 242 | ''' 243 | cond_without_not = p[1] 244 | 245 | def fn(karel_world): 246 | if cond_without_not == self.t_FRONT_IS_CLEAR: 247 | return karel_world.front_is_clear() 248 | elif cond_without_not == self.t_LEFT_IS_CLEAR: 249 | return karel_world.left_is_clear() 250 | elif cond_without_not == self.t_RIGHT_IS_CLEAR: 251 | return karel_world.right_is_clear() 252 | elif cond_without_not == self.t_MARKERS_PRESENT: 253 | return karel_world.marker_present() 254 | elif cond_without_not == self.t_NO_MARKERS_PRESENT: 255 | return karel_world.no_marker_present() 256 | else: 257 | raise ValueError("No such condition") 258 | 259 | p[0] = fn 260 | 261 | prob_action = [0.7, 0.1, 0.1, 0.05, 0.05] 262 | 263 | def p_action(self, p): 264 | '''action : MOVE 265 | | TURN_RIGHT 266 | | TURN_LEFT 267 | | PICK_MARKER 268 | | PUT_MARKER 269 | ''' 270 | action = p[1] 271 | 272 | def fn(karel_world): 273 | action_v = np.array([self.t_MOVE, 274 | self.t_TURN_LEFT, self.t_TURN_RIGHT, 275 | self.t_PICK_MARKER, self.t_PUT_MARKER]) == action 276 | karel_world.state_transition(action_v) 277 | p[0] = fn 278 | 279 | prob_cste = [1.0] 280 | 281 | def p_cste(self, p): 282 | '''cste : INT 283 | ''' 284 | value = p[1] 285 | p[0] = lambda: int(value) 286 | 287 | def p_error(self, p): 288 | self.error = True 289 | 290 | def random_tokens(self, start_token="prog", depth=0, max_depth=6, nesting_depth=0, max_nesting_depth=4): 291 | if start_token == 'stmt': 292 | if nesting_depth > max_nesting_depth or depth > max_depth: 293 | start_token = "action" 294 | 295 | codes = [] 296 | candidates = self.prodnames[start_token] 297 | sample_prob = getattr(self, 'prob_{}'.format(start_token)) 298 | 299 | prod = candidates[self.rng.choice(range(len(candidates)), p=sample_prob)] 300 | 301 | for term in prod.prod: 302 | if term in self.prodnames: # need digging 303 | if term in ['if', 'ifelse', 'repeat', 'while']: # increase nested depth 304 | codes.extend(self.random_tokens(term, depth + 1, max_depth, nesting_depth + 1, max_nesting_depth)) 305 | else: 306 | codes.extend(self.random_tokens(term, depth + 1, max_depth, nesting_depth, max_nesting_depth)) 307 | else: 308 | token = getattr(self, 't_{}'.format(term)) 309 | if callable(token): 310 | if token == self.t_INT: 311 | token = self.random_INT() 312 | else: 313 | raise Exception(" [!] Undefined token `{}`".format(token)) 314 | 315 | codes.append(str(token).replace('\\', '')) 316 | 317 | return codes 318 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_syntax.py: -------------------------------------------------------------------------------- 1 | from third_party import yacc 2 | from dsl_base import KarelDSLBase 3 | 4 | 5 | class KarelDSLSyntax(KarelDSLBase): 6 | def get_yacc(self): 7 | self.yacc, self.grammar = yacc.yacc( 8 | module=self, 9 | tabmodule="_parsetab_syntax", 10 | with_grammar=True) 11 | 12 | def get_next_candidates(self, code, **kwargs): 13 | next_candidates = self.yacc.parse(code, **kwargs) 14 | return next_candidates 15 | -------------------------------------------------------------------------------- /karel_env/dsl/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /karel_env/generate_dataset.sh: -------------------------------------------------------------------------------- 1 | # Generate karel datasets 2 | python karel_env/generator.py 3 | # Append unseen demonstrations to each programs 4 | python karel_env/append_demonstration.py 5 | # Add perception primitives to each demonstrations 6 | python karel_env/add_per.py 7 | -------------------------------------------------------------------------------- /karel_env/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import h5py 6 | import os 7 | import argparse 8 | import progressbar 9 | 10 | import numpy as np 11 | 12 | from dsl import get_KarelDSL 13 | from util import log 14 | 15 | import karel 16 | 17 | 18 | class KarelStateGenerator(object): 19 | def __init__(self, seed=None): 20 | self.rng = np.random.RandomState(seed) 21 | 22 | # generate an initial env 23 | def generate_single_state(self, h=8, w=8, wall_prob=0.1): 24 | s = np.zeros([h, w, 16]) > 0 25 | # Wall 26 | s[:, :, 4] = self.rng.rand(h, w) > 1 - wall_prob 27 | s[0, :, 4] = True 28 | s[h-1, :, 4] = True 29 | s[:, 0, 4] = True 30 | s[:, w-1, 4] = True 31 | # Karel initial location 32 | valid_loc = False 33 | while(not valid_loc): 34 | y = self.rng.randint(0, h) 35 | x = self.rng.randint(0, w) 36 | if not s[y, x, 4]: 37 | valid_loc = True 38 | s[y, x, self.rng.randint(0, 4)] = True 39 | # Marker: num of max marker == 1 for now 40 | s[:, :, 6] = (self.rng.rand(h, w) > 0.9) * (s[:, :, 4] == False) > 0 41 | s[:, :, 5] = 1 - (np.sum(s[:, :, 6:], axis=-1) > 0) > 0 42 | assert np.sum(s[:, :, 5:]) == h*w, np.sum(s[:, :, :5]) 43 | marker_weight = np.reshape(np.array(range(11)), (1, 1, 11)) 44 | return s, y, x, np.sum(s[:, :, 4]), np.sum(marker_weight*s[:, :, 5:]) 45 | 46 | 47 | def generator(config): 48 | dir_name = config.dir_name 49 | h = config.height 50 | w = config.width 51 | c = len(karel.state_table) 52 | wall_prob = config.wall_prob 53 | num_train = config.num_train 54 | num_test = config.num_test 55 | num_val = config.num_val 56 | num_total = num_train + num_test + num_val 57 | 58 | # output files 59 | f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'w') 60 | id_file = open(os.path.join(dir_name, 'id.txt'), 'w') 61 | 62 | # progress bar 63 | bar = progressbar.ProgressBar(maxval=100, 64 | widgets=[progressbar.Bar('=', '[', ']'), ' ', 65 | progressbar.Percentage()]) 66 | bar.start() 67 | 68 | dsl = get_KarelDSL(dsl_type='prob', seed=config.seed) 69 | s_gen = KarelStateGenerator(seed=config.seed) 70 | karel_world = karel.Karel_world() 71 | 72 | count = 0 73 | max_demo_length_in_dataset = -1 74 | max_program_length_in_dataset = -1 75 | seen_programs = set() 76 | while(1): 77 | # generate a single program 78 | random_code = dsl.random_code(max_depth=config.max_program_stmt_depth, 79 | max_nesting_depth=config.max_program_nesting_depth) 80 | # skip seen programs 81 | if random_code in seen_programs: 82 | continue 83 | program_seq = np.array(dsl.code2intseq(random_code), dtype=np.int8) 84 | if program_seq.shape[0] > config.max_program_length: 85 | continue 86 | 87 | s_h_list = [] 88 | a_h_list = [] 89 | num_demo = 0 90 | num_trial = 0 91 | while num_demo < config.num_demo_per_program and \ 92 | num_trial < config.max_demo_generation_trial: 93 | try: 94 | s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob) 95 | karel_world.set_new_state(s) 96 | s_h = dsl.run(karel_world, random_code) 97 | except RuntimeError: 98 | pass 99 | else: 100 | if len(karel_world.s_h) <= config.max_demo_length and \ 101 | len(karel_world.s_h) >= config.min_demo_length: 102 | s_h_list.append(np.stack(karel_world.s_h, axis=0)) 103 | a_h_list.append(np.array(karel_world.a_h)) 104 | num_demo += 1 105 | 106 | num_trial += 1 107 | 108 | if num_demo < config.num_demo_per_program: 109 | continue 110 | 111 | len_s_h = np.array([s_h.shape[0] for s_h in s_h_list], dtype=np.int16) 112 | if np.max(len_s_h) < config.min_max_demo_length_for_program: 113 | continue 114 | 115 | demos_s_h = np.zeros([num_demo, np.max(len_s_h), h, w, c], dtype=bool) 116 | for i, s_h in enumerate(s_h_list): 117 | demos_s_h[i, :s_h.shape[0]] = s_h 118 | 119 | len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16) 120 | 121 | demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8) 122 | for i, a_h in enumerate(a_h_list): 123 | demos_a_h[i, :a_h.shape[0]] = a_h 124 | 125 | max_demo_length_in_dataset = max(max_demo_length_in_dataset, np.max(len_s_h)) 126 | max_program_length_in_dataset = max(max_program_length_in_dataset, program_seq.shape[0]) 127 | 128 | # save the state 129 | id = 'no_{}_prog_len_{}_max_s_h_len_{}'.format( 130 | count, program_seq.shape[0], np.max(len_s_h)) 131 | id_file.write(id+'\n') 132 | grp = f.create_group(id) 133 | grp['program'] = program_seq 134 | grp['s_h_len'] = len_s_h 135 | grp['a_h_len'] = len_a_h 136 | grp['s_h'] = demos_s_h 137 | grp['a_h'] = demos_a_h 138 | seen_programs.add(random_code) 139 | # progress bar 140 | count += 1 141 | if count % (num_total / 100) == 0: 142 | bar.update(count / (num_total / 100)) 143 | if count >= num_total: 144 | grp = f.create_group('data_info') 145 | grp['max_demo_length'] = max_demo_length_in_dataset 146 | grp['dsl_type'] = 'prob' 147 | grp['max_program_length'] = max_program_length_in_dataset 148 | grp['num_program_tokens'] = len(dsl.int2token) 149 | grp['num_demo_per_program'] = config.num_demo_per_program 150 | grp['num_action_tokens'] = len(dsl.action_functions) 151 | grp['num_train'] = config.num_train 152 | grp['num_test'] = config.num_test 153 | grp['num_val'] = config.num_val 154 | bar.finish() 155 | f.close() 156 | id_file.close() 157 | log.info('Dataset generated under {} with {}' 158 | ' samples ({} for training and {} for testing ' 159 | 'and {} for val'.format(dir_name, num_total, 160 | num_train, num_test, num_val)) 161 | return 162 | 163 | 164 | def check_path(path): 165 | if not os.path.exists(path): 166 | os.mkdir(path) 167 | 168 | 169 | def main(): 170 | parser = argparse.ArgumentParser( 171 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 172 | parser.add_argument('--dir_name', type=str, default='karel_dataset') 173 | parser.add_argument('--height', type=int, default=8, 174 | help='height of square grid world') 175 | parser.add_argument('--width', type=int, default=8, 176 | help='width of square grid world') 177 | parser.add_argument('--num_train', type=int, default=25000, help='num train') 178 | parser.add_argument('--num_test', type=int, default=5000, help='num test') 179 | parser.add_argument('--num_val', type=int, default=5000, help='num val') 180 | parser.add_argument('--wall_prob', type=float, default=0.1, 181 | help='probability of wall generation') 182 | parser.add_argument('--seed', type=int, default=123, help='seed') 183 | parser.add_argument('--max_program_length', type=int, default=50) 184 | parser.add_argument('--max_program_stmt_depth', type=int, default=6) 185 | parser.add_argument('--max_program_nesting_depth', type=int, default=4) 186 | parser.add_argument('--min_max_demo_length_for_program', type=int, default=2) 187 | parser.add_argument('--min_demo_length', type=int, default=8, 188 | help='min demo length') 189 | parser.add_argument('--max_demo_length', type=int, default=20, 190 | help='max demo length') 191 | parser.add_argument('--num_demo_per_program', type=int, default=10, 192 | help='number of seen demonstrations') 193 | parser.add_argument('--max_demo_generation_trial', type=int, default=100) 194 | args = parser.parse_args() 195 | args.dir_name = os.path.join('datasets/', args.dir_name) 196 | check_path('datasets') 197 | check_path(args.dir_name) 198 | 199 | generator(args) 200 | 201 | if __name__ == '__main__': 202 | main() 203 | -------------------------------------------------------------------------------- /karel_env/input_ops_karel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from util import log 5 | 6 | 7 | def check_data_id(dataset, data_id): 8 | if not data_id: 9 | return 10 | 11 | wrong = [] 12 | for id in data_id: 13 | if id in dataset.data: 14 | pass 15 | else: 16 | wrong.append(id) 17 | 18 | if len(wrong) > 0: 19 | raise RuntimeError("There are %d invalid ids, including %s" % ( 20 | len(wrong), wrong[:5] 21 | )) 22 | 23 | 24 | def create_input_ops(dataset, 25 | batch_size, 26 | num_threads=16, # for creating batches 27 | is_training=False, 28 | data_id=None, 29 | scope='inputs', 30 | shuffle=True, 31 | ): 32 | ''' 33 | Return a batched tensor for the inputs from the dataset. 34 | ''' 35 | input_ops = {} 36 | 37 | if data_id is None: 38 | data_id = dataset.ids 39 | log.info("input_ops [%s]: Using %d IDs from dataset", scope, len(data_id)) 40 | else: 41 | log.info("input_ops [%s]: Using specified %d IDs", scope, len(data_id)) 42 | 43 | # single operations 44 | with tf.device("/cpu:0"), tf.name_scope(scope): 45 | input_ops['id'] = tf.train.string_input_producer( 46 | tf.convert_to_tensor(data_id), capacity=128 47 | ).dequeue(name='input_ids_dequeue') 48 | 49 | p, pt, s, ts, a, at, ta, tat, pl, dl, tdl, per, tper \ 50 | = dataset.get_data(data_id[0]) 51 | 52 | def load_fn(id): 53 | # program [n, max_program_len] 54 | # program_tokens [max_program_len] 55 | # s_h [k, max_demo_len, h, w, 16] 56 | # test_s_h [test_k, max_demo_len, h, w, 16] 57 | # a_h [k, max_demo_len - 1, ac] 58 | # a_h_tokens [k, max_demo_len - 1] 59 | # test_a_h [test_k, max_demo_len - 1, ac] 60 | # test_a_h_tokens [test_k, max_demo_len - 1] 61 | # program_len [1] 62 | # demo_len [k] 63 | # test_demo_len [k] 64 | # per [k, t, c] 65 | # test_per [test_k, t, c] 66 | program, program_tokens, s_h, test_s_h, a_h, a_h_tokens, \ 67 | test_a_h, test_a_h_tokens, program_len, demo_len, test_demo_len, \ 68 | per, test_per = dataset.get_data(id) 69 | return (id, program.astype(np.float32), program_tokens.astype(np.int32), 70 | s_h.astype(np.float32), test_s_h.astype(np.float32), 71 | a_h.astype(np.float32), a_h_tokens.astype(np.int32), 72 | test_a_h.astype(np.float32), test_a_h_tokens.astype(np.int32), 73 | program_len.astype(np.float32), demo_len.astype(np.float32), 74 | test_demo_len.astype(np.float32), 75 | per.astype(np.float32), test_per.astype(np.float32)) 76 | 77 | input_ops['id'], input_ops['program'], input_ops['program_tokens'], \ 78 | input_ops['s_h'], input_ops['test_s_h'], \ 79 | input_ops['a_h'], input_ops['a_h_tokens'], \ 80 | input_ops['test_a_h'], input_ops['test_a_h_tokens'], \ 81 | input_ops['program_len'], input_ops['demo_len'], \ 82 | input_ops['test_demo_len'], input_ops['per'], input_ops['test_per'] = tf.py_func( 83 | load_fn, inp=[input_ops['id']], 84 | Tout=[tf.string, tf.float32, tf.int32, tf.float32, tf.float32, 85 | tf.float32, tf.int32, tf.float32, tf.int32, 86 | tf.float32, tf.float32, tf.float32, tf.float32, tf.float32], 87 | name='func_hp' 88 | ) 89 | 90 | input_ops['id'].set_shape([]) 91 | input_ops['program'].set_shape(list(p.shape)) 92 | input_ops['program_tokens'].set_shape(list(pt.shape)) 93 | input_ops['s_h'].set_shape(list(s.shape)) 94 | input_ops['test_s_h'].set_shape(list(ts.shape)) 95 | input_ops['a_h'].set_shape(list(a.shape)) 96 | input_ops['a_h_tokens'].set_shape(list(at.shape)) 97 | input_ops['test_a_h'].set_shape(list(ta.shape)) 98 | input_ops['test_a_h_tokens'].set_shape(list(tat.shape)) 99 | input_ops['program_len'].set_shape(list(pl.shape)) 100 | input_ops['demo_len'].set_shape(list(dl.shape)) 101 | input_ops['test_demo_len'].set_shape(list(tdl.shape)) 102 | input_ops['per'].set_shape(list(per.shape)) 103 | input_ops['test_per'].set_shape(list(tper.shape)) 104 | 105 | # batchify 106 | capacity = 2 * batch_size * num_threads 107 | min_capacity = min(int(capacity * 0.75), 1024) 108 | 109 | if shuffle: 110 | batch_ops = tf.train.shuffle_batch( 111 | input_ops, 112 | batch_size=batch_size, 113 | num_threads=num_threads, 114 | capacity=capacity, 115 | min_after_dequeue=min_capacity, 116 | ) 117 | else: 118 | batch_ops = tf.train.batch( 119 | input_ops, 120 | batch_size=batch_size, 121 | num_threads=num_threads, 122 | capacity=capacity, 123 | ) 124 | 125 | return input_ops, batch_ops 126 | -------------------------------------------------------------------------------- /karel_env/karel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | MAX_NUM_MARKER = 10 5 | 6 | state_table = { 7 | 0: 'Karel facing North', 8 | 1: 'Karel facing East', 9 | 2: 'Karel facing South', 10 | 3: 'Karel facing West', 11 | 4: 'Wall', 12 | 5: '0 marker', 13 | 6: '1 marker', 14 | 7: '2 markers', 15 | 8: '3 markers', 16 | 9: '4 markers', 17 | 10: '5 markers', 18 | 11: '6 markers', 19 | 12: '7 markers', 20 | 13: '8 markers', 21 | 14: '9 markers', 22 | 15: '10 markers' 23 | } 24 | action_table = { 25 | 0: 'Move', 26 | 1: 'Turn left', 27 | 2: 'Turn right', 28 | 3: 'Pick up a marker', 29 | 4: 'Put a marker' 30 | } 31 | 32 | 33 | class Karel_world(object): 34 | 35 | def __init__(self, s=None, make_error=True): 36 | if s is not None: 37 | self.set_new_state(s) 38 | self.make_error = make_error 39 | 40 | def set_new_state(self, s): 41 | self.s = s.astype(np.bool) 42 | self.s_h = [self.s.copy()] 43 | self.a_h = [] 44 | self.h = self.s.shape[0] 45 | self.w = self.s.shape[1] 46 | p_v = self.get_perception_vector() 47 | self.p_v_h = [p_v.copy()] 48 | 49 | ################################### 50 | ### Collect Demonstrations ### 51 | ################################### 52 | def clear_history(self): 53 | self.s_h = [self.s.copy()] 54 | self.a_h = [] 55 | 56 | def add_to_history(self, a_idx): 57 | self.s_h.append(self.s.copy()) 58 | self.a_h.append(a_idx) 59 | p_v = self.get_perception_vector() 60 | self.p_v_h.append(p_v.copy()) 61 | 62 | # get location (x, y) and facing {north, east, south, west} 63 | def get_location(self): 64 | x, y, z = np.where(self.s[:, :, :4] > 0) 65 | return np.asarray([x[0], y[0], z[0]]) 66 | 67 | # get the neighbor {front, left, right} loction 68 | def get_neighbor(self, face): 69 | loc = self.get_location() 70 | if face == 'front': 71 | neighbor_loc = loc[:2] + { 72 | 0: [-1, 0], 73 | 1: [0, 1], 74 | 2: [1, 0], 75 | 3: [0, -1] 76 | }[loc[2]] 77 | elif face == 'left': 78 | neighbor_loc = loc[:2] + { 79 | 0: [0, -1], 80 | 1: [-1, 0], 81 | 2: [0, 1], 82 | 3: [1, 0] 83 | }[loc[2]] 84 | elif face == 'right': 85 | neighbor_loc = loc[:2] + { 86 | 0: [0, 1], 87 | 1: [1, 0], 88 | 2: [0, -1], 89 | 3: [-1, 0] 90 | }[loc[2]] 91 | return neighbor_loc 92 | 93 | ################################### 94 | ### Perception Primitives ### 95 | ################################### 96 | # return if the neighbor {front, left, right} of Karel is clear 97 | def neighbor_is_clear(self, face): 98 | neighbor_loc = self.get_neighbor(face) 99 | if neighbor_loc[0] >= self.h or neighbor_loc[0] < 0 \ 100 | or neighbor_loc[1] >= self.w or neighbor_loc[1] < 0: 101 | return False 102 | return not self.s[neighbor_loc[0], neighbor_loc[1], 4] 103 | 104 | def front_is_clear(self): 105 | return self.neighbor_is_clear('front') 106 | 107 | def left_is_clear(self): 108 | return self.neighbor_is_clear('left') 109 | 110 | def right_is_clear(self): 111 | return self.neighbor_is_clear('right') 112 | 113 | # return if there is a marker presented 114 | def marker_present(self): 115 | loc = self.get_location() 116 | return np.sum(self.s[loc[0], loc[1], 6:]) > 0 117 | 118 | def no_marker_present(self): 119 | loc = self.get_location() 120 | return np.sum(self.s[loc[0], loc[1], 6:]) == 0 121 | 122 | def get_perception_list(self): 123 | vec = ['frontIsClear', 'leftIsClear', 124 | 'rightIsClear', 'markersPresent', 125 | 'noMarkersPresent'] 126 | return vec 127 | 128 | def get_perception_vector(self): 129 | vec = [self.front_is_clear(), self.left_is_clear(), 130 | self.right_is_clear(), self.marker_present(), 131 | self.no_marker_present()] 132 | return np.array(vec) 133 | 134 | ################################### 135 | ### State Transition ### 136 | ################################### 137 | # given a state and a action, return the next state 138 | def state_transition(self, a): 139 | a_idx = np.argmax(a) 140 | loc = self.get_location() 141 | 142 | if a_idx == 0: 143 | # move 144 | if self.front_is_clear(): 145 | front_loc = self.get_neighbor('front') 146 | loc_vec = self.s[loc[0], loc[1], :4] 147 | self.s[front_loc[0], front_loc[1], :4] = loc_vec 148 | self.s[loc[0], loc[1], :4] = np.zeros(4) > 0 149 | else: 150 | if self.make_error: 151 | raise RuntimeError("Failed to move.") 152 | loc_vec = np.zeros(4) > 0 153 | loc_vec[(loc[2] + 2) % 4] = True # Turn 180 154 | self.s[loc[0], loc[1], :4] = loc_vec 155 | self.add_to_history(a_idx) 156 | elif a_idx == 1 or a_idx == 2: 157 | # turn left or right 158 | loc_vec = np.zeros(4) > 0 159 | loc_vec[(a_idx * 2 - 3 + loc[2]) % 4] = True 160 | self.s[loc[0], loc[1], :4] = loc_vec 161 | self.add_to_history(a_idx) 162 | 163 | elif a_idx == 3 or a_idx == 4: 164 | # pick up or put a marker 165 | num_marker = np.argmax(self.s[loc[0], loc[1], 5:]) 166 | # just clip the num of markers for now 167 | # new_num_marker = np.clip(a_idx*2-7 + num_marker, 0, MAX_NUM_MARKER-1) 168 | new_num_marker = a_idx*2-7 + num_marker 169 | if new_num_marker < 0: 170 | if self.make_error: 171 | raise RuntimeError("No marker to pick up.") 172 | else: 173 | new_num_marker = num_marker 174 | elif new_num_marker > MAX_NUM_MARKER-1: 175 | if self.make_error: 176 | raise RuntimeError("Cannot put more marker.") 177 | else: 178 | new_num_marker = num_marker 179 | marker_vec = np.zeros(MAX_NUM_MARKER+1) > 0 180 | marker_vec[new_num_marker] = True 181 | self.s[loc[0], loc[1], 5:] = marker_vec 182 | self.add_to_history(a_idx) 183 | else: 184 | raise RuntimeError("Invalid action") 185 | return 186 | -------------------------------------------------------------------------------- /karel_env/karel_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from karel import state_table 3 | 4 | 5 | class color_code: 6 | HEADER = '\033[95m' 7 | RED = '\033[31m' 8 | GREEN = '\033[32m' 9 | BLUE = '\033[34m' 10 | PURPLE = '\033[35m' 11 | YELLOW = '\033[93m' 12 | CYAN = '\033[36m' 13 | END = '\033[0m' 14 | BOLD = '\033[1m' 15 | UNDERLINE = '\033[4m' 16 | 17 | 18 | def grid2str(grid): 19 | assert len(grid) == 16, 'Invalid representation of a grid' 20 | idx = np.argwhere(grid == np.amax(grid)).flatten().tolist() 21 | if len(idx) == 1: 22 | return state_table[idx[0]] 23 | elif len(idx) == 2: 24 | return '{} with {}'.format(state_table[idx[0]], state_table[idx[1]]) 25 | else: 26 | return 'None' 27 | 28 | 29 | # given a karel env state, return a symbol representation 30 | def state2symbol(s): 31 | KAREL = "^>v<#" 32 | for i in range(s.shape[0]): 33 | str = "" 34 | for j in range(s.shape[1]): 35 | if np.sum(s[i, j, :4]) > 0 and np.sum(s[i, j, 6:]) > 0: 36 | idx = np.argmax(s[i, j]) 37 | str += color_code.PURPLE+KAREL[idx]+color_code.END 38 | elif np.sum(s[i, j, :4]) > 0: 39 | idx = np.argmax(s[i, j]) 40 | str += color_code.BLUE+KAREL[idx]+color_code.END 41 | elif np.sum(s[i, j, 4]) > 0: 42 | str += color_code.RED+'#'+color_code.END 43 | elif np.sum(s[i, j, 6:]) > 0: 44 | str += color_code.GREEN+'o'+color_code.END 45 | else: 46 | str += '.' 47 | print(str) 48 | return 49 | 50 | 51 | # given a karel env state, return a visulized image 52 | def state2image(s, grid_size=10, root_dir='./'): 53 | h = s.shape[0] 54 | w = s.shape[1] 55 | img = np.ones((h*grid_size, w*grid_size, 3)) 56 | import h5py 57 | import os.path as osp 58 | f = h5py.File(osp.join(root_dir, 'asset/texture.hdf5'), 'r') 59 | wall_img = f['wall'] 60 | marker_img = f['marker'] 61 | # wall 62 | y, x = np.where(s[:, :, 4]) 63 | for i in range(len(x)): 64 | img[y[i]*grid_size:(y[i]+1)*grid_size, x[i]*grid_size:(x[i]+1)*grid_size] = wall_img 65 | # marker 66 | y, x = np.where(np.sum(s[:, :, 6:], axis=-1)) 67 | for i in range(len(x)): 68 | img[y[i]*grid_size:(y[i]+1)*grid_size, x[i]*grid_size:(x[i]+1)*grid_size] = marker_img 69 | # karel 70 | y, x = np.where(np.sum(s[:, :, :4], axis=-1)) 71 | if len(y) == 1: 72 | y = y[0] 73 | x = x[0] 74 | idx = np.argmax(s[y, x]) 75 | marker_present = np.sum(s[y, x, 6:]) > 0 76 | if marker_present: 77 | if idx == 0: 78 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['n_m'] 79 | elif idx == 1: 80 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['e_m'] 81 | elif idx == 2: 82 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['s_m'] 83 | elif idx == 3: 84 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['w_m'] 85 | else: 86 | if idx == 0: 87 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['n'] 88 | elif idx == 1: 89 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['e'] 90 | elif idx == 2: 91 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['s'] 92 | elif idx == 3: 93 | img[y*grid_size:(y+1)*grid_size, x*grid_size:(x+1)*grid_size] = f['w'] 94 | elif len(y) > 1: 95 | raise ValueError 96 | f.close() 97 | return img 98 | -------------------------------------------------------------------------------- /karel_env/state_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | 8 | class KarelStateGenerator(object): 9 | def __init__(self, seed=None): 10 | self.rng = np.random.RandomState(seed) 11 | 12 | # generate an initial env 13 | def generate_single_state(self, h=8, w=8, wall_prob=0.1): 14 | s = np.zeros([h, w, 16]) > 0 15 | # Wall 16 | s[:, :, 4] = self.rng.rand(h, w) > 1 - wall_prob 17 | s[0, :, 4] = True 18 | s[h-1, :, 4] = True 19 | s[:, 0, 4] = True 20 | s[:, w-1, 4] = True 21 | # Karel initial location 22 | valid_loc = False 23 | while(not valid_loc): 24 | y = self.rng.randint(0, h) 25 | x = self.rng.randint(0, w) 26 | if not s[y, x, 4]: 27 | valid_loc = True 28 | s[y, x, self.rng.randint(0, 4)] = True 29 | # Marker: num of max marker == 1 for now 30 | s[:, :, 6] = (self.rng.rand(h, w) > 0.9) * (s[:, :, 4] == False) > 0 31 | s[:, :, 5] = 1 - (np.sum(s[:, :, 6:], axis=-1) > 0) > 0 32 | assert np.sum(s[:, :, 5:]) == h*w, np.sum(s[:, :, :5]) 33 | marker_weight = np.reshape(np.array(range(11)), (1, 1, 11)) 34 | return s, y, x, np.sum(s[:, :, 4]), np.sum(marker_weight*s[:, :, 5:]) 35 | -------------------------------------------------------------------------------- /karel_env/tool/eval_execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Eval Execution 3 | 4 | Execute output programs and then check execution accuracy and syntax accuracy. 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | 11 | import argparse 12 | import collections 13 | import h5py 14 | import os 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | import karel 19 | from dsl import get_KarelDSL 20 | from dsl.dsl_parse import parse 21 | 22 | 23 | def GetArgument(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--data_hdf5', type=str) 26 | parser.add_argument('--output_hdf5', type=str) 27 | parser.add_argument('--output_log_path', type=str, default=None) 28 | parser.add_argument('--new_hdf5_path', type=str, default=None) 29 | parser.add_argument('--log', action='store_true', default=False) 30 | parser.add_argument('--dump', action='store_true', default=False) 31 | return parser.parse_args() 32 | 33 | 34 | class CheckProgramOutput( 35 | collections.namedtuple("CheckProgramOutput", 36 | ("data_id", "program", "syntax", "num_correct", "demo_correctness"))): 37 | pass 38 | 39 | 40 | def CheckProgram(program, data_id, num_demo, demo, demo_len, dsl, karel_world): 41 | exe, s_exe = parse(program) 42 | if not s_exe: 43 | syntax = False 44 | demo_correctness = np.array([False] * num_demo) 45 | num_correct = 0 46 | else: 47 | syntax = True 48 | demo_correctness = np.array([False] * num_demo) 49 | for k in range(num_demo): 50 | init_state = demo[k][0] 51 | karel_world.set_new_state(init_state) 52 | karel_world.clear_history() 53 | exe, s_exe = parse(program) 54 | if not s_exe: 55 | raise RuntimeError('This should be correct') 56 | karel_world, n, s_run = exe(karel_world, 0) 57 | if not s_run: 58 | demo_correctness[k] = False 59 | else: 60 | exe_result_len = len(karel_world.s_h) 61 | exe_result = np.stack(karel_world.s_h) 62 | demo_correctness[k] = (demo_len[k] == exe_result_len and 63 | np.all(demo[k][:demo_len[k]] == exe_result)) 64 | num_correct = demo_correctness.astype(np.int32).sum() 65 | return CheckProgramOutput(data_id, program, syntax, num_correct, demo_correctness) 66 | 67 | 68 | class EvaluationResult: 69 | 70 | def __init__(self, name): 71 | self.name = name 72 | self.initialize() 73 | 74 | def initialize(self): 75 | self.syntax = [] 76 | self.num_correct_count = {} 77 | self.demo_correctness = [] 78 | self.ids = [] 79 | self.programs = [] 80 | 81 | def get_program_by_id(self, id): 82 | idx = self.ids.index(id) 83 | return self.programs[idx] 84 | 85 | def get_demo_correctness_by_id(self, id): 86 | idx = self.ids.index(id) 87 | return self.demo_correctness[idx] 88 | 89 | def get_syntax_by_id(self, id): 90 | idx = self.ids.index(id) 91 | return self.syntax[idx] 92 | 93 | def add_check_outputs(self, check_output): 94 | self.syntax.append(check_output.syntax) 95 | self.num_correct_count[check_output.num_correct] = \ 96 | self.num_correct_count.get(check_output.num_correct, 0) + 1 97 | self.demo_correctness.append(check_output.demo_correctness) 98 | self.programs.append(check_output.program) 99 | self.ids.append(check_output.data_id) 100 | 101 | def summary_results(self): 102 | self.syntax_accuracy = float(self.syntax.count(True)) / len(self.syntax) 103 | self.num_correct_histogram = np.zeros( 104 | [max(self.num_correct_count) + 1], dtype=np.float) 105 | for i in range(len(self.num_correct_histogram)): 106 | self.num_correct_histogram[i] = \ 107 | self.num_correct_count.get(i, 0) 108 | self.num_correct_histogram /= \ 109 | np.array(self.num_correct_count.values()).astype(np.float).sum() 110 | 111 | def result_string(self): 112 | histogram_str = \ 113 | ', '.join(['{:.3f}'.format(h) for h in self.num_correct_histogram]) 114 | string = """ 115 | **{name}** 116 | syntax_accuracy: {syntax_accuracy: .3f} 117 | num_correct_histogram: [{histogram}] 118 | """.format(name=self.name, 119 | syntax_accuracy=self.syntax_accuracy, 120 | histogram=histogram_str) 121 | return string 122 | 123 | 124 | if __name__ == '__main__': 125 | args = GetArgument() 126 | 127 | if not os.path.exists(args.data_hdf5): 128 | raise ValueError( 129 | "data_hdf5 doesn't exist: {}".format(args.data_hdf5)) 130 | 131 | if not os.path.exists(args.output_hdf5): 132 | raise ValueError( 133 | "output_path doesn't exist: {}".format(args.output_hdf5)) 134 | 135 | with h5py.File(args.data_hdf5, 'r') as file_data: 136 | with h5py.File(args.output_hdf5, 'r') as file_output: 137 | data_info = file_data['data_info'] 138 | num_train_demo = data_info['num_demo_per_program'].value 139 | num_test_demo = data_info['num_test_demo_per_program'].value 140 | dsl_type = data_info['dsl_type'].value 141 | dsl = get_KarelDSL(dsl_type=dsl_type, seed=123) 142 | karel_world = karel.Karel_world() 143 | 144 | # tf means "Teacher Forcing" and greedy means "Greedy Unrolling" 145 | results = { 146 | 'train_tf_result': EvaluationResult('train_tf_result'), 147 | 'test_tf_result': EvaluationResult('test_tf_result'), 148 | 'train_greedy_result': EvaluationResult('train_greedy_result'), 149 | 'test_greedy_result': EvaluationResult('test_greedy_result') 150 | } 151 | 152 | tf_syntax = [] 153 | greedy_syntax = [] 154 | for data_id in tqdm(file_output.keys()): 155 | data = file_data[data_id] 156 | output = file_output[data_id] 157 | gt_program = dsl.intseq2str(data['program'].value) 158 | tf_program = output['program_prediction'].value 159 | greedy_program = output['greedy_prediction'].value 160 | 161 | # Train demo 162 | train_tf_out = CheckProgram( 163 | tf_program, data_id, num_train_demo, 164 | data['s_h'].value, data['s_h_len'].value, 165 | dsl, karel_world) 166 | results['train_tf_result'].add_check_outputs(train_tf_out) 167 | 168 | train_greedy_out = CheckProgram( 169 | greedy_program, data_id, num_train_demo, 170 | data['s_h'], data['s_h_len'], 171 | dsl, karel_world) 172 | results['train_greedy_result'].add_check_outputs(train_greedy_out) 173 | 174 | # Test demo 175 | test_tf_out = CheckProgram( 176 | tf_program, data_id, num_test_demo, 177 | data['test_s_h'], data['test_s_h_len'], 178 | dsl, karel_world) 179 | results['test_tf_result'].add_check_outputs(test_tf_out) 180 | 181 | test_greedy_out = CheckProgram( 182 | greedy_program, data_id, num_test_demo, 183 | data['test_s_h'], data['test_s_h_len'], 184 | dsl, karel_world) 185 | results['test_greedy_result'].add_check_outputs(test_greedy_out) 186 | 187 | for result in results.values(): 188 | result.summary_results() 189 | print(result.result_string()) 190 | 191 | if args.log: 192 | if args.output_log_path is None: 193 | args.output_log_path = "{}.eval_exe.log".format( 194 | args.output_hdf5) 195 | with open(args.output_log_path, 'w') as f: 196 | for result in results.values(): 197 | result.summary_results() 198 | f.write(result.result_string()) 199 | if args.dump: 200 | if args.new_hdf5_path is None: 201 | args.new_hdf5_path = "{}.eval_exe.hdf5".format( 202 | args.output_hdf5) 203 | with h5py.File(args.new_hdf5_path, 'w') as new_file: 204 | print('Dump result files: {}'.format(args.new_hdf5_path)) 205 | tf_train = results['train_tf_result'] 206 | tf_test = results['test_tf_result'] 207 | greedy_train = results['train_greedy_result'] 208 | greedy_test = results['test_greedy_result'] 209 | for data_id in tqdm(file_output.keys()): 210 | grp = new_file.create_group(data_id) 211 | grp['program_prediction'] = \ 212 | tf_train.get_program_by_id(data_id) 213 | grp['program_syntax'] = \ 214 | ('correct' if tf_train.get_syntax_by_id(data_id) 215 | else 'wrong') 216 | grp['program_is_correct_execution'] = \ 217 | tf_train.get_demo_correctness_by_id(data_id) 218 | grp['program_num_execution_correct'] = \ 219 | (tf_train.get_demo_correctness_by_id(data_id)).astype(np.int32).sum() 220 | grp['test_program_prediction'] = \ 221 | tf_test.get_program_by_id(data_id) 222 | grp['test_program_syntax'] = \ 223 | ('correct' if tf_test.get_syntax_by_id(data_id) 224 | else 'wrong') 225 | grp['test_program_is_correct_execution'] = \ 226 | tf_test.get_demo_correctness_by_id(data_id) 227 | grp['test_program_num_execution_correct'] = \ 228 | (tf_test.get_demo_correctness_by_id(data_id)).astype(np.int32).sum() 229 | grp['greedy_prediction'] = \ 230 | greedy_train.get_program_by_id(data_id) 231 | grp['greedy_syntax'] = \ 232 | ('correct' if greedy_train.get_syntax_by_id(data_id) 233 | else 'wrong') 234 | grp['greedy_is_correct_execution'] = \ 235 | greedy_train.get_demo_correctness_by_id(data_id) 236 | grp['greedy_num_execution_correct'] = \ 237 | (greedy_train.get_demo_correctness_by_id(data_id)).astype(np.int32).sum() 238 | grp['test_greedy_prediction'] = \ 239 | greedy_test.get_program_by_id(data_id) 240 | grp['test_greedy_syntax'] = \ 241 | ('correct' if greedy_test.get_syntax_by_id(data_id) 242 | else 'wrong') 243 | grp['test_greedy_is_correct_execution'] = \ 244 | greedy_test.get_demo_correctness_by_id(data_id) 245 | grp['test_greedy_num_execution_correct'] = \ 246 | (greedy_test.get_demo_correctness_by_id(data_id)).astype(np.int32).sum() 247 | -------------------------------------------------------------------------------- /karel_env/tool/visualize_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import h5py 6 | import os 7 | import argparse 8 | 9 | from prompt_toolkit import prompt 10 | 11 | from dsl import get_KarelDSL 12 | from karel_util import state2symbol 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dir_name', type=str, default='karel_default') 18 | args = parser.parse_args() 19 | 20 | dir_name = args.dir_name 21 | data_file = os.path.join(dir_name, 'data.hdf5') 22 | id_file = os.path.join(dir_name, 'id.txt') 23 | 24 | if not os.path.exists(data_file): 25 | print("data_file path doesn't exist: {}".format(data_file)) 26 | return 27 | if not os.path.exists(id_file): 28 | print("id_file path doesn't exist: {}".format(id_file)) 29 | return 30 | 31 | f = h5py.File(data_file, 'r') 32 | ids = open(id_file, 'r').read().splitlines() 33 | 34 | dsl = get_KarelDSL(seed=123) 35 | 36 | cur_id = 0 37 | while True: 38 | print('ids / previous id: {}'.format(cur_id)) 39 | for i, id in enumerate(ids[max(cur_id - 5, 0):cur_id + 5]): 40 | print('#{}: {}'.format(max(cur_id - 5, 0) + i, id)) 41 | 42 | print('Put id you want to examine') 43 | cur_id = int(prompt(u'In: ')) 44 | 45 | print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program']))) 46 | print('demonstrations') 47 | for i, l in enumerate(f[ids[cur_id]]['s_h_len']): 48 | print('demo #{}: length {}'.format(i, l)) 49 | print('Put demonstration number [0-{}]'.format(f[ids[cur_id]]['s_h'].shape[0])) 50 | demo_idx = int(prompt(u'In: ')) 51 | seq_idx = 0 52 | 53 | print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program']))) 54 | state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx]) 55 | seq_idx += 1 56 | while seq_idx < f[ids[cur_id]]['s_h_len'][demo_idx]: 57 | print("Press 'c' to continue and 'n' to next example") 58 | print(seq_idx, f[ids[cur_id]]['s_h_len'][demo_idx]) 59 | key = prompt(u'In: ') 60 | if key == 'c': 61 | print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program']))) 62 | state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx]) 63 | seq_idx += 1 64 | elif key == 'n': 65 | break 66 | else: 67 | print('Wrong key') 68 | print('Demo is terminated') 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /karel_env/util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | # Logging 8 | # ======= 9 | 10 | import logging 11 | from colorlog import ColoredFormatter 12 | 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | 16 | formatter = ColoredFormatter( 17 | "%(log_color)s[%(asctime)s] %(message)s", 18 | # datefmt='%H:%M:%S.%f', 19 | datefmt=None, 20 | reset=True, 21 | log_colors={ 22 | 'DEBUG': 'cyan', 23 | 'INFO': 'white,bold', 24 | 'INFOV': 'cyan,bold', 25 | 'WARNING': 'yellow', 26 | 'ERROR': 'red,bold', 27 | 'CRITICAL': 'red,bg_white', 28 | }, 29 | secondary_log_colors={}, 30 | style='%' 31 | ) 32 | ch.setFormatter(formatter) 33 | 34 | log = logging.getLogger('rn') 35 | log.setLevel(logging.DEBUG) 36 | log.handlers = [] # No duplicated handlers 37 | log.propagate = False # workaround for duplicated logs in ipython 38 | log.addHandler(ch) 39 | 40 | logging.addLevelName(logging.INFO + 1, 'INFOV') 41 | def _infov(self, msg, *args, **kwargs): 42 | self.log(logging.INFO + 1, msg, *args, **kwargs) 43 | 44 | logging.Logger.infov = _infov 45 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/models/__init__.py -------------------------------------------------------------------------------- /models/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/models/baselines/__init__.py -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | import tensorflow.contrib.slim as slim 4 | from util import log 5 | 6 | 7 | def lrelu(x, leak=0.2, name="lrelu"): 8 | with tf.variable_scope(name): 9 | f1 = 0.5 * (1 + leak) 10 | f2 = 0.5 * (1 - leak) 11 | return f1 * x + f2 * abs(x) 12 | 13 | 14 | def bn_act(input, is_train, batch_norm=True, activation_fn=None, name="bn_act"): 15 | with tf.variable_scope(name): 16 | _ = input 17 | if activation_fn is not None: 18 | _ = activation_fn(_) 19 | if batch_norm is True: 20 | _ = tf.contrib.layers.batch_norm( 21 | _, center=True, scale=True, decay=0.9, 22 | is_training=is_train, updates_collections=None 23 | ) 24 | return _ 25 | 26 | 27 | def conv2d(input, output_shape, is_train, info=False, k_h=4, k_w=4, s=2, 28 | stddev=0.01, name="conv2d", activation_fn=lrelu, batch_norm=True): 29 | with tf.variable_scope(name): 30 | _ = slim.conv2d(input, output_shape, [k_h, k_w], stride=s, activation_fn=None) 31 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 32 | if info: log.info('{} {}'.format(name, _)) 33 | return _ 34 | 35 | 36 | def residual_block(input, output_shape, is_train, info=False, k=3, s=1, 37 | name="residual", activation_fn=lrelu, batch_norm=True): 38 | with tf.variable_scope(name): 39 | with tf.variable_scope('res1'): 40 | _ = conv2d(input, output_shape, is_train, k_h=k, k_w=k, s=s, 41 | activation_fn=activation_fn, batch_norm=batch_norm) 42 | with tf.variable_scope('res2'): 43 | _ = conv2d(input, output_shape, is_train, k_h=k, k_w=k, s=s, 44 | activation_fn=None, batch_norm=batch_norm) 45 | _ = activation_fn(_ + input) 46 | if info: log.info('{} {}'.format(name, _)) 47 | return _ 48 | 49 | 50 | def deconv2d(input, deconv_info, is_train, name="deconv2d", info=False, 51 | stddev=0.01, activation_fn=tf.nn.relu, batch_norm=True): 52 | with tf.variable_scope(name): 53 | output_shape = deconv_info[0] 54 | k = deconv_info[1] 55 | s = deconv_info[2] 56 | _ = layers.conv2d_transpose( 57 | input, 58 | num_outputs=output_shape, 59 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 60 | biases_initializer=tf.zeros_initializer(), 61 | kernel_size=[k, k], stride=[s, s], padding='SAME' 62 | ) 63 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 64 | if info: log.info('{} {}'.format(name, _)) 65 | return _ 66 | 67 | 68 | def bilinear_deconv2d(input, deconv_info, is_train, name="bilinear_deconv2d", 69 | info=False, activation_fn=tf.nn.relu, batch_norm=True): 70 | with tf.variable_scope(name): 71 | output_shape = deconv_info[0] 72 | k = deconv_info[1] 73 | s = deconv_info[2] 74 | h = int(input.get_shape()[1]) * s 75 | w = int(input.get_shape()[2]) * s 76 | _ = tf.image.resize_bilinear(input, [h, w]) 77 | _ = conv2d(_, output_shape, is_train, k_h=k, k_w=k, s=1, 78 | batch_norm=False, activation_fn=None) 79 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 80 | if info: log.info('{} {}'.format(name, _)) 81 | return _ 82 | 83 | 84 | def nn_deconv2d(input, deconv_info, is_train, name="nn_deconv2d", 85 | info=False, activation_fn=tf.nn.relu, batch_norm=True): 86 | with tf.variable_scope(name): 87 | output_shape = deconv_info[0] 88 | k = deconv_info[1] 89 | s = deconv_info[2] 90 | h = int(input.get_shape()[1]) * s 91 | w = int(input.get_shape()[2]) * s 92 | _ = tf.image.resize_nearest_neighbor(input, [h, w]) 93 | _ = conv2d(_, output_shape, is_train, k_h=k, k_w=k, s=1, 94 | batch_norm=False, activation_fn=None) 95 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 96 | if info: log.info('{} {}'.format(name, _)) 97 | return _ 98 | 99 | 100 | def transpose_deconv3d(input, deconv_info, is_train=True, name="deconv3d", 101 | stddev=0.01, activation_fn=tf.nn.relu, batch_norm=True): 102 | with tf.variable_scope(name): 103 | output_shape = deconv_info[0] 104 | k = deconv_info[1] 105 | s = deconv_info[2] 106 | _ = tf.layers.conv3d_transpose( 107 | input, 108 | filters=output_shape, 109 | kernel_initializer=tf.truncated_normal_initializer(stddev=stddev), 110 | bias_initializer=tf.zeros_initializer(), 111 | kernel_size=[k, k, k], strides=[s, s, s], padding='SAME' 112 | ) 113 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 114 | return _ 115 | 116 | 117 | def residual_conv(input, num_filters, filter_size, stride, reuse=False, 118 | pad='SAME', dtype=tf.float32, bias=False, name='res_conv'): 119 | with tf.variable_scope(name): 120 | stride_shape = [1, stride, stride, 1] 121 | filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters] 122 | w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02)) 123 | p = (filter_size - 1) // 2 124 | x = tf.pad(input, [[0, 0], [p, p], [p, p], [0, 0]], 'REFLECT') 125 | conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID') 126 | return conv 127 | 128 | 129 | def residual(input, num_filters, name, is_train, reuse=False, pad='REFLECT'): 130 | with tf.variable_scope(name, reuse=reuse): 131 | with tf.variable_scope('res1', reuse=reuse): 132 | out = residual_conv(input, num_filters, 3, 1, reuse, pad, name=name) 133 | out = tf.contrib.layers.batch_norm( 134 | out, center=True, scale=True, decay=0.9, 135 | is_training=is_train, updates_collections=None 136 | ) 137 | out = tf.nn.relu(out) 138 | 139 | with tf.variable_scope('res2', reuse=reuse): 140 | out = residual_conv(out, num_filters, 3, 1, reuse, pad, name=name) 141 | out = tf.contrib.layers.batch_norm( 142 | out, center=True, scale=True, decay=0.9, 143 | is_training=is_train, updates_collections=None 144 | ) 145 | 146 | return tf.nn.relu(input + out) 147 | 148 | 149 | def fc(input, output_shape, is_train, info=False, batch_norm=True, 150 | activation_fn=lrelu, name="fc"): 151 | with tf.variable_scope(name): 152 | _ = slim.fully_connected(input, output_shape, activation_fn=None) 153 | _ = bn_act(_, is_train, batch_norm=batch_norm, activation_fn=activation_fn) 154 | if info: log.info('{} {}'.format(name, _)) 155 | return _ 156 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | # Logging 9 | # ======= 10 | 11 | import logging 12 | from colorlog import ColoredFormatter 13 | import matplotlib.colors as cl 14 | import numpy as np 15 | 16 | ch = logging.StreamHandler() 17 | ch.setLevel(logging.DEBUG) 18 | 19 | formatter = ColoredFormatter( 20 | "%(log_color)s[%(asctime)s] %(message)s", 21 | datefmt=None, 22 | reset=True, 23 | log_colors={ 24 | 'DEBUG': 'cyan', 25 | 'INFO': 'white,bold', 26 | 'INFOV': 'cyan,bold', 27 | 'WARNING': 'yellow', 28 | 'ERROR': 'red,bold', 29 | 'CRITICAL': 'red,bg_white', 30 | }, 31 | secondary_log_colors={}, 32 | style='%' 33 | ) 34 | ch.setFormatter(formatter) 35 | 36 | log = logging.getLogger('attcap') 37 | log.setLevel(logging.DEBUG) 38 | log.handlers = [] # No duplicated handlers 39 | log.propagate = False # workaround for duplicated logs in ipython 40 | log.addHandler(ch) 41 | 42 | logging.addLevelName(logging.INFO + 1, 'INFOV') 43 | 44 | 45 | def _infov(self, msg, *args, **kwargs): 46 | self.log(logging.INFO + 1, msg, *args, **kwargs) 47 | 48 | logging.Logger.infov = _infov 49 | 50 | 51 | def visualize_flow(x, y): 52 | img_batch = [] 53 | h, w = x.shape[-2:] 54 | for i in range(x.shape[1]): 55 | img_time_step = [] 56 | for j in range(x.shape[0]): 57 | du = x[j, i] 58 | dv = y[j, i] 59 | # valid = flow[:, :, 2] 60 | max_flow = max(np.max(du), np.max(dv)) 61 | img = np.zeros((h, w, 3), dtype=np.float64) 62 | # angle layer 63 | img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi) 64 | # magnitude layer, normalized to 1 65 | img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow 66 | # phase layer 67 | img[:, :, 2] = 8 - img[:, :, 1] 68 | # clip to [0,1] 69 | small_idx = img < 0 70 | large_idx = img > 1 71 | img[small_idx] = 0 72 | img[large_idx] = 1 73 | # convert to rgb 74 | img = cl.hsv_to_rgb(img) 75 | img_time_step.append(img) 76 | img_time_step = np.stack(img_time_step, axis=-1) 77 | img_time_step = np.transpose(img_time_step, [0, 1, 3, 2]) 78 | img_time_step = np.reshape(img_time_step, [h, w*x.shape[0], 3]) 79 | img_batch.append(img_time_step) 80 | return np.stack(img_batch, axis=0).astype(np.float32) 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.3.0 2 | scipy==1.0.0 3 | numpy==1.14.2 4 | colorlog==3.1.0 5 | h5py==2.7.1 6 | Pillow==5.0.0 7 | progressbar==2.3 8 | ply==3.10 9 | -------------------------------------------------------------------------------- /vizdoom_env/README.md: -------------------------------------------------------------------------------- 1 | # Vizdoom Environment 2 | This directory includes code for Vizdoom environments, which includes: 3 | - Random programs and demonstration generator 4 | - Domain specific language interpreter 5 | 6 | ## Dependencies 7 | Using vizdoom environment requires [VizDoom Deterministic](https://github.com/HyeonwooNoh/ViZDoomDeterministic) as a dependency. 8 | You can install the VizDoom Deterministic as follows: 9 | ```bash 10 | git clone https://github.com/HyeonwooNoh/ViZDoomDeterministic 11 | cd ViZDoomDeterministic 12 | # use pip3 for Python3 13 | sudo pip install . 14 | ``` 15 | Please find the detailed installation instruction of the VizDoom Deterministic from [this repository](https://github.com/HyeonwooNoh/ViZDoomDeterministic). 16 | 17 | ## Dataset generation 18 | 19 | Datasets used in the paper (vizdoom_dataset, vizdoom_dataset_ifelse) are generated with the following script 20 | ```bash 21 | ./vizdoom_world/generate_dataset.sh 22 | ``` 23 | 24 | ## Domain specific language 25 | The interpreter and random program generator for vizdoom domain specific language (DSL) is in the [dsl directory](./dsl). You can find detailed definition of the DSL from the supplementary material of the paper. 26 | -------------------------------------------------------------------------------- /vizdoom_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/__init__.py -------------------------------------------------------------------------------- /vizdoom_env/asset/default.cfg: -------------------------------------------------------------------------------- 1 | #doom_scenario_path = scenarios/basic.wad 2 | #doom_scenario_path = scenarios/cig.wad 3 | #doom_scenario_path = scenarios/deadly_corridor.wad 4 | #doom_scenario_path = scenarios/deathmatch.wad 5 | doom_scenario_path = scenarios/doom_state.wad 6 | #doom_scenario_path = scenarios/defend_the_center.wad 7 | #doom_scenario_path = scenarios/defend_the_line.wad 8 | #doom_scenario_path = scenarios/health_gathering.wad 9 | #doom_scenario_path = scenarios/predict_position.wad 10 | #doom_scenario_path = scenarios/rocket_basic.wad 11 | #doom_scenario_path = scenarios/simpler_basic.wad 12 | #doom_scenario_path = scenarios/take_cover.wad 13 | #doom_map = map01 14 | 15 | # Rewards 16 | living_reward = 0 17 | 18 | # Rendering options 19 | screen_resolution = RES_160x120 20 | screen_format = CRCGCB 21 | render_hud = True 22 | render_crosshair = False 23 | render_weapon = True 24 | render_decals = False 25 | render_particles = False 26 | #window_visible = True 27 | window_visible = False 28 | 29 | # Buffer enable 30 | labels_buffer_enabled = True 31 | 32 | 33 | # make episodes start after 20 tics (after unholstering the gun) 34 | episode_start_time = 14 35 | 36 | # make episodes finish after 300 actions (tics) 37 | episode_timeout = 1000 38 | 39 | # Available buttons 40 | available_buttons = 41 | { 42 | MOVE_FORWARD 43 | MOVE_BACKWARD 44 | MOVE_LEFT 45 | MOVE_RIGHT 46 | TURN_LEFT 47 | TURN_RIGHT 48 | 49 | ATTACK 50 | 51 | SELECT_WEAPON1 52 | SELECT_WEAPON2 53 | SELECT_WEAPON3 54 | SELECT_WEAPON4 55 | SELECT_WEAPON5 56 | SELECT_WEAPON6 57 | 58 | SELECT_NEXT_WEAPON 59 | SELECT_PREV_WEAPON 60 | } 61 | 62 | # Game variables that will be in the state 63 | available_game_variables = 64 | { 65 | ANGLE 66 | PITCH 67 | ROLL 68 | POSITION_X 69 | POSITION_Y 70 | POSITION_Z 71 | VELOCITY_X 72 | VELOCITY_Y 73 | VELOCITY_Z 74 | 75 | ON_GROUND 76 | ATTACK_READY 77 | ALTATTACK_READY 78 | SELECTED_WEAPON 79 | SELECTED_WEAPON_AMMO 80 | 81 | ITEMCOUNT 82 | KILLCOUNT 83 | HEALTH 84 | ARMOR 85 | 86 | AMMO0 87 | AMMO1 88 | AMMO2 89 | AMMO3 90 | AMMO4 91 | AMMO5 92 | AMMO6 93 | AMMO7 94 | AMMO8 95 | AMMO9 96 | 97 | WEAPON0 98 | WEAPON1 99 | WEAPON2 100 | WEAPON3 101 | WEAPON4 102 | WEAPON5 103 | WEAPON6 104 | WEAPON7 105 | WEAPON8 106 | WEAPON9 107 | } 108 | 109 | mode = PLAYER 110 | 111 | # Level of difficulty 112 | # 1 - VERY EASY 113 | # 2 - EASY 114 | # 3 - NORMAL 115 | # 4 - HARD 116 | # 5 - VERY HARD 117 | doom_skill = 3 118 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/README.md: -------------------------------------------------------------------------------- 1 | # Scenarios Decription: 2 | 3 | Scenarios contained in iwad files do not support action constraints, death penalty and living rewards. 4 | Every mention of any settings that are not included in iwads will be specified with "(config)". 5 | 6 | Note: Vizdoom does not support setting certain rewards (such as killing oponents) in .cfg files. These must be set in the .wad files instead 7 | 8 | ## BASIC 9 | The purpose of the scenario is just to check if using this 10 | framework to train some AI i 3D environment is feasible. 11 | 12 | Map is a rectangle with gray walls, ceiling and floor. 13 | Player is spawned along the longer wall, in the center. 14 | A red, circular monster is spawned randomly somewhere along 15 | the opposite wall. Player can only (config) go left/right 16 | and shoot. 1 hit is enough to kill the monster. Episode 17 | finishes when monster is killed or on timeout. 18 | 19 | __REWARDS:__ 20 | 21 | +101 for killing the monster 22 | -5 for missing 23 | Episode ends after killing the monster or on timeout. 24 | 25 | Further configuration: 26 | * living reward = -1, 27 | * 3 available buttons: move left, move right, shoot (attack) 28 | * timeout = 300 29 | 30 | ## DEADLY CORRIDOR 31 | The purpose of this scenario is to teach the agent to navigate towards 32 | his fundamental goal (the vest) and make sure he survives at the 33 | same time. 34 | 35 | Map is a corridor with shooting monsters on both sides (6 monsters 36 | in total). A green vest is placed at the oposite end of the corridor. 37 | Reward is proportional (negative or positive) to change of the 38 | distance between the player and the vest. If player ignores monsters 39 | on the sides and runs straight for the vest he will be killed somewhere 40 | along the way. To ensure this behavior doom_skill = 5 (config) is 41 | needed. 42 | 43 | __REWARDS:__ 44 | 45 | +dX for getting closer to the vest. 46 | -dX for getting further from the vest. 47 | 48 | Further configuration: 49 | * 5 available buttons: turn left, turn right, move left, move right, shoot (attack) 50 | * timeout = 4200 51 | * death penalty = 100 52 | * doom_skill = 5 53 | 54 | 55 | ## DEFEND THE CENTER 56 | The purpose of this scenario is to teach the agent that killing the 57 | monsters is GOOD and when monsters kill you is BAD. In addition, 58 | wasting amunition is not very good either. Agent is rewarded only 59 | for killing monsters so he has to figure out the rest for himself. 60 | 61 | Map is a large circle. Player is spawned in the exact center. 62 | 5 melee-only, monsters are spawned along the wall. Monsters are 63 | killed after a single shot. After dying each monster is respawned 64 | after some time. Episode ends when the player dies (it's inevitable 65 | becuse of limitted ammo). 66 | 67 | __REWARDS:__ 68 | +1 for killing a monster 69 | 70 | Further configuration: 71 | * 3 available buttons: turn left, turn right, shoot (attack) 72 | * death penalty = 1 73 | 74 | ## DEFEND THE LINE 75 | The purpose of this scenario is to teach the agent that killing the 76 | monsters is GOOD and when monsters kill you is BAD. In addition, 77 | wasting amunition is not very good either. Agent is rewarded only 78 | for killing monsters so he has to figure out the rest for himself. 79 | 80 | Map is a rectangle. Player is spawned along the longer wall, in the 81 | center. 3 melee-only and 3 shooting monsters are spawned along the 82 | oposite wall. Monsters are killed after a single shot, at first. 83 | After dying each monster is respawned after some time and can endure 84 | more damage. Episode ends when the player dies (it's inevitable 85 | becuse of limitted ammo). 86 | 87 | __REWARDS:__ 88 | +1 for killing a monster 89 | 90 | Further configuration: 91 | * 3 available buttons: turn left, turn right, shoot (attack) 92 | * death penalty = 1 93 | 94 | ## HEALTH GATHERING 95 | The purpose of this scenario is to teach the agent how to survive 96 | without knowing what makes him survive. Agent know only that life 97 | is precious and death is bad so he must learn what prolongs his 98 | existence and that his health is connected with it. 99 | 100 | Map is a rectangle with green, acidic floor which hurts the player 101 | periodically. Initially there are some medkits spread uniformly 102 | over the map. A new medkit falls from the skies every now and then. 103 | Medkits heal some portions of player's health - to survive agent 104 | needs to pick them up. Episode finishes after player's death or 105 | on timeout. 106 | 107 | 108 | Further configuration: 109 | * living_reward = 1 110 | * 3 available buttons: turn left, turn right, move forward 111 | * 1 available game variable: HEALTH 112 | * death penalty = 100 113 | 114 | ## MY WAY HOME 115 | The purpose of this scenario is to teach the agent how to navigate 116 | in a labirynth-like surroundings and reach his ultimate goal 117 | (and learn what it actually is). 118 | 119 | Map is a series of rooms with interconnection and 1 corridor 120 | with a dead end. Each room has a different color. There is a 121 | green vest in one of the rooms (the same room every time). 122 | Player is spawned in randomly choosen room facing a random 123 | direction. Episode ends when vest is reached or on timeout/ 124 | 125 | __REWARDS:__ 126 | +1 for reaching the vest 127 | 128 | Further configuration: 129 | * 3 available buttons: turn left, turn right, move forward 130 | * living reward = -0.0001 131 | * timeout = 2100 132 | 133 | ## PREDICT POSITION 134 | The purpose of the scenario is teach agent to synchronize 135 | missle weapon shot (involving a signifficant delay between 136 | shooting and hitting) with target movements. Agent should be 137 | able to shoot so that missle and monster meet each other. 138 | 139 | The map is a rectangle room. Player is spawned along the longer 140 | wall, in the center. A monster is spawned randomly somewhere 141 | along the opposite wall and walks between left and right corners 142 | along the wall. Player is equipped with a rocket launcher and 143 | a single rocket. Episode ends when missle hits a wall/the monster 144 | or on timeout. 145 | 146 | __REWARDS:__ 147 | +1 for killing the monster 148 | 149 | Further configuration: 150 | * living reward = -0.0001, 151 | * 3 available buttons: move left, move right, shoot (attack) 152 | * timeout = 300 153 | 154 | ## TAKE COVER 155 | The purpose of this scenario is to teach agent to link incomming 156 | missles with his estimated lifespan. Agent should learn that 157 | being hit means health decrease and this in turn will lead to 158 | death which is undesirable. In effect agent should avoid 159 | missles. 160 | 161 | Map is a rectangle. Player is spawned along the longer wall, 162 | in the center. A couple of shooting monsters are spawned 163 | randomly somewhere along the opposite wall and try to kill 164 | the player with fireballs. The player can only (config) move 165 | left/right. More monsters appear with time. Episode ends when 166 | player dies. 167 | 168 | __REWARDS:__ 169 | +1 for each tic of life 170 | 171 | Further configuration: 172 | * living reward = 1.0, 173 | * 2 available buttons: move left, move right 174 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/basic.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = basic.wad 6 | doom_map = map01 7 | 8 | # Rewards 9 | living_reward = -1 10 | 11 | # Rendering options 12 | screen_resolution = RES_320X240 13 | screen_format = CRCGCB 14 | render_hud = True 15 | render_crosshair = false 16 | render_weapon = true 17 | render_decals = false 18 | render_particles = false 19 | window_visible = true 20 | 21 | # make episodes start after 20 tics (after unholstering the gun) 22 | episode_start_time = 14 23 | 24 | # make episodes finish after 300 actions (tics) 25 | episode_timeout = 300 26 | 27 | # Available buttons 28 | available_buttons = 29 | { 30 | MOVE_LEFT 31 | MOVE_RIGHT 32 | ATTACK 33 | } 34 | 35 | # Game variables that will be in the state 36 | available_game_variables = { AMMO2} 37 | 38 | mode = PLAYER 39 | doom_skill = 5 40 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/basic.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/bots.cfg: -------------------------------------------------------------------------------- 1 | { 2 | name Rambo 3 | aiming 67 4 | perfection 50 5 | reaction 70 6 | isp 50 7 | color "40 cf 00" 8 | skin base 9 | //weaponpref 012385678 10 | } 11 | 12 | { 13 | name McClane 14 | aiming 34 15 | perfection 75 16 | reaction 15 17 | isp 90 18 | color "b0 b0 b0" 19 | skin base 20 | //weaponpref 012345678 21 | } 22 | 23 | { 24 | name MacGyver 25 | aiming 80 26 | perfection 67 27 | reaction 72 28 | isp 87 29 | color "50 50 60" 30 | skin base 31 | //weaponpref 012345678 32 | } 33 | 34 | { 35 | name Plissken 36 | aiming 15 37 | perfection 50 38 | reaction 50 39 | isp 50 40 | color "8f 00 00" 41 | skin base 42 | //weaponpref 082345678 43 | } 44 | 45 | { 46 | name Machete 47 | aiming 50 48 | perfection 13 49 | reaction 20 50 | isp 100 51 | color "ff ff ff" 52 | skin base 53 | //weaponpref 012345678 54 | } 55 | 56 | { 57 | name Anderson 58 | aiming 45 59 | perfection 30 60 | reaction 70 61 | isp 60 62 | color "ff af 3f" 63 | skin base 64 | //weaponpref 012345678 65 | } 66 | 67 | { 68 | name Leone 69 | aiming 56 70 | perfection 34 71 | reaction 78 72 | isp 50 73 | color "bf 00 00" 74 | skin base 75 | //weaponpref 012345678 76 | } 77 | 78 | { 79 | name Predator 80 | aiming 25 81 | perfection 55 82 | reaction 32 83 | isp 70 84 | color "00 00 ff" 85 | skin base 86 | //weaponpref 012345678 87 | } 88 | 89 | { 90 | name Ripley 91 | aiming 61 92 | perfection 50 93 | reaction 23 94 | isp 32 95 | color "00 00 7f" 96 | skin base 97 | //weaponpref 012345678 98 | } 99 | 100 | { 101 | name T800 102 | aiming 90 103 | perfection 85 104 | reaction 10 105 | isp 30 106 | color "ff ff 00" 107 | skin base 108 | //weaponpref 012345678 109 | } 110 | 111 | { 112 | name Dredd 113 | aiming 12 114 | perfection 35 115 | reaction 56 116 | isp 37 117 | color "40 cf 00" 118 | skin base 119 | //weaponpref 012345678 120 | } 121 | 122 | { 123 | name Conan 124 | aiming 10 125 | perfection 35 126 | reaction 10 127 | isp 100 128 | color "b0 b0 b0" 129 | skin base 130 | //weaponpref 012345678 131 | } 132 | 133 | { 134 | name Bond 135 | aiming 67 136 | perfection 15 137 | reaction 76 138 | isp 37 139 | color "50 50 60" 140 | skin base 141 | //weaponpref 012345678 142 | } 143 | 144 | { 145 | name Jones 146 | aiming 52 147 | perfection 35 148 | reaction 50 149 | isp 37 150 | color "8f 00 00" 151 | skin base 152 | //weaponpref 012345678 153 | } 154 | 155 | { 156 | name Blazkowicz 157 | aiming 80 158 | perfection 80 159 | reaction 80 160 | isp 100 161 | color "00 00 00" 162 | skin base 163 | //weaponpref 012345678 164 | } 165 | 166 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/cig.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = cig.wad 6 | 7 | #12 minutes 8 | episode_timeout = 25200 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = true 14 | render_crosshair = true 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | 19 | window_visible = true 20 | 21 | # Available buttons 22 | available_buttons = 23 | { 24 | TURN_LEFT 25 | TURN_RIGHT 26 | ATTACK 27 | 28 | MOVE_RIGHT 29 | MOVE_LEFT 30 | 31 | MOVE_FORWARD 32 | MOVE_BACKWARD 33 | TURN_LEFT_RIGHT_DELTA 34 | LOOK_UP_DOWN_DELTA 35 | 36 | } 37 | 38 | mode = ASYNC_PLAYER 39 | 40 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/cig.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/cig.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/cig_with_unknown.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/cig_with_unknown.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/deadly_corridor.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = deadly_corridor.wad 6 | 7 | # Skill 5 is reccomanded for the scenario to be a challenge. 8 | doom_skill = 5 9 | 10 | # Rewards 11 | death_penalty = 100 12 | #living_reward = 0 13 | 14 | # Rendering options 15 | screen_resolution = RES_320X240 16 | screen_format = CRCGCB 17 | render_hud = true 18 | render_crosshair = false 19 | render_weapon = true 20 | render_decals = false 21 | render_particles = false 22 | window_visible = true 23 | 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | MOVE_LEFT 30 | MOVE_RIGHT 31 | ATTACK 32 | MOVE_FORWARD 33 | MOVE_BACKWARD 34 | TURN_LEFT 35 | TURN_RIGHT 36 | } 37 | 38 | # Game variables that will be in the state 39 | available_game_variables = { HEALTH } 40 | 41 | mode = PLAYER 42 | 43 | 44 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/deadly_corridor.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/deadly_corridor.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/deathmatch.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = deathmatch.wad 6 | 7 | # Rendering options 8 | screen_resolution = RES_320X240 9 | screen_format = CRCGCB 10 | render_hud = true 11 | render_crosshair = false 12 | render_weapon = true 13 | render_decals = false 14 | render_particles = false 15 | window_visible = true 16 | 17 | # make episodes finish after 4200 actions (tics) 18 | episode_timeout = 4200 19 | 20 | # Available buttons 21 | available_buttons = 22 | { 23 | ATTACK 24 | SPEED 25 | STRAFE 26 | 27 | MOVE_RIGHT 28 | MOVE_LEFT 29 | MOVE_BACKWARD 30 | MOVE_FORWARD 31 | TURN_RIGHT 32 | TURN_LEFT 33 | 34 | SELECT_WEAPON1 35 | SELECT_WEAPON2 36 | SELECT_WEAPON3 37 | SELECT_WEAPON4 38 | SELECT_WEAPON5 39 | SELECT_WEAPON6 40 | 41 | SELECT_NEXT_WEAPON 42 | SELECT_PREV_WEAPON 43 | 44 | LOOK_UP_DOWN_DELTA 45 | TURN_LEFT_RIGHT_DELTA 46 | MOVE_LEFT_RIGHT_DELTA 47 | 48 | } 49 | 50 | # Game variables that will be in the state 51 | available_game_variables = 52 | { 53 | KILLCOUNT 54 | HEALTH 55 | ARMOR 56 | SELECTED_WEAPON 57 | SELECTED_WEAPON_AMMO 58 | } 59 | mode = PLAYER 60 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/deathmatch.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/deathmatch.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/deathmatch.wad.bak: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/deathmatch.wad.bak -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/defend_the_center.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_center.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Game variables that will be in the state 35 | available_game_variables = { AMMO2 HEALTH } 36 | 37 | mode = PLAYER 38 | doom_skill = 3 39 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/defend_the_center.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/defend_the_center.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/defend_the_line.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_line.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_320X240 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | 24 | # Available buttons 25 | available_buttons = 26 | { 27 | TURN_lEFT 28 | TURN_RIGHT 29 | ATTACK 30 | } 31 | 32 | # Game variables that will be in the state 33 | available_game_variables = { AMMO2 HEALTH} 34 | 35 | mode = PLAYER 36 | doom_skill = 3 37 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/defend_the_line.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/defend_the_line.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/doom2.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/doom2.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/doom_state.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/doom_state.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/health_gathering.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = health_gathering.wad 6 | 7 | # Each step is good for you! 8 | living_reward = 1 9 | # And death is not! 10 | death_penalty = 100 11 | 12 | # Rendering options 13 | screen_resolution = RES_320X240 14 | screen_format = CRCGCB 15 | render_hud = false 16 | render_crosshair = false 17 | render_weapon = false 18 | render_decals = false 19 | render_particles = false 20 | window_visible = true 21 | 22 | # make episodes finish after 2100 actions (tics) 23 | episode_timeout = 2100 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | TURN_LEFT 29 | TURN_RIGHT 30 | MOVE_FORWARD 31 | } 32 | 33 | # Game variables that will be in the state 34 | available_game_variables = { HEALTH } 35 | 36 | mode = PLAYER -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/health_gathering.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/health_gathering.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/health_gathering_supreme.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = health_gathering_supreme.wad 6 | 7 | # Each step is good for you! 8 | living_reward = 1 9 | # And death is not! 10 | death_penalty = 100 11 | 12 | # Rendering options 13 | screen_resolution = RES_320X240 14 | screen_format = CRCGCB 15 | render_hud = false 16 | render_crosshair = false 17 | render_weapon = false 18 | render_decals = false 19 | render_particles = false 20 | window_visible = true 21 | 22 | # make episodes finish after 2100 actions (tics) 23 | episode_timeout = 2100 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | TURN_LEFT 29 | TURN_RIGHT 30 | MOVE_FORWARD 31 | } 32 | 33 | # Game variables that will be in the state 34 | available_game_variables = { HEALTH } 35 | 36 | mode = PLAYER -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/health_gathering_supreme.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/health_gathering_supreme.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/learning.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | render_hud = false 10 | render_crosshair = false 11 | render_weapon = true 12 | render_decals = false 13 | render_particles = false 14 | window_visible = false 15 | 16 | # make episodes start after 20 tics (after unholstering the gun) 17 | episode_start_time = 14 18 | 19 | # make episodes finish after 300 actions (tics) 20 | episode_timeout = 300 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | MOVE_LEFT 26 | MOVE_RIGHT 27 | ATTACK 28 | } 29 | 30 | mode = PLAYER 31 | 32 | 33 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/multi.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = multi_deathmatch.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = true 14 | render_crosshair = true 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | 19 | window_visible = true 20 | 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | TURN_LEFT 26 | TURN_RIGHT 27 | ATTACK 28 | 29 | MOVE_RIGHT 30 | MOVE_LEFT 31 | 32 | MOVE_FORWARD 33 | MOVE_BACKWARD 34 | TURN_LEFT_RIGHT_DELTA 35 | LOOK_UP_DOWN_DELTA 36 | 37 | } 38 | 39 | available_game_variables = 40 | { 41 | HEALTH 42 | AMMO3 43 | } 44 | mode = ASYNC_PLAYER 45 | 46 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/multi_deathmatch.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/multi_deathmatch.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/multi_duel.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = multi_duel.wad 2 | 3 | screen_resolution = RES_640X480 4 | screen_format = CRCGCB 5 | render_hud = true 6 | render_crosshair = false 7 | render_weapon = true 8 | render_decals = true 9 | render_particles = true 10 | window_visible = true 11 | 12 | available_buttons = 13 | { 14 | MOVE_LEFT 15 | MOVE_RIGHT 16 | ATTACK 17 | } 18 | 19 | mode = PLAYER 20 | doom_skill = 5 21 | 22 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/multi_duel.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/multi_duel.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/my_way_home.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = my_way_home.wad 6 | 7 | # Rewards 8 | living_reward = -0.0001 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | MOVE_FORWARD 32 | MOVE_LEFT 33 | MOVE_RIGHT 34 | } 35 | 36 | # Game variables that will be in the state 37 | available_game_variables = { AMMO0 } 38 | 39 | mode = PLAYER 40 | doom_skill = 5 41 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/my_way_home.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/my_way_home.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/predict_position.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = predict_position.wad 6 | 7 | # Rewards 8 | living_reward = -0.001 9 | 10 | # Rendering options 11 | screen_resolution = RES_800X450 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 16 tics (after producing the rocket launcher) 21 | episode_start_time = 16 22 | 23 | # make episodes finish after 300 actions (tics) 24 | episode_timeout = 300 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Empty list is allowed, in case you are lazy. 35 | available_game_variables = { } 36 | 37 | game_args += +sv_noautoaim 1 38 | 39 | mode = PLAYER 40 | doom_skill = 1 41 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/predict_position.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/predict_position.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/rocket_basic.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = rocket_basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | render_hud = true 10 | render_crosshair = false 11 | render_weapon = true 12 | render_decals = false 13 | render_particles = false 14 | 15 | # make episodes start after 14 tics (after unholstering the gun) 16 | episode_start_time = 14 17 | 18 | # make episodes finish after 300 actions (tics) 19 | episode_timeout = 300 20 | 21 | # Available buttons 22 | available_buttons = 23 | { 24 | MOVE_LEFT 25 | MOVE_RIGHT 26 | ATTACK 27 | } 28 | 29 | game_args += +sv_noautoaim 1 30 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/rocket_basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/rocket_basic.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/simpler_basic.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = simpler_basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | 10 | render_hud = true 11 | render_crosshair = false 12 | render_weapon = true 13 | render_decals = false 14 | render_particles = false 15 | 16 | # make episodes start after 20 tics (after unholstering the gun) 17 | episode_start_time = 14 18 | 19 | # make episodes finish after 300 actions (tics) 20 | episode_timeout = 300 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | MOVE_LEFT 26 | MOVE_RIGHT 27 | ATTACK 28 | } 29 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/simpler_basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/simpler_basic.wad -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/take_cover.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = take_cover.wad 6 | doom_map = map01 7 | 8 | # Rewards 9 | living_reward = 1 10 | 11 | # Rendering options 12 | screen_resolution = RES_320X240 13 | screen_format = CRCGCB 14 | render_hud = false 15 | render_crosshair = false 16 | render_weapon = false 17 | render_decals = false 18 | render_particles = false 19 | window_visible = true 20 | 21 | # Available buttons 22 | available_buttons = 23 | { 24 | MOVE_LEFT 25 | MOVE_RIGHT 26 | } 27 | 28 | # Game variables that will be in the state 29 | available_game_variables = { HEALTH } 30 | 31 | # Change it if you wish. 32 | doom_skill = 4 33 | 34 | -------------------------------------------------------------------------------- /vizdoom_env/asset/scenarios/take_cover.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/asset/scenarios/take_cover.wad -------------------------------------------------------------------------------- /vizdoom_env/dataset_vizdoom.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import numpy as np 7 | import h5py 8 | from vizdoom_env.util import log 9 | 10 | 11 | rs = np.random.RandomState(123) 12 | 13 | 14 | class Dataset(object): 15 | 16 | def __init__(self, ids, dataset_path, name='default', num_k=10, is_train=True): 17 | self._ids = list(ids) 18 | self.name = name 19 | self.dataset_image_path = osp.join(dataset_path, 'images') 20 | self.is_train = is_train 21 | self.num_k = num_k 22 | 23 | filename = 'data.hdf5' 24 | file = osp.join(dataset_path, filename) 25 | log.info("Reading %s ...", file) 26 | 27 | self.data = h5py.File(file, 'r') 28 | self.num_demo = int(self.data['data_info']['num_demo_per_program'].value) 29 | self.max_demo_len = int(self.data['data_info']['max_demo_length'].value) 30 | self.max_program_len = int(self.data['data_info']['max_program_length'].value) 31 | self.num_program_tokens = int(self.data['data_info']['num_program_tokens'].value) 32 | self.num_action_tokens = int(self.data['data_info']['num_action_tokens'].value) 33 | self.vizdoom_pos_keys = list(self.data['data_info']['vizdoom_pos_keys'].value) 34 | self.vizdoom_max_init_pos_len = \ 35 | int(self.data['data_info']['vizdoom_max_init_pos_len'].value) 36 | self.perception_type = self.data['data_info']['perception_type'].value 37 | if 'level' in self.data['data_info'].keys(): 38 | self.level = self.data['data_info']['level'].value 39 | else: self.level = 'not_simple' 40 | 41 | self.k = int(self.data['data_info']['num_demo_per_program'].value) 42 | self.test_k = int(self.data['data_info']['num_test_demo_per_program'].value) 43 | self.s_h_h = int(self.data['data_info']['s_h_h'].value) 44 | self.s_h_w = int(self.data['data_info']['s_h_w'].value) 45 | self.s_h_c = int(self.data['data_info']['s_h_c'].value) 46 | log.info("Reading Done: %s", file) 47 | 48 | def get_data(self, id, order=None): 49 | # preprocessing and data augmentation 50 | 51 | # each data point consist of a program + k demo 52 | 53 | # dim: [one hot dim of program tokens, program len] 54 | program_tokens = self.data[id]['program'].value 55 | program = np.zeros([self.num_program_tokens, self.max_program_len], dtype=bool) 56 | program[:, :len(program_tokens)][program_tokens, np.arange(len(program_tokens))] = 1 57 | padded_program_tokens = np.zeros([self.max_program_len], dtype=program_tokens.dtype) 58 | padded_program_tokens[:len(program_tokens)] = program_tokens 59 | 60 | # get s_h and test_s_h 61 | demo_data = self.data[id]['s_h'].value[:self.num_k] 62 | test_demo_data = self.data[id]['test_s_h'].value 63 | 64 | sz = demo_data.shape 65 | demo = np.zeros([sz[0], self.max_demo_len, sz[2], sz[3], sz[4]], dtype=demo_data.dtype) 66 | demo[:, :sz[1], :, :, :] = demo_data 67 | sz = test_demo_data.shape 68 | test_demo = np.zeros([sz[0], self.max_demo_len, sz[2], sz[3], sz[4]], dtype=demo_data.dtype) 69 | test_demo[:, :sz[1], :, :, :] = test_demo_data 70 | 71 | # dim: [k, action_space, max len of demo - 1] 72 | action_history_tokens = self.data[id]['a_h'].value[:self.num_k] 73 | action_history = [] 74 | for a_h_tokens in action_history_tokens: 75 | # num_action_tokens + 1 is token which is required for detecting 76 | # the end of the sequence. Even though the original length of the 77 | # action history is max_demo_len - 1, we make it max_demo_len, by 78 | # including the last token. 79 | a_h = np.zeros([self.max_demo_len, self.num_action_tokens + 1], dtype=bool) 80 | a_h[:len(a_h_tokens), :][np.arange(len(a_h_tokens)), a_h_tokens] = 1 81 | a_h[len(a_h_tokens), self.num_action_tokens] = 1 # 82 | action_history.append(a_h) 83 | action_history = np.stack(action_history, axis=0) 84 | padded_action_history_tokens = np.argmax(action_history, axis=2) 85 | 86 | # dim: [test_k, action_space, max len of demo - 1] 87 | test_action_history_tokens = self.data[id]['test_a_h'].value 88 | test_action_history = [] 89 | for test_a_h_tokens in test_action_history_tokens: 90 | # num_action_tokens + 1 is token which is required for detecting 91 | # the end of the sequence. Even though the original length of the 92 | # action history is max_demo_len - 1, we make it max_demo_len, by 93 | # including the last token. 94 | test_a_h = np.zeros([self.max_demo_len, self.num_action_tokens + 1], dtype=bool) 95 | test_a_h[:len(test_a_h_tokens), :][np.arange(len(test_a_h_tokens)), test_a_h_tokens] = 1 96 | test_a_h[len(test_a_h_tokens), self.num_action_tokens] = 1 # 97 | test_action_history.append(test_a_h) 98 | test_action_history = np.stack(test_action_history, axis=0) 99 | padded_test_action_history_tokens = np.argmax(test_action_history, axis=2) 100 | 101 | # program length: [1] 102 | program_length = np.array([len(program_tokens)], dtype=np.float32) 103 | 104 | # len of each demo. dim: [k] 105 | demo_length = self.data[id]['s_h_len'].value[:self.num_k] 106 | test_demo_length = self.data[id]['test_s_h_len'].value 107 | 108 | demo_percept_data = self.data[id]['p_v_h'].value[:self.num_k] 109 | sz = demo_percept_data.shape 110 | demo_percept = np.zeros([sz[0], self.max_demo_len, sz[2]], 111 | dtype=demo_percept_data.dtype) 112 | demo_percept[:, :sz[1], :] = demo_percept_data 113 | 114 | test_demo_percept_data = self.data[id]['test_p_v_h'].value 115 | sz = test_demo_percept_data.shape 116 | test_demo_percept = np.zeros([sz[0], self.max_demo_len, sz[2]], 117 | dtype=test_demo_percept_data.dtype) 118 | test_demo_percept[:, :sz[1], :] = test_demo_percept_data 119 | 120 | init_pos_data = self.data[id]['vizdoom_init_pos'].value[:self.num_k] 121 | sz = init_pos_data.shape 122 | init_pos = np.zeros([sz[0], sz[1], self.vizdoom_max_init_pos_len, 2], 123 | dtype=init_pos_data.dtype) 124 | init_pos[:, :, :sz[2], :] = init_pos_data 125 | init_pos_len = self.data[id]['vizdoom_init_pos_len'].value[:self.num_k] 126 | 127 | test_init_pos_data = self.data[id]['test_vizdoom_init_pos'].value 128 | sz = test_init_pos_data.shape 129 | test_init_pos = np.zeros([sz[0], sz[1], self.vizdoom_max_init_pos_len, 2], 130 | dtype=test_init_pos_data.dtype) 131 | test_init_pos[:, :, :sz[2], :] = test_init_pos_data 132 | test_init_pos_len = self.data[id]['test_vizdoom_init_pos_len'].value 133 | 134 | outputs = [program, padded_program_tokens, demo, test_demo, 135 | action_history, padded_action_history_tokens, 136 | test_action_history, padded_test_action_history_tokens, 137 | program_length, demo_length, test_demo_length, 138 | demo_percept, test_demo_percept, 139 | init_pos, init_pos_len, test_init_pos, test_init_pos_len] 140 | return tuple(outputs) 141 | 142 | @property 143 | def ids(self): 144 | return self._ids 145 | 146 | def __len__(self): 147 | return len(self.ids) 148 | 149 | def __repr__(self): 150 | return 'Dataset (%s, %d examples)' % ( 151 | self.name, 152 | len(self) 153 | ) 154 | 155 | 156 | def create_default_splits(dataset_path, num_k=10, is_train=True): 157 | ids_train, ids_test, ids_val = all_ids(dataset_path) 158 | 159 | dataset_train = Dataset(ids_train, dataset_path, name='train', 160 | num_k=num_k, is_train=is_train) 161 | dataset_test = Dataset(ids_test, dataset_path, name='test', 162 | num_k=num_k, is_train=is_train) 163 | dataset_val = Dataset(ids_val, dataset_path, name='val', 164 | num_k=num_k, is_train=is_train) 165 | return dataset_train, dataset_test, dataset_val 166 | 167 | 168 | def all_ids(dataset_path): 169 | with h5py.File(osp.join(dataset_path, 'data.hdf5'), 'r') as f: 170 | num_train = int(f['data_info']['num_train'].value) 171 | num_test = int(f['data_info']['num_test'].value) 172 | num_val = int(f['data_info']['num_val'].value) 173 | 174 | with open(osp.join(dataset_path, 'id.txt'), 'r') as fp: 175 | ids_total = [s.strip() for s in fp.readlines() if s] 176 | 177 | ids_train = ids_total[:num_train] 178 | ids_test = ids_total[num_train: num_train + num_test] 179 | ids_val = ids_total[num_train + num_test: num_train + num_test + num_val] 180 | 181 | rs.shuffle(ids_train) 182 | rs.shuffle(ids_test) 183 | rs.shuffle(ids_val) 184 | 185 | return ids_train, ids_test, ids_val 186 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/demo2program/23464a69bfbf6fac9752fd423d14b03d37d1d1c6/vizdoom_env/dsl/__init__.py -------------------------------------------------------------------------------- /vizdoom_env/dsl/dsl_enum_program.py: -------------------------------------------------------------------------------- 1 | MONSTER_LIST = ['Demon', 'HellKnight', 'Revenant'] 2 | 3 | ITEMS_IN_INTEREST = ['MyAmmo'] 4 | 5 | ACTION_LIST = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 'MOVE_RIGHT', 6 | 'TURN_LEFT', 'TURN_RIGHT', 'ATTACK', 7 | 'SELECT_WEAPON1', 'SELECT_WEAPON2', 'SELECT_WEAPON3', 8 | 'SELECT_WEAPON4', 'SELECT_WEAPON5'] 9 | 10 | DISTANCE_DICT = { 11 | 'doncare_dist': lambda d: True, 12 | 'far': lambda d: d > 400, 13 | 'mid': lambda d: d < 300, 14 | 'close': lambda d: d < 180, 15 | 'very_close': lambda d: d < 135} 16 | 17 | HORIZONTAL_DICT = { 18 | 'doncare_horz': lambda l, r, x: True, 19 | 'center': lambda l, r, x: l < x and x < r, 20 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 21 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 22 | 'mid_left': lambda l, r, x: r < x and x <= r + 20, 23 | 'mid_right': lambda l, r, x: l > x and x >= l - 20, 24 | 'left': lambda l, r, x: r < x, 25 | 'right': lambda l, r, x: l > x} 26 | 27 | CLEAR_DISTANCE_DICT = { 28 | 'far': lambda d: d > 400, 29 | 'mid_far': lambda d: 300 < d and d <= 400, 30 | 'mid': lambda d: 180 < d and d <= 300, 31 | 'close': lambda d: 135 < d and d <= 180, 32 | 'very_close': lambda d: d <= 135} 33 | 34 | CLEAR_HORIZONTAL_DICT = { 35 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 36 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 37 | 'mid_left': lambda l, r, x: r + 10 < x and x <= r + 20, 38 | 'mid_right': lambda l, r, x: l - 10 > x and x >= l - 20, 39 | 'left': lambda l, r, x: r + 20 < x, 40 | 'right': lambda l, r, x: l - 20 > x} 41 | 42 | merge_distance_vocab = list(set(DISTANCE_DICT.keys()).union( 43 | set(CLEAR_DISTANCE_DICT.keys()))) 44 | merge_horizontal_vocab = list(set(HORIZONTAL_DICT.keys()).union( 45 | set(CLEAR_HORIZONTAL_DICT.keys()))) 46 | 47 | 48 | def check_and_apply(queue, rule): 49 | r = rule[0].split() 50 | l = len(r) 51 | if len(queue) >= l: 52 | t = queue[-l:] 53 | if list(zip(*t)[0]) == r: 54 | new_t = rule[1](list(zip(*t)[1])) 55 | del queue[-l:] 56 | queue.extend(new_t) 57 | return True 58 | return False 59 | 60 | rules = [] 61 | 62 | # world, n, s = fn(world, n) 63 | # world: vizdoom_world 64 | # n: num_call 65 | # s: success 66 | # c: condition [True, False] 67 | MAX_WHILE = 1000 68 | 69 | 70 | def r_prog(t): 71 | stmt = t[3] 72 | 73 | return [('prog', stmt(0, 0))] 74 | rules.append(('DEF run m( stmt m)', r_prog)) 75 | 76 | 77 | def r_stmt(t): 78 | stmt = t[0] 79 | 80 | def fn(k, n): 81 | return stmt(k, n) 82 | return [('stmt', fn)] 83 | rules.append(('while_stmt', r_stmt)) 84 | rules.append(('repeat_stmt', r_stmt)) 85 | rules.append(('stmt_stmt', r_stmt)) 86 | rules.append(('action', r_stmt)) 87 | rules.append(('if_stmt', r_stmt)) 88 | rules.append(('ifelse_stmt', r_stmt)) 89 | 90 | 91 | def r_stmt_stmt(t): 92 | stmt1, stmt2 = t[0], t[1] 93 | 94 | def fn(k, n): 95 | return stmt1(k, n) + stmt2(k, n) 96 | return [('stmt_stmt', fn)] 97 | rules.append(('stmt stmt', r_stmt_stmt)) 98 | 99 | 100 | def r_if(t): 101 | cond, stmt = t[2], t[5] 102 | 103 | def fn(k, n): 104 | return ['if'] + cond(k, n) + stmt(k, n) 105 | return [('if_stmt', fn)] 106 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 107 | 108 | 109 | def r_ifelse(t): 110 | cond, stmt1, stmt2 = t[2], t[5], t[9] 111 | 112 | def fn(k, n): 113 | stmt1_out = stmt1(k, n) 114 | stmt2_out = stmt2(k, n) 115 | if stmt1_out == stmt2_out: 116 | return stmt1_out 117 | cond_out = cond(k, n) 118 | if cond_out[0] == 'not': 119 | else_cond = ['if'] + cond_out[1:] 120 | else: 121 | else_cond = ['if', 'not'] + cond_out 122 | return ['if'] + cond_out + stmt1(k, n) + else_cond + stmt2(k, n) 123 | return [('ifelse_stmt', fn)] 124 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 125 | 126 | 127 | def r_while(t): 128 | cond, stmt = t[2], t[5] 129 | 130 | def fn(k, n): 131 | cond_out = cond(k, n) 132 | stmt_out = stmt(k, n) 133 | while_out = [] 134 | for _ in range(MAX_WHILE): 135 | while_out.extend(['if'] + cond_out + stmt_out) 136 | return while_out 137 | return [('while_stmt', fn)] 138 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 139 | 140 | 141 | def r_repeat(t): 142 | cste, stmt = t[1], t[3] 143 | 144 | def fn(k, n): 145 | repeat_out = [] 146 | for _ in range(cste()): 147 | repeat_out.extend(stmt(k, n)) 148 | return repeat_out 149 | return [('repeat_stmt', fn)] 150 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 151 | 152 | 153 | def r_cond1(t): 154 | cond = t[0] 155 | 156 | def fn(k, n): 157 | return cond(k, n) 158 | return [('cond', fn)] 159 | rules.append(('percept', r_cond1)) 160 | 161 | 162 | def r_cond2(t): 163 | cond = t[2] 164 | 165 | def fn(k, n): 166 | cond_out = cond(k, n) 167 | if cond_out[0] == 'not': 168 | cond_out = cond_out[1:] 169 | else: 170 | cond_out = ['not'] + cond_out 171 | return cond_out 172 | return [('cond', fn)] 173 | rules.append(('not c( cond c)', r_cond2)) 174 | 175 | 176 | def r_percept1(t): 177 | actor, dist, horz = t[1], t[3], t[4] 178 | 179 | def fn(world, n): 180 | return ['exist_actor_in_distance_horizontal', actor(), dist(), horz()] 181 | return [('percept', fn)] 182 | rules.append(('EXIST actor IN distance horizontal', r_percept1)) 183 | 184 | 185 | def r_percept2(t): 186 | actor = t[1] 187 | 188 | def fn(world, n): 189 | return ['in_target', actor()] 190 | return [('percept', fn)] 191 | rules.append(('INTARGET actor', r_percept2)) 192 | 193 | 194 | def r_percept3(t): 195 | actor = t[1] 196 | 197 | def fn(world, n): 198 | return ['is_there', actor()] 199 | return [('percept', fn)] 200 | rules.append(('ISTHERE actor', r_percept3)) 201 | 202 | 203 | def r_actor1(t): 204 | return [('actor', t[0])] 205 | rules.append(('monster', r_actor1)) 206 | 207 | 208 | def create_r_monster(monster): 209 | def r_monster(t): 210 | return [('monster', lambda: monster)] 211 | return r_monster 212 | for monster in MONSTER_LIST: 213 | rules.append((monster, create_r_monster(monster))) 214 | 215 | 216 | def r_actor2(t): 217 | return [('actor', t[0])] 218 | rules.append(('items', r_actor2)) 219 | 220 | 221 | def create_r_item(item): 222 | def r_item(t): 223 | return [('items', lambda: item)] 224 | return r_item 225 | for item in ITEMS_IN_INTEREST: 226 | rules.append((item, create_r_item(item))) 227 | 228 | 229 | def create_r_distance(distance): 230 | def r_distance(t): 231 | return [('distance', lambda: distance)] 232 | return r_distance 233 | for distance in merge_distance_vocab: 234 | rules.append((distance, create_r_distance(distance))) 235 | 236 | 237 | def create_r_horizontal(horizontal): 238 | def r_horizontal(t): 239 | return [('horizontal', lambda: horizontal)] 240 | return r_horizontal 241 | for horizontal in merge_horizontal_vocab: 242 | rules.append((horizontal, create_r_horizontal(horizontal))) 243 | 244 | 245 | def create_r_slot(slot_number): 246 | def r_slot(t): 247 | return [('slot', lambda: slot_number)] 248 | return r_slot 249 | for slot_number in range(1, 7): 250 | rules.append(('S={}'.format(slot_number), create_r_slot(slot_number))) 251 | 252 | 253 | def create_r_action(action): 254 | def r_action(t): 255 | def fn(world, n): 256 | return [action] 257 | return [('action', fn)] 258 | return r_action 259 | for action in ACTION_LIST: 260 | rules.append((action, create_r_action(action))) 261 | 262 | 263 | def create_r_cste(number): 264 | def r_cste(t): 265 | return [('cste', lambda: number)] 266 | return r_cste 267 | for i in range(20): 268 | rules.append(('R={}'.format(i), create_r_cste(i))) 269 | 270 | 271 | def parse(program): 272 | p_tokens = program.split()[::-1] 273 | queue = [] 274 | applied = False 275 | while len(p_tokens) > 0 or len(queue) != 1: 276 | if applied: applied = False 277 | else: 278 | queue.append((p_tokens.pop(), None)) 279 | for rule in rules: 280 | applied = check_and_apply(queue, rule) 281 | if applied: break 282 | if not applied and len(p_tokens) == 0: # error parsing 283 | return None, False 284 | return queue[0][1], True 285 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/dsl_hit_analysis.py: -------------------------------------------------------------------------------- 1 | MONSTER_LIST = ['Demon', 'HellKnight', 'Revenant'] 2 | 3 | ITEMS_IN_INTEREST = ['MyAmmo'] 4 | 5 | ACTION_LIST = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 'MOVE_RIGHT', 6 | 'TURN_LEFT', 'TURN_RIGHT', 'ATTACK', 7 | 'SELECT_WEAPON1', 'SELECT_WEAPON2', 'SELECT_WEAPON3', 8 | 'SELECT_WEAPON4', 'SELECT_WEAPON5'] 9 | 10 | DISTANCE_DICT = { 11 | 'doncare_dist': lambda d: True, 12 | 'far': lambda d: d > 400, 13 | 'mid': lambda d: d < 300, 14 | 'close': lambda d: d < 180, 15 | 'very_close': lambda d: d < 135} 16 | 17 | HORIZONTAL_DICT = { 18 | 'doncare_horz': lambda l, r, x: True, 19 | 'center': lambda l, r, x: l < x and x < r, 20 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 21 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 22 | 'mid_left': lambda l, r, x: r < x and x <= r + 20, 23 | 'mid_right': lambda l, r, x: l > x and x >= l - 20, 24 | 'left': lambda l, r, x: r < x, 25 | 'right': lambda l, r, x: l > x} 26 | 27 | CLEAR_DISTANCE_DICT = { 28 | 'far': lambda d: d > 400, 29 | 'mid_far': lambda d: 300 < d and d <= 400, 30 | 'mid': lambda d: 180 < d and d <= 300, 31 | 'close': lambda d: 135 < d and d <= 180, 32 | 'very_close': lambda d: d <= 135} 33 | 34 | CLEAR_HORIZONTAL_DICT = { 35 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 36 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 37 | 'mid_left': lambda l, r, x: r + 10 < x and x <= r + 20, 38 | 'mid_right': lambda l, r, x: l - 10 > x and x >= l - 20, 39 | 'left': lambda l, r, x: r + 20 < x, 40 | 'right': lambda l, r, x: l - 20 > x} 41 | 42 | merge_distance_vocab = list(set(DISTANCE_DICT.keys()).union( 43 | set(CLEAR_DISTANCE_DICT.keys()))) 44 | merge_horizontal_vocab = list(set(HORIZONTAL_DICT.keys()).union( 45 | set(CLEAR_HORIZONTAL_DICT.keys()))) 46 | 47 | 48 | def check_and_apply(queue, rule): 49 | r = rule[0].split() 50 | l = len(r) 51 | if len(queue) >= l: 52 | t = queue[-l:] 53 | if list(zip(*t)[0]) == r: 54 | new_t = rule[1](list(zip(*t)[1]), list(zip(*t)[2])) 55 | del queue[-l:] 56 | queue.extend(new_t) 57 | return True 58 | return False 59 | 60 | rules = [] 61 | 62 | # world, n, s = fn(world, n) 63 | # world: vizdoom_world 64 | # n: num_call 65 | # s: success 66 | # c: condition [True, False] 67 | MAX_FUNC_CALL = 100 68 | 69 | 70 | def r_prog(tn, t): 71 | stmt = t[3] 72 | token_hit = tn[:3] + tn[4:] 73 | 74 | def fn(world, n): 75 | if n > MAX_FUNC_CALL: return token_hit, n, False 76 | hit_s, n, s = stmt(world, n + 1) 77 | return token_hit + hit_s, n, s 78 | return [('prog', -1, fn)] 79 | rules.append(('DEF run m( stmt m)', r_prog)) 80 | 81 | 82 | def r_stmt(tn, t): 83 | stmt = t[0] 84 | 85 | def fn(world, n): 86 | if n > MAX_FUNC_CALL: return [], n, False 87 | return stmt(world, n + 1) 88 | return [('stmt', -1, fn)] 89 | rules.append(('while_stmt', r_stmt)) 90 | rules.append(('repeat_stmt', r_stmt)) 91 | rules.append(('stmt_stmt', r_stmt)) 92 | rules.append(('action', r_stmt)) 93 | rules.append(('if_stmt', r_stmt)) 94 | rules.append(('ifelse_stmt', r_stmt)) 95 | 96 | 97 | def r_stmt_stmt(tn, t): 98 | stmt1, stmt2 = t[0], t[1] 99 | 100 | def fn(world, n): 101 | if n > MAX_FUNC_CALL: return [], n, False 102 | hit_s1, n, s = stmt1(world, n + 1) 103 | if not s: return hit_s1, n, s 104 | if n > MAX_FUNC_CALL: return hit_s1, n, False 105 | hit_s2, n, s = stmt2(world, n) 106 | return hit_s1 + hit_s2, n, s 107 | return [('stmt_stmt', -1, fn)] 108 | rules.append(('stmt stmt', r_stmt_stmt)) 109 | 110 | 111 | def r_if(tn, t): 112 | cond, stmt = t[2], t[5] 113 | token_hit = tn[:2] + tn[3:5] + tn[6:] 114 | 115 | def fn(world, n): 116 | if n > MAX_FUNC_CALL: return [], n, False 117 | hit_c, n, s, c = cond(world, n + 1) 118 | if not s: return token_hit + hit_c, n, s 119 | if c: 120 | hit_s, n, s = stmt(world, n) 121 | return token_hit + hit_c + hit_s, n, s 122 | else: return token_hit + hit_c, n, s 123 | return [('if_stmt', -1, fn)] 124 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 125 | 126 | 127 | def r_ifelse(tn, t): 128 | cond, stmt1, stmt2 = t[2], t[5], t[9] 129 | token_hit = tn[:2] + tn[3:5] + tn[6:9] + tn[10:] 130 | 131 | def fn(world, n): 132 | if n > MAX_FUNC_CALL: return token_hit, n, False 133 | hit_c, n, s, c = cond(world, n + 1) 134 | if not s: return token_hit + hit_c, n, s 135 | if c: 136 | hit_s1, n, s = stmt1(world, n) 137 | return token_hit + hit_c + hit_s1, n, s 138 | else: 139 | hit_s2, n, s = stmt2(world, n) 140 | return token_hit + hit_c + hit_s2, n, s 141 | return [('ifelse_stmt', -1, fn)] 142 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 143 | 144 | 145 | def r_while(tn, t): 146 | cond, stmt = t[2], t[5] 147 | token_hit = tn[:2] + tn[3:5] + tn[6:] 148 | 149 | def fn(world, n): 150 | if n > MAX_FUNC_CALL: return token_hit, n, False 151 | hit_c, n, s, c = cond(world, n) 152 | if not s: return token_hit + hit_c, n, s 153 | total_hit = token_hit 154 | while(c): 155 | hit_s, n, s = stmt(world, n) 156 | total_hit.extend(hit_s) 157 | if not s: return total_hit, n, s 158 | hit_c, n, s, c = cond(world, n) 159 | total_hit.extend(hit_c) 160 | if not s: return total_hit, n, s 161 | return total_hit, n, s 162 | return [('while_stmt', -1, fn)] 163 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 164 | 165 | 166 | def r_repeat(tn, t): 167 | cste, stmt = t[1], t[3] 168 | token_hit = tn[:3] + tn[4:] 169 | 170 | def fn(world, n): 171 | if n > MAX_FUNC_CALL: return token_hit, n, False 172 | n += 1 173 | s = True 174 | total_hit = token_hit 175 | for _ in range(cste()): 176 | hit_s, n, s = stmt(world, n) 177 | total_hit.extend(hit_s) 178 | if not s: return total_hit, n, s 179 | return total_hit, n, s 180 | return [('repeat_stmt', -1, fn)] 181 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 182 | 183 | 184 | def r_cond1(tn, t): 185 | cond = t[0] 186 | 187 | def fn(world, n): 188 | if n > MAX_FUNC_CALL: return [], n, False, False 189 | return cond(world, n) 190 | return [('cond', -1, fn)] 191 | rules.append(('percept', r_cond1)) 192 | 193 | 194 | def r_cond2(tn, t): 195 | cond = t[2] 196 | token_hit = tn[:2] + tn[3:] 197 | 198 | def fn(world, n): 199 | if n > MAX_FUNC_CALL: return token_hit, n, False, False 200 | hit_c, n, s, c = cond(world, n) 201 | return token_hit + hit_c, n, s, not c 202 | return [('cond', -1, fn)] 203 | rules.append(('not c( cond c)', r_cond2)) 204 | 205 | 206 | def r_percept1(tn, t): 207 | actor, dist, horz = t[1], t[3], t[4] 208 | token_hit = tn 209 | 210 | def fn(world, n): 211 | if n > MAX_FUNC_CALL: return token_hit, n, False, False 212 | c = world.exist_actor_in_distance_horizontal(actor(), dist(), horz()) 213 | return token_hit, n, True, c 214 | return [('percept', -1, fn)] 215 | rules.append(('EXIST actor IN distance horizontal', r_percept1)) 216 | 217 | 218 | def r_percept2(tn, t): 219 | actor = t[1] 220 | token_hit = tn 221 | 222 | def fn(world, n): 223 | if n > MAX_FUNC_CALL: return token_hit, n, False, False 224 | c = world.in_target(actor()) 225 | return token_hit, n, True, c 226 | return [('percept', -1, fn)] 227 | rules.append(('INTARGET actor', r_percept2)) 228 | 229 | 230 | def r_percept3(tn, t): 231 | actor = t[1] 232 | token_hit = tn 233 | 234 | def fn(world, n): 235 | if n > MAX_FUNC_CALL: return token_hit, n, False, False 236 | c = world.is_there(actor()) 237 | return token_hit, n, True, c 238 | return [('percept', -1, fn)] 239 | rules.append(('ISTHERE actor', r_percept3)) 240 | 241 | 242 | def r_actor1(tn, t): 243 | return [('actor', tn[0], t[0])] 244 | rules.append(('monster', r_actor1)) 245 | 246 | 247 | def create_r_monster(monster): 248 | def r_monster(tn, t): 249 | return [('monster', tn[0], lambda: monster)] 250 | return r_monster 251 | for monster in MONSTER_LIST: 252 | rules.append((monster, create_r_monster(monster))) 253 | 254 | 255 | def r_actor2(tn, t): 256 | return [('actor', tn[0], t[0])] 257 | rules.append(('items', r_actor2)) 258 | 259 | 260 | def create_r_item(item): 261 | def r_item(tn, t): 262 | return [('items', tn[0], lambda: item)] 263 | return r_item 264 | for item in ITEMS_IN_INTEREST: 265 | rules.append((item, create_r_item(item))) 266 | 267 | 268 | def create_r_distance(distance): 269 | def r_distance(tn, t): 270 | return [('distance', tn[0], lambda: distance)] 271 | return r_distance 272 | for distance in merge_distance_vocab: 273 | rules.append((distance, create_r_distance(distance))) 274 | 275 | 276 | def create_r_horizontal(horizontal): 277 | def r_horizontal(tn, t): 278 | return [('horizontal', tn[0], lambda: horizontal)] 279 | return r_horizontal 280 | for horizontal in merge_horizontal_vocab: 281 | rules.append((horizontal, create_r_horizontal(horizontal))) 282 | 283 | 284 | def create_r_slot(slot_number): 285 | def r_slot(tn, t): 286 | return [('slot', tn[0], lambda: slot_number)] 287 | return r_slot 288 | for slot_number in range(1, 7): 289 | rules.append(('S={}'.format(slot_number), create_r_slot(slot_number))) 290 | 291 | 292 | def create_r_action(action): 293 | def r_action(tn, t): 294 | token_hit = tn 295 | 296 | def fn(world, n): 297 | if n > MAX_FUNC_CALL: token_hit, n, False 298 | try: world.state_transition(action) 299 | except: return token_hit, n, False 300 | else: return token_hit, n, True 301 | return [('action', -1, fn)] 302 | return r_action 303 | for action in ACTION_LIST: 304 | rules.append((action, create_r_action(action))) 305 | 306 | 307 | def create_r_cste(number): 308 | def r_cste(tn, t): 309 | return [('cste', tn[0], lambda: number)] 310 | return r_cste 311 | for i in range(20): 312 | rules.append(('R={}'.format(i), create_r_cste(i))) 313 | 314 | 315 | def hit_count(program): 316 | p_tokens = program.split()[::-1] 317 | token_nums = list(range(len(p_tokens)))[::-1] 318 | queue = [] 319 | applied = False 320 | while len(p_tokens) > 0 or len(queue) != 1: 321 | if applied: applied = False 322 | else: 323 | queue.append((p_tokens.pop(), token_nums.pop(), None)) 324 | for rule in rules: 325 | applied = check_and_apply(queue, rule) 326 | if applied: break 327 | if not applied and len(p_tokens) == 0: # error parsing 328 | return None, False 329 | return queue[0][2], True 330 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/dsl_parse.py: -------------------------------------------------------------------------------- 1 | MONSTER_LIST = ['Demon', 'HellKnight', 'Revenant'] 2 | 3 | ITEMS_IN_INTEREST = ['MyAmmo'] 4 | 5 | ACTION_LIST = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 'MOVE_RIGHT', 6 | 'TURN_LEFT', 'TURN_RIGHT', 'ATTACK', 7 | 'SELECT_WEAPON1', 'SELECT_WEAPON2', 'SELECT_WEAPON3', 8 | 'SELECT_WEAPON4', 'SELECT_WEAPON5'] 9 | 10 | DISTANCE_DICT = { 11 | 'doncare_dist': lambda d: True, 12 | 'far': lambda d: d > 400, 13 | 'mid': lambda d: d < 300, 14 | 'close': lambda d: d < 180, 15 | 'very_close': lambda d: d < 135} 16 | 17 | HORIZONTAL_DICT = { 18 | 'doncare_horz': lambda l, r, x: True, 19 | 'center': lambda l, r, x: l < x and x < r, 20 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 21 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 22 | 'mid_left': lambda l, r, x: r < x and x <= r + 20, 23 | 'mid_right': lambda l, r, x: l > x and x >= l - 20, 24 | 'left': lambda l, r, x: r < x, 25 | 'right': lambda l, r, x: l > x} 26 | 27 | CLEAR_DISTANCE_DICT = { 28 | 'far': lambda d: d > 400, 29 | 'mid_far': lambda d: 300 < d and d <= 400, 30 | 'mid': lambda d: 180 < d and d <= 300, 31 | 'close': lambda d: 135 < d and d <= 180, 32 | 'very_close': lambda d: d <= 135} 33 | 34 | CLEAR_HORIZONTAL_DICT = { 35 | 'slight_left': lambda l, r, x: r < x and x <= r + 10, 36 | 'slight_right': lambda l, r, x: l > x and x >= l - 10, 37 | 'mid_left': lambda l, r, x: r + 10 < x and x <= r + 20, 38 | 'mid_right': lambda l, r, x: l - 10 > x and x >= l - 20, 39 | 'left': lambda l, r, x: r + 20 < x, 40 | 'right': lambda l, r, x: l - 20 > x} 41 | 42 | merge_distance_vocab = list(set(DISTANCE_DICT.keys()).union( 43 | set(CLEAR_DISTANCE_DICT.keys()))) 44 | merge_horizontal_vocab = list(set(HORIZONTAL_DICT.keys()).union( 45 | set(CLEAR_HORIZONTAL_DICT.keys()))) 46 | 47 | 48 | def check_and_apply(queue, rule): 49 | r = rule[0].split() 50 | l = len(r) 51 | if len(queue) >= l: 52 | t = queue[-l:] 53 | if list(zip(*t)[0]) == r: 54 | new_t = rule[1](list(zip(*t)[1])) 55 | del queue[-l:] 56 | queue.extend(new_t) 57 | return True 58 | return False 59 | 60 | rules = [] 61 | 62 | # world, n, s = fn(world, n) 63 | # world: vizdoom_world 64 | # n: num_call 65 | # s: success 66 | # c: condition [True, False] 67 | MAX_FUNC_CALL = 100 68 | 69 | 70 | def r_prog(t): 71 | stmt = t[3] 72 | 73 | def fn(world, n): 74 | if n > MAX_FUNC_CALL: return world, n, False 75 | return stmt(world, n + 1) 76 | return [('prog', fn)] 77 | rules.append(('DEF run m( stmt m)', r_prog)) 78 | 79 | 80 | def r_stmt(t): 81 | stmt = t[0] 82 | 83 | def fn(world, n): 84 | if n > MAX_FUNC_CALL: return world, n, False 85 | return stmt(world, n + 1) 86 | return [('stmt', fn)] 87 | rules.append(('while_stmt', r_stmt)) 88 | rules.append(('repeat_stmt', r_stmt)) 89 | rules.append(('stmt_stmt', r_stmt)) 90 | rules.append(('action', r_stmt)) 91 | rules.append(('if_stmt', r_stmt)) 92 | rules.append(('ifelse_stmt', r_stmt)) 93 | 94 | 95 | def r_stmt_stmt(t): 96 | stmt1, stmt2 = t[0], t[1] 97 | 98 | def fn(world, n): 99 | if n > MAX_FUNC_CALL: return world, n, False 100 | world, n, s = stmt1(world, n + 1) 101 | if not s: return world, n, s 102 | if n > MAX_FUNC_CALL: return world, n, False 103 | return stmt2(world, n) 104 | return [('stmt_stmt', fn)] 105 | rules.append(('stmt stmt', r_stmt_stmt)) 106 | 107 | 108 | def r_if(t): 109 | cond, stmt = t[2], t[5] 110 | 111 | def fn(world, n): 112 | if n > MAX_FUNC_CALL: return world, n, False 113 | world, n, s, c = cond(world, n + 1) 114 | if not s: return world, n, s 115 | if c: return stmt(world, n) 116 | else: return world, n, s 117 | return [('if_stmt', fn)] 118 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 119 | 120 | 121 | def r_ifelse(t): 122 | cond, stmt1, stmt2 = t[2], t[5], t[9] 123 | 124 | def fn(world, n): 125 | if n > MAX_FUNC_CALL: return world, n, False 126 | world, n, s, c = cond(world, n + 1) 127 | if not s: return world, n, s 128 | if c: return stmt1(world, n) 129 | else: return stmt2(world, n) 130 | return [('ifelse_stmt', fn)] 131 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 132 | 133 | 134 | def r_while(t): 135 | cond, stmt = t[2], t[5] 136 | 137 | def fn(world, n): 138 | if n > MAX_FUNC_CALL: return world, n, False 139 | world, n, s, c = cond(world, n) 140 | if not s: return world, n, s 141 | while(c): 142 | world, n, s = stmt(world, n) 143 | if not s: return world, n, s 144 | world, n, s, c = cond(world, n) 145 | if not s: return world, n, s 146 | return world, n, s 147 | return [('while_stmt', fn)] 148 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 149 | 150 | 151 | def r_repeat(t): 152 | cste, stmt = t[1], t[3] 153 | 154 | def fn(world, n): 155 | if n > MAX_FUNC_CALL: return world, n, False 156 | n += 1 157 | s = True 158 | for _ in range(cste()): 159 | world, n, s = stmt(world, n) 160 | if not s: return world, n, s 161 | return world, n, s 162 | return [('repeat_stmt', fn)] 163 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 164 | 165 | 166 | def r_cond1(t): 167 | cond = t[0] 168 | 169 | def fn(world, n): 170 | if n > MAX_FUNC_CALL: return world, n, False, False 171 | return cond(world, n) 172 | return [('cond', fn)] 173 | rules.append(('percept', r_cond1)) 174 | 175 | 176 | def r_cond2(t): 177 | cond = t[2] 178 | 179 | def fn(world, n): 180 | if n > MAX_FUNC_CALL: return world, n, False, False 181 | world, n, s, c = cond(world, n) 182 | return world, n, s, not c 183 | return [('cond', fn)] 184 | rules.append(('not c( cond c)', r_cond2)) 185 | 186 | 187 | def r_percept1(t): 188 | actor, dist, horz = t[1], t[3], t[4] 189 | 190 | def fn(world, n): 191 | if n > MAX_FUNC_CALL: return world, n, False, False 192 | c = world.exist_actor_in_distance_horizontal(actor(), dist(), horz()) 193 | return world, n, True, c 194 | return [('percept', fn)] 195 | rules.append(('EXIST actor IN distance horizontal', r_percept1)) 196 | 197 | 198 | def r_percept2(t): 199 | actor = t[1] 200 | 201 | def fn(world, n): 202 | if n > MAX_FUNC_CALL: return world, n, False, False 203 | c = world.in_target(actor()) 204 | return world, n, True, c 205 | return [('percept', fn)] 206 | rules.append(('INTARGET actor', r_percept2)) 207 | 208 | 209 | def r_percept3(t): 210 | actor = t[1] 211 | 212 | def fn(world, n): 213 | if n > MAX_FUNC_CALL: return world, n, False, False 214 | c = world.is_there(actor()) 215 | return world, n, True, c 216 | return [('percept', fn)] 217 | rules.append(('ISTHERE actor', r_percept3)) 218 | 219 | 220 | def r_actor1(t): 221 | return [('actor', t[0])] 222 | rules.append(('monster', r_actor1)) 223 | 224 | 225 | def create_r_monster(monster): 226 | def r_monster(t): 227 | return [('monster', lambda: monster)] 228 | return r_monster 229 | for monster in MONSTER_LIST: 230 | rules.append((monster, create_r_monster(monster))) 231 | 232 | 233 | def r_actor2(t): 234 | return [('actor', t[0])] 235 | rules.append(('items', r_actor2)) 236 | 237 | 238 | def create_r_item(item): 239 | def r_item(t): 240 | return [('items', lambda: item)] 241 | return r_item 242 | for item in ITEMS_IN_INTEREST: 243 | rules.append((item, create_r_item(item))) 244 | 245 | 246 | def create_r_distance(distance): 247 | def r_distance(t): 248 | return [('distance', lambda: distance)] 249 | return r_distance 250 | for distance in merge_distance_vocab: 251 | rules.append((distance, create_r_distance(distance))) 252 | 253 | 254 | def create_r_horizontal(horizontal): 255 | def r_horizontal(t): 256 | return [('horizontal', lambda: horizontal)] 257 | return r_horizontal 258 | for horizontal in merge_horizontal_vocab: 259 | rules.append((horizontal, create_r_horizontal(horizontal))) 260 | 261 | 262 | def create_r_slot(slot_number): 263 | def r_slot(t): 264 | return [('slot', lambda: slot_number)] 265 | return r_slot 266 | for slot_number in range(1, 7): 267 | rules.append(('S={}'.format(slot_number), create_r_slot(slot_number))) 268 | 269 | 270 | def create_r_action(action): 271 | def r_action(t): 272 | def fn(world, n): 273 | if n > MAX_FUNC_CALL: world, n, False 274 | try: world.state_transition(action) 275 | except: return world, n, False 276 | else: return world, n, True 277 | return [('action', fn)] 278 | return r_action 279 | for action in ACTION_LIST: 280 | rules.append((action, create_r_action(action))) 281 | 282 | 283 | def create_r_cste(number): 284 | def r_cste(t): 285 | return [('cste', lambda: number)] 286 | return r_cste 287 | for i in range(20): 288 | rules.append(('R={}'.format(i), create_r_cste(i))) 289 | 290 | 291 | def parse(program): 292 | p_tokens = program.split()[::-1] 293 | queue = [] 294 | applied = False 295 | while len(p_tokens) > 0 or len(queue) != 1: 296 | if applied: applied = False 297 | else: 298 | queue.append((p_tokens.pop(), None)) 299 | for rule in rules: 300 | applied = check_and_apply(queue, rule) 301 | if applied: break 302 | if not applied and len(p_tokens) == 0: # error parsing 303 | return None, False 304 | return queue[0][1], True 305 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/random_code_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dsl_parse import parse 4 | 5 | stmt_length_range = { 6 | 'span0': (1, 3), # [1, 6] 7 | 'span1': (1, 2), # [1, 2] 8 | 'span2': (1, 2)} # [1, 1] 9 | 10 | rules = {} 11 | rules['prog'] = [] 12 | rules['prog'].append(('DEF run m( stmt0 m)', 1)) 13 | 14 | rules['action'] = [] 15 | rules['action'].append(('MOVE_FORWARD', 0.1)) 16 | rules['action'].append(('MOVE_BACKWARD', 0.1)) 17 | rules['action'].append(('MOVE_LEFT', 0.2)) 18 | rules['action'].append(('MOVE_RIGHT', 0.2)) 19 | rules['action'].append(('TURN_LEFT', 0.1)) 20 | rules['action'].append(('TURN_RIGHT', 0.1)) 21 | rules['action'].append(('ATTACK', 0.2)) 22 | 23 | rules['stmt0'] = [] 24 | rules['stmt0'].append(('action_stmt1', 0.2)) 25 | rules['stmt0'].append(('if_stmt1', 0.25)) 26 | rules['stmt0'].append(('ifelse_stmt1', 0.24)) 27 | rules['stmt0'].append(('while_stmt1', 0.3)) 28 | rules['stmt0'].append(('repeat_stmt1', 0.01)) 29 | 30 | rules['stmt1'] = [] 31 | rules['stmt1'].append(('action_stmt1', 0.2)) 32 | rules['stmt1'].append(('if_stmt1', 0.25)) 33 | rules['stmt1'].append(('ifelse_stmt1', 0.24)) 34 | rules['stmt1'].append(('while_stmt1', 0.3)) 35 | rules['stmt1'].append(('repeat_stmt1', 0.01)) 36 | 37 | rules['stmt2'] = [] 38 | rules['stmt2'].append(('action', 0.8)) 39 | rules['stmt2'].append(('action action', 0.2)) 40 | 41 | rules['action_stmt1'] = [] 42 | rules['action_stmt1'].append(('action', 0.85)) 43 | rules['action_stmt1'].append(('action action', 0.1)) 44 | rules['action_stmt1'].append(('action action action', 0.05)) 45 | 46 | rules['if_stmt1'] = [] 47 | rules['if_stmt1'].append(('IF c( cond c) i( stmt2 i)', 1)) 48 | 49 | rules['ifelse_stmt1'] = [] 50 | rules['ifelse_stmt1'].append( 51 | ('IFELSE c( cond c) i( stmt2 i) ELSE e( stmt2 e)', 1)) 52 | 53 | rules['while_stmt1'] = [] 54 | rules['while_stmt1'].append(('WHILE c( cond c) w( stmt2 w)', 1)) 55 | 56 | rules['repeat_stmt1'] = [] 57 | rules['repeat_stmt1'].append(('REPEAT cste r( stmt2 r)', 1)) 58 | 59 | rules['if_stmt2'] = [] 60 | rules['if_stmt2'].append(('IF c( cond c) i( stmt1 i)', 1)) 61 | 62 | rules['ifelse_stmt2'] = [] 63 | rules['ifelse_stmt2'].append( 64 | ('IFELSE c( cond c) i( stmt1 i) ELSE e( stmt1 e)', 1)) 65 | 66 | rules['while_stmt2'] = [] 67 | rules['while_stmt2'].append(('WHILE c( cond c) w( stmt1 w)', 1)) 68 | 69 | rules['repeat_stmt2'] = [] 70 | rules['repeat_stmt2'].append(('REPEAT cste r( stmt1 r)', 1)) 71 | 72 | rules['cond'] = [] 73 | rules['cond'].append(('not c( percept c)', 0.2)) 74 | rules['cond'].append(('percept', 0.8)) 75 | 76 | rules['cste'] = [] 77 | rules['cste'].append(('R=2', 0.4)) 78 | rules['cste'].append(('R=3', 0.3)) 79 | rules['cste'].append(('R=4', 0.3)) 80 | 81 | 82 | # This generator handles single depth case only 83 | class DoomProgramGenerator(): 84 | def __init__(self, seed=123): 85 | self.rng = np.random.RandomState(seed) 86 | 87 | def get_percepts_value(self, world_list): 88 | percepts_value = [] 89 | for world in world_list: 90 | percepts_value.append(world.get_perception_vector()) 91 | percepts_value = np.stack(percepts_value).astype(np.float) 92 | return percepts_value 93 | 94 | def compute_percepts_prob(self, world_list): 95 | percepts_value = self.get_percepts_value(world_list) 96 | num_demo = float(len(world_list)) 97 | percepts_sum = percepts_value.sum(axis=0) 98 | percepts_diff = (num_demo / 2.0 - abs(num_demo / 2.0 - percepts_sum)) 99 | percepts_diff = percepts_diff ** 2 100 | if percepts_diff.sum() == 0: 101 | percepts_diff[:] += 1e-10 102 | percepts_prob = percepts_diff / percepts_diff.sum() 103 | return percepts_prob 104 | 105 | def random_expand_token(self, token, percepts, world_list, depth=0): 106 | # Expansion 107 | candidates, sample_prob = zip(*rules[token]) 108 | sample_idx = self.rng.choice(range(len(candidates)), p=sample_prob) 109 | expansion = [] 110 | for new_t in candidates[sample_idx].split(): 111 | if new_t in ['stmt0', 'stmt1', 'stmt2']: 112 | stmt_len = self.rng.choice( 113 | range(*stmt_length_range['span{}'.format(depth)])) 114 | expansion.extend([new_t] * stmt_len) 115 | else: expansion.append(new_t) 116 | codes = [] 117 | for t in expansion: 118 | if t in rules: 119 | # Increase nested depth 120 | if t in ['stmt0', 'stmt1', 'stmt2']: 121 | sub_codes, success = self.random_expand_token(t, percepts, world_list, depth + 1) 122 | if not success: 123 | return [], False 124 | codes.extend(sub_codes) 125 | else: 126 | sub_codes, success = self.random_expand_token(t, percepts, world_list, depth) 127 | if not success: 128 | return [], False 129 | codes.extend(sub_codes) 130 | elif t == 'percept': 131 | percepts_prob = self.compute_percepts_prob(world_list) 132 | percept_idx = self.rng.choice(range(len(percepts)), p=percepts_prob) 133 | codes.append(percepts[percept_idx]) 134 | else: codes.append(t) 135 | if token in ['action_stmt1', 'if_stmt1', 'ifelse_stmt1', 136 | 'while_stmt1', 'repeat_stmt1']: 137 | # run new statement to be capable of getting next statements 138 | stmt = ' '.join(codes) 139 | exe, compile_success = parse(stmt) 140 | if not compile_success: 141 | raise RuntimeError('Compile failure should not happen') 142 | for world in world_list: 143 | w, num_call, success = exe(world, 0) 144 | if not success: 145 | return [], False 146 | 147 | return codes, True 148 | 149 | def random_code(self, percepts, world_list): 150 | codes, success = self.random_expand_token('prog', percepts, world_list, depth=0) 151 | return ' '.join(codes), success 152 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/random_code_generator_ifelse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dsl_parse import parse 4 | 5 | stmt_length_range = { 6 | 'span0': (1, 2), # [1, 6] 7 | 'span1': (1, 2), # [1, 2] 8 | 'span2': (1, 2)} # [1, 1] 9 | 10 | rules = {} 11 | rules['prog'] = [] 12 | rules['prog'].append(('DEF run m( stmt0 m)', 1)) 13 | 14 | rules['action'] = [] 15 | rules['action'].append(('MOVE_FORWARD', 0.1)) 16 | rules['action'].append(('MOVE_BACKWARD', 0.1)) 17 | rules['action'].append(('MOVE_LEFT', 0.2)) 18 | rules['action'].append(('MOVE_RIGHT', 0.2)) 19 | rules['action'].append(('TURN_LEFT', 0.1)) 20 | rules['action'].append(('TURN_RIGHT', 0.1)) 21 | rules['action'].append(('ATTACK', 0.1)) 22 | rules['action'].append(('SELECT_WEAPON1', 0.025)) 23 | rules['action'].append(('SELECT_WEAPON3', 0.025)) 24 | rules['action'].append(('SELECT_WEAPON4', 0.025)) 25 | rules['action'].append(('SELECT_WEAPON5', 0.025)) 26 | 27 | rules['stmt0'] = [] 28 | rules['stmt0'].append(('ifelse_stmt1', 1.0)) 29 | 30 | 31 | rules['stmt2'] = [] 32 | rules['stmt2'].append(('action', 1)) 33 | 34 | rules['ifelse_stmt1'] = [] 35 | rules['ifelse_stmt1'].append( 36 | ('IFELSE c( cond c) i( stmt2 i) ELSE e( stmt2 e)', 1)) 37 | 38 | rules['cond'] = [] 39 | rules['cond'].append(('not c( percept c)', 0.2)) 40 | rules['cond'].append(('percept', 0.8)) 41 | 42 | 43 | # This generator handles single depth case only 44 | class DoomProgramGenerator(): 45 | def __init__(self, seed=123): 46 | self.rng = np.random.RandomState(seed) 47 | 48 | def get_percepts_value(self, world_list): 49 | percepts_value = [] 50 | for world in world_list: 51 | percepts_value.append(world.get_perception_vector()) 52 | percepts_value = np.stack(percepts_value).astype(np.float) 53 | return percepts_value 54 | 55 | def compute_percepts_prob(self, world_list): 56 | percepts_value = self.get_percepts_value(world_list) 57 | num_demo = float(len(world_list)) 58 | percepts_sum = percepts_value.sum(axis=0) 59 | percepts_diff = (num_demo / 2.0 - abs(num_demo / 2.0 - percepts_sum)) 60 | percepts_diff = percepts_diff ** 2 61 | if percepts_diff.sum() == 0: 62 | percepts_diff[:] += 1e-10 63 | percepts_prob = percepts_diff / percepts_diff.sum() 64 | return percepts_prob 65 | 66 | def random_expand_token(self, token, percepts, world_list, depth=0): 67 | # Expansion 68 | candidates, sample_prob = zip(*rules[token]) 69 | sample_idx = self.rng.choice(range(len(candidates)), p=sample_prob) 70 | expansion = [] 71 | for new_t in candidates[sample_idx].split(): 72 | if new_t in ['stmt0', 'stmt1', 'stmt2']: 73 | stmt_len = self.rng.choice( 74 | range(*stmt_length_range['span{}'.format(depth)])) 75 | expansion.extend([new_t] * stmt_len) 76 | else: expansion.append(new_t) 77 | codes = [] 78 | for t in expansion: 79 | if t in rules: 80 | # Increase nested depth 81 | if t in ['stmt0', 'stmt1', 'stmt2']: 82 | sub_codes, success = self.random_expand_token(t, percepts, world_list, depth + 1) 83 | if not success: 84 | return [], False 85 | codes.extend(sub_codes) 86 | else: 87 | sub_codes, success = self.random_expand_token(t, percepts, world_list, depth) 88 | if not success: 89 | return [], False 90 | codes.extend(sub_codes) 91 | elif t == 'percept': 92 | percepts_prob = self.compute_percepts_prob(world_list) 93 | percept_idx = self.rng.choice(range(len(percepts)), p=percepts_prob) 94 | codes.append(percepts[percept_idx]) 95 | else: codes.append(t) 96 | if token in ['action_stmt1', 'if_stmt1', 'ifelse_stmt1', 97 | 'while_stmt1', 'repeat_stmt1']: 98 | # run new statement to be capable of getting next statements 99 | stmt = ' '.join(codes) 100 | exe, compile_success = parse(stmt) 101 | if not compile_success: 102 | raise RuntimeError('Compile failure should not happen') 103 | for world in world_list: 104 | w, num_call, success = exe(world, 0) 105 | if not success: 106 | return [], False 107 | 108 | return codes, True 109 | 110 | def random_code(self, percepts, world_list): 111 | codes, success = self.random_expand_token('prog', percepts, world_list, depth=0) 112 | return ' '.join(codes), success 113 | -------------------------------------------------------------------------------- /vizdoom_env/dsl/vocab.py: -------------------------------------------------------------------------------- 1 | from dsl_parse import MONSTER_LIST, ITEMS_IN_INTEREST, ACTION_LIST, \ 2 | DISTANCE_DICT, HORIZONTAL_DICT, CLEAR_DISTANCE_DICT, CLEAR_HORIZONTAL_DICT 3 | 4 | SIMPLE_ACTION_LIST = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 'MOVE_RIGHT', 5 | 'TURN_LEFT', 'TURN_RIGHT', 'ATTACK'] 6 | SIMPLE_PROGRAM_TOKENS = ['DEF', 'run', 'm(', 'm)', 'WHILE', 'c(', 'c)', 7 | 'w(', 'w)', 'IF', 'i(', 'i)', 'IFELSE', 'ELSE', 8 | 'e(', 'e)', 'not', 'EXIST', 'IN', 'INTARGET'] 9 | 10 | PROGRAM_TOKENS = ['DEF', 'run', 'm(', 'm)', 'WHILE', 'c(', 'c)', 'w(', 'w)', 11 | 'REPEAT', 'r(', 'r)', 'R=2', 'R=3', 'R=4', 'R=5', 'R=6', 12 | 'IF', 'i(', 'i)', 'IFELSE', 'ELSE', 'e(', 'e)', 'not', 'EXIST', 'IN', 'INTARGET', 13 | 'ISTHERE'] 14 | 15 | 16 | class VizDoomDSLVocab(object): 17 | def __init__(self, perception_type='clear', level='not_simple'): 18 | if perception_type == 'clear': 19 | distance_vocab = CLEAR_DISTANCE_DICT.keys() 20 | horizontal_vocab = CLEAR_HORIZONTAL_DICT.keys() 21 | elif perception_type == 'simple' or perception_type == 'more_simple': 22 | distance_vocab = [] 23 | horizontal_vocab = [] 24 | else: 25 | distance_vocab = DISTANCE_DICT.keys() 26 | horizontal_vocab = HORIZONTAL_DICT.keys() 27 | if level == 'simple': 28 | action_list = SIMPLE_ACTION_LIST 29 | program_tokens = SIMPLE_PROGRAM_TOKENS 30 | elif perception_type == 'simple': 31 | action_list = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 32 | 'MOVE_RIGHT', 'TURN_LEFT', 'TURN_RIGHT', 33 | 'ATTACK', 'SELECT_WEAPON1', 'SELECT_WEAPON3', 34 | 'SELECT_WEAPON4', 'SELECT_WEAPON5'] 35 | program_tokens = ['DEF', 'run', 'm(', 'm)', 'WHILE', 'c(', 'c)', 36 | 'w(', 'w)', 'REPEAT', 'r(', 'r)', 'R=2', 'R=3', 37 | 'R=4', 'R=5', 'R=6', 'IF', 'i(', 'i)', 38 | 'IFELSE', 'ELSE', 'e(', 'e)', 'not', 39 | 'INTARGET', 'ISTHERE'] 40 | elif perception_type == 'more_simple': 41 | action_list = ['MOVE_FORWARD', 'MOVE_BACKWARD', 'MOVE_LEFT', 42 | 'MOVE_RIGHT', 'TURN_LEFT', 'TURN_RIGHT', 43 | 'ATTACK', 'SELECT_WEAPON1', 'SELECT_WEAPON3', 44 | 'SELECT_WEAPON4', 'SELECT_WEAPON5'] 45 | program_tokens = ['DEF', 'run', 'm(', 'm)', 'WHILE', 'c(', 'c)', 46 | 'w(', 'w)', 'REPEAT', 'r(', 'r)', 'R=2', 'R=3', 47 | 'R=4', 'R=5', 'R=6', 'IF', 'i(', 'i)', 48 | 'IFELSE', 'ELSE', 'e(', 'e)', 'not', 49 | 'ISTHERE'] 50 | else: 51 | action_list = ACTION_LIST 52 | program_tokens = PROGRAM_TOKENS 53 | self.int2token = program_tokens + action_list + distance_vocab +\ 54 | horizontal_vocab + MONSTER_LIST + ITEMS_IN_INTEREST 55 | self.token2int = {v: i for i, v in enumerate(self.int2token)} 56 | 57 | self.action_int2token = action_list 58 | self.action_token2int = {v: i for i, v in enumerate(self.action_int2token)} 59 | 60 | def str2intseq(self, string): 61 | return [self.token2int[t] for t in string.split()] 62 | 63 | def strlist2intseq(self, strlist): 64 | return [self.token2int[t] for t in strlist] 65 | 66 | def intseq2str(self, intseq): 67 | return ' '.join([self.int2token[i] for i in intseq]) 68 | 69 | def token_dim(self): 70 | return len(self.int2token) 71 | 72 | def action_str2intseq(self, string): 73 | return [self.action_token2int[t] for t in string.split()] 74 | 75 | def action_intseq2str(self, intseq): 76 | return ' '.join([self.action_int2token[i] for i in intseq]) 77 | 78 | def action_token_dim(self): 79 | return len(self.action_int2token) 80 | 81 | def action_strlist2intseq(self, strlist): 82 | return [self.action_token2int[t] for t in strlist] 83 | -------------------------------------------------------------------------------- /vizdoom_env/generate_dataset.sh: -------------------------------------------------------------------------------- 1 | # Generate small vizdoom datasets with different seeds 2 | # Note that we first generate small datasets to make generation parallelizable. 3 | # We generate datasets with 40 seen demonstrations for generalization experiments, 4 | # but we used only 25 seen demonstrations for training as described in the paper. 5 | python vizdoom_world/generator.py --max_demo_length 8 --seed 123 6 | python vizdoom_world/generator.py --max_demo_length 8 --seed 234 7 | python vizdoom_world/generator.py --max_demo_length 8 --seed 345 8 | python vizdoom_world/generator.py --max_demo_length 8 --seed 456 9 | python vizdoom_world/generator.py --max_demo_length 8 --seed 567 10 | python vizdoom_world/generator.py --max_demo_length 8 --seed 678 11 | python vizdoom_world/generator.py --max_demo_length 8 --seed 789 12 | python vizdoom_world/generator.py --max_demo_length 8 --seed 890 13 | python vizdoom_world/generator.py --max_demo_length 20 --seed 234 14 | python vizdoom_world/generator.py --max_demo_length 20 --seed 789 15 | # Merge datasets 16 | python vizdoom_world/merge_datasets.py --dataset_paths\ 17 | datasets/vizdoom_small_len8_seed123 \ 18 | datasets/vizdoom_small_len8_seed234 \ 19 | datasets/vizdoom_small_len8_seed345 \ 20 | datasets/vizdoom_small_len8_seed456 \ 21 | datasets/vizdoom_small_len8_seed567 \ 22 | datasets/vizdoom_small_len8_seed678 \ 23 | datasets/vizdoom_small_len8_seed789 \ 24 | datasets/vizdoom_small_len8_seed890 \ 25 | datasets/vizdoom_small_len20_seed234 \ 26 | datasets/vizdoom_small_len20_seed789 27 | 28 | # For if-else experiments, you can generate datasets with the following script 29 | python vizdoom_word/generator_ifelse.py 30 | -------------------------------------------------------------------------------- /vizdoom_env/generator_ifelse.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import h5py 6 | import os 7 | import argparse 8 | import numpy as np 9 | 10 | from cv2 import resize, INTER_AREA 11 | from tqdm import tqdm 12 | 13 | from vizdoom_env import Vizdoom_env 14 | from dsl.dsl_parse import parse as vizdoom_parse 15 | from dsl.random_code_generator_ifelse import DoomProgramGenerator 16 | from dsl.vocab import VizDoomDSLVocab 17 | from util import log 18 | 19 | 20 | class DoomStateGenerator(object): 21 | def __init__(self, seed=None): 22 | self.rng = np.random.RandomState(seed) 23 | self.x_max = 64 24 | self.x_min = -480 25 | self.y_max = 480 26 | self.y_min = 64 27 | 28 | def gen_rand_pos(self): 29 | return [self.rng.randint(self.x_min, self.x_max), 30 | self.rng.randint(self.y_min, self.y_max)] 31 | 32 | def get_pos_keys(self): 33 | return ['player_pos', 'demon_pos', 'hellknight_pos', 34 | 'revenant_pos', 'ammo_pos'] 35 | 36 | # generate an initial env 37 | def generate_initial_state(self, min_ammo=4, max_ammo=5, 38 | min_monster=4, max_monster=5): 39 | """ h is y, w is x 40 | s = [{"player_pos": [x, y], "monster_pos": [[x1, y1], [x2, y2]]}] 41 | """ 42 | s = {} 43 | locs = [] 44 | s["player_pos"] = self.gen_rand_pos() 45 | s["demon_pos"] = [] 46 | s["hellknight_pos"] = [] 47 | s["revenant_pos"] = [] 48 | s["ammo_pos"] = [] 49 | locs.append(s["player_pos"]) 50 | 51 | ammo_count = self.rng.randint(min_ammo, max_ammo + 1) 52 | demon_count = self.rng.randint(min_monster, max_monster + 1) 53 | hellknight_count = self.rng.randint(min_monster, max_monster + 1) 54 | revenant_count = self.rng.randint(min_monster, max_monster + 1) 55 | while(revenant_count > 0): 56 | new_pos = self.gen_rand_pos() 57 | if new_pos not in locs: 58 | s["revenant_pos"].append(new_pos) 59 | locs.append(new_pos) 60 | revenant_count -= 1 61 | 62 | while(hellknight_count > 0): 63 | new_pos = self.gen_rand_pos() 64 | if new_pos not in locs: 65 | s["hellknight_pos"].append(new_pos) 66 | locs.append(new_pos) 67 | hellknight_count -= 1 68 | 69 | while(demon_count > 0): 70 | new_pos = self.gen_rand_pos() 71 | if new_pos not in locs: 72 | s["demon_pos"].append(new_pos) 73 | locs.append(new_pos) 74 | demon_count -= 1 75 | 76 | while(ammo_count > 0): 77 | new_pos = self.gen_rand_pos() 78 | if new_pos not in locs: 79 | s["ammo_pos"].append(new_pos) 80 | locs.append(new_pos) 81 | ammo_count -= 1 82 | return s 83 | 84 | 85 | def downsize(img, h=80, w=80): 86 | image_resize = resize(img, (h, w), interpolation=INTER_AREA) 87 | return image_resize 88 | 89 | 90 | def generator(config): 91 | dir_name = config.dir_name 92 | 93 | image_dir = os.path.join(dir_name, 'images') 94 | check_path(image_dir) 95 | 96 | num_train = config.num_train 97 | num_test = config.num_test 98 | num_val = config.num_val 99 | num_total = num_train + num_test + num_val 100 | 101 | # output files 102 | f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'w') 103 | id_file = open(os.path.join(dir_name, 'id.txt'), 'w') 104 | 105 | num_demo = config.num_demo_per_program + config.num_test_demo_per_program 106 | world_list = [] 107 | log.info('Initializing {} vizdoom environments...'.format(num_demo)) 108 | for _ in range(num_demo): 109 | log.info('[{}/{}]'.format(_, num_demo)) 110 | world = Vizdoom_env(config="vizdoom_env/asset/default.cfg", 111 | perception_type='simple') 112 | world.init_game() 113 | world_list.append(world) 114 | h = config.height 115 | w = config.width 116 | c = world_list[0].channel 117 | 118 | gen = DoomStateGenerator(seed=config.seed) 119 | prog_gen = DoomProgramGenerator(seed=config.seed) 120 | 121 | percepts = world_list[0].get_perception_vector_cond() 122 | vizdoom_vocab = VizDoomDSLVocab( 123 | perception_type='simple') 124 | 125 | count = 0 126 | max_demo_length_in_dataset = -1 127 | max_program_length_in_dataset = -1 128 | pos_keys = gen.get_pos_keys() 129 | max_init_poslen = -1 130 | pbar = tqdm(total=num_total) 131 | while True: 132 | init_states = [] 133 | for world in world_list: 134 | init_states.append(gen.generate_initial_state()) 135 | world.new_episode(init_states[-1]) 136 | 137 | program, gen_success = prog_gen.random_code( 138 | percepts, world_list[:config.num_demo_per_program]) 139 | if not gen_success: 140 | continue 141 | if len(program.split()) > config.max_program_length: 142 | continue 143 | 144 | program_seq = np.array(vizdoom_vocab.str2intseq(program), dtype=np.int8) 145 | 146 | exe, compile_success = vizdoom_parse(program) 147 | if not compile_success: 148 | print('compile failure') 149 | print('program: {}'.format(program)) 150 | raise RuntimeError('Program compile failure should not happen') 151 | 152 | all_success = True 153 | for k, world in enumerate(world_list[config.num_demo_per_program:]): 154 | idx = k + config.num_demo_per_program 155 | world.new_episode(init_states[idx]) 156 | new_w, num_call, success = exe(world, 0) 157 | if not success or len(world.s_h) < config.min_demo_length \ 158 | or len(world.s_h) > config.max_demo_length: 159 | all_success = False 160 | break 161 | if not all_success: continue 162 | 163 | s_h_len_fail = False 164 | for world in world_list: 165 | if len(world.s_h) < config.min_demo_length or \ 166 | len(world.s_h) > config.max_demo_length: 167 | s_h_len_fail = True 168 | if s_h_len_fail: continue 169 | 170 | program_seq = np.array(vizdoom_vocab.str2intseq(program), dtype=np.int8) 171 | 172 | s_h_list = [] 173 | a_h_list = [] 174 | p_v_h_list = [] 175 | for k, world in enumerate(world_list): 176 | s_h_list.append(np.stack(world.s_h, axis=0).copy()) 177 | a_h_list.append(np.array( 178 | vizdoom_vocab.action_strlist2intseq(world.a_h))) 179 | p_v_h_list.append(np.stack(world.p_v_h, axis=0).copy()) 180 | 181 | len_s_h = np.array([s_h.shape[0] for s_h in s_h_list], dtype=np.int16) 182 | 183 | demos_s_h = np.zeros([num_demo, np.max(len_s_h), h, w, c], dtype=np.int16) 184 | for i, s_h in enumerate(s_h_list): 185 | downsize_s_h = [] 186 | for t, s in enumerate(s_h): 187 | if s.shape[0] != h or s.shape[1] != w: 188 | s = downsize(s, h, w) 189 | downsize_s_h.append(s.copy()) 190 | demos_s_h[i, :s_h.shape[0]] = np.stack(downsize_s_h, 0) 191 | 192 | len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16) 193 | 194 | demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8) 195 | for i, a_h in enumerate(a_h_list): 196 | demos_a_h[i, :a_h.shape[0]] = a_h 197 | 198 | demos_p_v_h = np.zeros([num_demo, np.max(len_s_h), len(percepts)], dtype=np.bool) 199 | for i, p_v in enumerate(p_v_h_list): 200 | demos_p_v_h[i, :p_v.shape[0]] = p_v 201 | 202 | max_demo_length_in_dataset = max( 203 | max_demo_length_in_dataset, np.max(len_s_h)) 204 | max_program_length_in_dataset = max( 205 | max_program_length_in_dataset, program_seq.shape[0]) 206 | 207 | # save the state 208 | id = 'no_{}_prog_len_{}_max_s_h_len_{}'.format( 209 | count, program_seq.shape[0], np.max(len_s_h)) 210 | id_file.write(id+'\n') 211 | 212 | # data: [# demo, # pos_key, max(# pos), 2] 213 | # len: [# demo, #pos_key] 214 | np_init_states = {} 215 | np_init_state_len = {} 216 | pos_key_maxlen = -1 217 | for k in pos_keys: 218 | np_init_states[k] = [] 219 | np_init_state_len[k] = [] 220 | for s in init_states: 221 | np_pos = np.array(s[k], dtype=np.int32) 222 | if np_pos.ndim == 1: 223 | np_pos = np.expand_dims(np_pos, axis=0) 224 | np_init_states[k].append(np_pos) 225 | np_init_state_len[k].append(np_pos.shape[0]) 226 | pos_key_maxlen = max(pos_key_maxlen, np_pos.shape[0]) 227 | max_init_poslen = max(max_init_poslen, pos_key_maxlen) 228 | 229 | # 3rd dimension is 2 as they are positions 230 | np_merged_init_states = np.zeros([num_demo, len(pos_keys), 231 | pos_key_maxlen, 2], 232 | dtype=np.int32) 233 | merged_pos_len = [] 234 | for p, key in enumerate(pos_keys): 235 | single_key_pos_len = [] 236 | for k, state in enumerate(np_init_states[key]): 237 | np_merged_init_states[k, p, :state.shape[0]] = state 238 | single_key_pos_len.append(state.shape[0]) 239 | merged_pos_len.append(np.array(single_key_pos_len, dtype=np.int32)) 240 | np_merged_pos_len = np.stack(merged_pos_len, axis=1) 241 | 242 | grp = f.create_group(id) 243 | grp['program'] = program_seq 244 | grp['s_h_len'] = len_s_h[:config.num_demo_per_program] 245 | grp['s_h'] = demos_s_h[:config.num_demo_per_program] 246 | grp['a_h_len'] = len_a_h[:config.num_demo_per_program] 247 | grp['a_h'] = demos_a_h[:config.num_demo_per_program] 248 | grp['p_v_h'] = demos_p_v_h[:config.num_demo_per_program] 249 | grp['test_s_h_len'] = len_s_h[config.num_demo_per_program:] 250 | grp['test_s_h'] = demos_s_h[config.num_demo_per_program:] 251 | grp['test_a_h_len'] = len_a_h[config.num_demo_per_program:] 252 | grp['test_a_h'] = demos_a_h[config.num_demo_per_program:] 253 | grp['test_p_v_h'] = demos_p_v_h[config.num_demo_per_program:] 254 | grp['vizdoom_init_pos'] = \ 255 | np_merged_init_states[:config.num_demo_per_program] 256 | grp['vizdoom_init_pos_len'] = \ 257 | np_merged_pos_len[:config.num_demo_per_program] 258 | grp['test_vizdoom_init_pos'] = \ 259 | np_merged_init_states[config.num_demo_per_program:] 260 | grp['test_vizdoom_init_pos_len'] = \ 261 | np_merged_pos_len[config.num_demo_per_program:] 262 | 263 | count += 1 264 | pbar.update(1) 265 | if count >= num_total: 266 | grp = f.create_group('data_info') 267 | grp['max_demo_length'] = max_demo_length_in_dataset 268 | grp['max_program_length'] = max_program_length_in_dataset 269 | grp['num_program_tokens'] = len(vizdoom_vocab.int2token) 270 | grp['num_demo_per_program'] = config.num_demo_per_program 271 | grp['num_test_demo_per_program'] = config.num_test_demo_per_program 272 | grp['num_action_tokens'] = len(vizdoom_vocab.action_int2token) 273 | grp['num_train'] = config.num_train 274 | grp['num_test'] = config.num_test 275 | grp['num_val'] = config.num_val 276 | grp['s_h_h'] = h 277 | grp['s_h_w'] = w 278 | grp['s_h_c'] = c 279 | grp['percepts'] = percepts 280 | grp['vizdoom_pos_keys'] = pos_keys 281 | grp['vizdoom_max_init_pos_len'] = max_init_poslen 282 | grp['perception_type'] = 'simple' 283 | f.close() 284 | id_file.close() 285 | print('Dataset generated under {} with {}' 286 | ' samples ({} for training and {} for testing ' 287 | 'and {} for val'.format(dir_name, num_total, 288 | num_train, num_test, num_val)) 289 | pbar.close() 290 | return 291 | 292 | 293 | def check_path(path): 294 | if not os.path.exists(path): 295 | os.makedirs(path) 296 | 297 | 298 | def main(): 299 | parser = argparse.ArgumentParser( 300 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 301 | parser.add_argument('--dir_name', type=str, default='vizdoom_dataset_ifelse') 302 | parser.add_argument('--num_train', type=int, default=10000) 303 | parser.add_argument('--num_test', type=int, default=1000) 304 | parser.add_argument('--num_val', type=int, default=100) 305 | parser.add_argument('--seed', type=int, default=123) 306 | parser.add_argument('--max_program_length', type=int, default=19) 307 | parser.add_argument('--min_demo_length', type=int, default=2) 308 | parser.add_argument('--max_demo_length', type=int, default=2) 309 | parser.add_argument('--num_demo_per_program', type=int, default=40) 310 | parser.add_argument('--num_test_demo_per_program', type=int, default=10) 311 | parser.add_argument('--width', type=int, default=80) 312 | parser.add_argument('--height', type=int, default=80) 313 | args = parser.parse_args() 314 | 315 | args.dir_name = os.path.join('datasets/', args.dir_name) 316 | check_path('datasets') 317 | check_path(args.dir_name) 318 | 319 | generator(args) 320 | 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /vizdoom_env/input_ops_vizdoom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from util import log 5 | 6 | 7 | def check_data_id(dataset, data_id): 8 | if not data_id: 9 | return 10 | 11 | wrong = [] 12 | for id in data_id: 13 | if id in dataset.data: 14 | pass 15 | else: 16 | wrong.append(id) 17 | 18 | if len(wrong) > 0: 19 | raise RuntimeError("There are %d invalid ids, including %s" % ( 20 | len(wrong), wrong[:5] 21 | )) 22 | 23 | 24 | def create_input_ops(dataset, 25 | batch_size, 26 | num_threads=16, # for creating batches 27 | is_training=False, 28 | data_id=None, 29 | scope='inputs', 30 | shuffle=True, 31 | ): 32 | ''' 33 | Return a batched tensor for the inputs from the dataset. 34 | ''' 35 | input_ops = {} 36 | 37 | if data_id is None: 38 | data_id = dataset.ids 39 | log.info("input_ops [%s]: Using %d IDs from dataset", scope, len(data_id)) 40 | else: 41 | log.info("input_ops [%s]: Using specified %d IDs", scope, len(data_id)) 42 | 43 | # single operations 44 | with tf.device("/cpu:0"), tf.name_scope(scope): 45 | input_ops['id'] = tf.train.string_input_producer( 46 | tf.convert_to_tensor(data_id), capacity=128 47 | ).dequeue(name='input_ids_dequeue') 48 | 49 | p, pt, s, ts, a, at, ta, tat, pl, dl, tdl, per, tper, \ 50 | ip, ipl, tip, tipl = dataset.get_data(data_id[0]) 51 | 52 | def load_fn(id): 53 | # program [n, max_program_len] 54 | # program_tokens [max_program_len] 55 | # s_h [k, max_demo_len, h, w, 16] 56 | # test_s_h [test_k, max_demo_len, h, w, 16] 57 | # a_h [k, max_demo_len - 1, ac] 58 | # a_h_tokens [k, max_demo_len - 1] 59 | # test_a_h [test_k, max_demo_len - 1, ac] 60 | # test_a_h_tokens [test_k, max_demo_len - 1] 61 | # program_len [1] 62 | # demo_len [k] 63 | # test_demo_len [k] 64 | # per [k, t, c] 65 | # test_per [test_k, t, c] 66 | program, program_tokens, s_h, test_s_h, a_h, a_h_tokens, \ 67 | test_a_h, test_a_h_tokens, program_len, demo_len, test_demo_len, \ 68 | per, test_per, init_pos, init_pos_len, \ 69 | test_init_pos, test_init_pos_len= dataset.get_data(id) 70 | return (id, program.astype(np.float32), program_tokens.astype(np.int32), 71 | s_h.astype(np.float32), test_s_h.astype(np.float32), 72 | a_h.astype(np.float32), a_h_tokens.astype(np.int32), 73 | test_a_h.astype(np.float32), test_a_h_tokens.astype(np.int32), 74 | program_len.astype(np.float32), demo_len.astype(np.float32), 75 | test_demo_len.astype(np.float32), 76 | per.astype(np.float32), test_per.astype(np.float32), 77 | init_pos.astype(np.int32), init_pos_len.astype(np.int32), 78 | test_init_pos.astype(np.int32), test_init_pos_len.astype(np.int32)) 79 | 80 | input_ops['id'], input_ops['program'], input_ops['program_tokens'], \ 81 | input_ops['s_h'], input_ops['test_s_h'], \ 82 | input_ops['a_h'], input_ops['a_h_tokens'], \ 83 | input_ops['test_a_h'], input_ops['test_a_h_tokens'], \ 84 | input_ops['program_len'], input_ops['demo_len'], \ 85 | input_ops['test_demo_len'], input_ops['per'], input_ops['test_per'], \ 86 | input_ops['init_pos'], input_ops['init_pos_len'], \ 87 | input_ops['test_init_pos'], input_ops['test_init_pos_len'] = tf.py_func( 88 | load_fn, inp=[input_ops['id']], 89 | Tout=[tf.string, tf.float32, tf.int32, tf.float32, tf.float32, 90 | tf.float32, tf.int32, tf.float32, tf.int32, 91 | tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, 92 | tf.int32, tf.int32, tf.int32, tf.int32], 93 | name='func_hp' 94 | ) 95 | 96 | input_ops['id'].set_shape([]) 97 | input_ops['program'].set_shape(list(p.shape)) 98 | input_ops['program_tokens'].set_shape(list(pt.shape)) 99 | input_ops['s_h'].set_shape(list(s.shape)) 100 | input_ops['test_s_h'].set_shape(list(ts.shape)) 101 | input_ops['a_h'].set_shape(list(a.shape)) 102 | input_ops['a_h_tokens'].set_shape(list(at.shape)) 103 | input_ops['test_a_h'].set_shape(list(ta.shape)) 104 | input_ops['test_a_h_tokens'].set_shape(list(tat.shape)) 105 | input_ops['program_len'].set_shape(list(pl.shape)) 106 | input_ops['demo_len'].set_shape(list(dl.shape)) 107 | input_ops['test_demo_len'].set_shape(list(tdl.shape)) 108 | input_ops['per'].set_shape(list(per.shape)) 109 | input_ops['test_per'].set_shape(list(tper.shape)) 110 | input_ops['init_pos'].set_shape(list(ip.shape)) 111 | input_ops['init_pos_len'].set_shape(list(ipl.shape)) 112 | input_ops['test_init_pos'].set_shape(list(tip.shape)) 113 | input_ops['test_init_pos_len'].set_shape(list(tipl.shape)) 114 | 115 | # batchify 116 | capacity = 2 * batch_size * num_threads 117 | min_capacity = min(int(capacity * 0.75), 1024) 118 | 119 | if shuffle: 120 | batch_ops = tf.train.shuffle_batch( 121 | input_ops, 122 | batch_size=batch_size, 123 | num_threads=num_threads, 124 | capacity=capacity, 125 | min_after_dequeue=min_capacity, 126 | ) 127 | else: 128 | batch_ops = tf.train.batch( 129 | input_ops, 130 | batch_size=batch_size, 131 | num_threads=num_threads, 132 | capacity=capacity, 133 | ) 134 | 135 | return input_ops, batch_ops 136 | -------------------------------------------------------------------------------- /vizdoom_env/measure_program_fix_accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import editdistance 3 | import h5py 4 | import numpy as np 5 | from cv2 import resize, INTER_AREA 6 | from tqdm import tqdm 7 | from vizdoom_env import Vizdoom_env 8 | from dsl.dsl_hit_analysis import hit_count 9 | from dsl.vocab import VizDoomDSLVocab 10 | 11 | 12 | def downsize(img, h=80, w=80): 13 | image_resize = resize(img, (h, w), interpolation=INTER_AREA) 14 | return image_resize 15 | 16 | parser = argparse.ArgumentParser( 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('--result_file', type=str, default='result.hdf5', help=' ') 19 | parser.add_argument('--data_file', type=str, 20 | default='datasets/vizdoom_dataset/data.hdf5', help=' ') 21 | args = parser.parse_args() 22 | 23 | fr = h5py.File(args.result_file, 'r') 24 | ft = h5py.File(args.data_file, 'r') 25 | 26 | perception_type = ft['data_info']['perception_type'].value 27 | vocab = VizDoomDSLVocab(perception_type=perception_type) 28 | world = Vizdoom_env(config='vizdoom_env/asset/default.cfg', 29 | perception_type=perception_type) 30 | world.init_game() 31 | id_dict = {} 32 | execute_correct = [] 33 | sequence_match = [] 34 | edit_distances = [] 35 | 36 | num_test_demo = ft['data_info']['num_test_demo_per_program'].value 37 | vizdoom_pos_keys = list(ft['data_info']['vizdoom_pos_keys'].value) 38 | for i, id in enumerate(tqdm(fr.keys())): 39 | id_dict[id] = i 40 | prog_len = fr[id]['pred_program_len'].value 41 | program_tokens = np.argmax(fr[id]['pred_program'].value, axis=0)[:prog_len] 42 | program_tokens_str = ''.join([str(t) for t in program_tokens]) 43 | program = vocab.intseq2str(program_tokens) 44 | gt_program_tokens = ft[id]['program'].value 45 | gt_program_tokens_str = ''.join([str(t) for t in gt_program_tokens]) 46 | gt_program = vocab.intseq2str(gt_program_tokens) 47 | 48 | edit_dist = int(editdistance.eval(program_tokens_str, gt_program_tokens_str)) 49 | edit_distances.append(edit_dist) 50 | 51 | sequence_match.append(program == gt_program) 52 | 53 | hit_exe, hit_compile_success = hit_count(program) 54 | if not hit_compile_success: 55 | execute_correct.append(False) 56 | continue 57 | 58 | test_s_h = ft[id]['test_s_h'].value 59 | test_s_h_len = ft[id]['test_s_h_len'].value 60 | init_pos = ft[id]['test_vizdoom_init_pos'].value 61 | init_pos_len = ft[id]['test_vizdoom_init_pos_len'].value 62 | is_correct = True 63 | for k in range(num_test_demo): 64 | init_dict = {} 65 | for p, key in enumerate(vizdoom_pos_keys): 66 | init_dict[key] = np.squeeze( 67 | init_pos[k, p][:init_pos_len[k, p]]) 68 | world.new_episode(init_dict) 69 | hit, num_cal, success = hit_exe(world, 0) 70 | if not success or len(world.s_h) == 1: 71 | is_correct = False 72 | break 73 | if len(world.s_h) != test_s_h_len[k]: 74 | is_correct = False 75 | break 76 | small_s_h = [] 77 | for s in world.s_h: 78 | small_s_h.append(downsize(s, 80, 80)) 79 | small_s_h = np.stack(small_s_h, 0) 80 | if not np.all(test_s_h[k, :test_s_h_len[k]] == small_s_h): 81 | is_correct = False 82 | break 83 | execute_correct.append(is_correct) 84 | 85 | execute_correct = np.array(execute_correct).astype(np.int32) 86 | sequence_match = np.array(sequence_match).astype(np.int32) 87 | edit_distances = np.array(edit_distances).astype(np.int32) 88 | for d in range(20): 89 | seq_acc = np.clip((sequence_match + (edit_distances <= d).astype(np.int32)), 0, 1).mean() 90 | exe_acc = np.clip((execute_correct + (edit_distances <= d).astype(np.int32)), 0, 1).mean() 91 | print('edit distance: {}, seq_acc: {}, exe_acc: {}'.format(d, seq_acc, exe_acc)) 92 | 93 | fr.close() 94 | ft.close() 95 | -------------------------------------------------------------------------------- /vizdoom_env/merge_datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import h5py 7 | import argparse 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | def merge(config): 13 | dir_name = os.path.join('datasets/', config.dir_name) 14 | check_path(dir_name) 15 | 16 | f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'w') 17 | id_file = open(os.path.join(dir_name, 'id.txt'), 'w') 18 | 19 | new_dataset_paths = list(set(config.dataset_paths)) 20 | if len(new_dataset_paths) != len(config.dataset_paths): 21 | raise ValueError('There is overlap in the dataset paths') 22 | 23 | num_train, num_test, num_val = 0, 0, 0 24 | h, w, c = None, None, None 25 | max_demo_length = 0 26 | max_program_length = 0 27 | num_program_tokens = None 28 | num_demo_per_program = 0 29 | num_test_demo_per_program = 0 30 | num_action_tokens = None 31 | percepts = None 32 | vizdoom_pos_keys = None 33 | vizdoom_max_init_pos_len = 0 34 | perception_type = None 35 | print('data_info checking') 36 | for i, dataset_path in enumerate(config.dataset_paths): 37 | print('dataset [{}/{}]'.format(i, len(config.dataset_paths))) 38 | fs = h5py.File(os.path.join(dataset_path, 'data.hdf5'), 'r') 39 | fs_max_demo_length = fs['data_info']['max_demo_length'].value 40 | fs_max_program_length = fs['data_info']['max_program_length'].value 41 | fs_num_program_tokens = fs['data_info']['num_program_tokens'].value 42 | fs_num_demo_per_program = fs['data_info']['num_demo_per_program'].value 43 | fs_num_test_demo_per_program = fs['data_info']['num_test_demo_per_program'].value 44 | fs_num_action_tokens = fs['data_info']['num_action_tokens'].value 45 | fs_num_train = fs['data_info']['num_train'].value 46 | fs_num_test = fs['data_info']['num_test'].value 47 | fs_num_val = fs['data_info']['num_val'].value 48 | fs_h = fs['data_info']['s_h_h'].value 49 | fs_w = fs['data_info']['s_h_w'].value 50 | fs_c = fs['data_info']['s_h_c'].value 51 | fs_percepts = list(fs['data_info']['percepts'].value) 52 | fs_vizdoom_pos_keys = list(fs['data_info']['vizdoom_pos_keys'].value) 53 | fs_vizdoom_max_init_pos_len = fs['data_info']['vizdoom_max_init_pos_len'].value 54 | fs_perception_type = fs['data_info']['perception_type'].value 55 | 56 | max_demo_length = max(max_demo_length, fs_max_demo_length) 57 | max_program_length = max(max_program_length, fs_max_program_length) 58 | if num_program_tokens is None: num_program_tokens = fs_num_program_tokens 59 | elif num_program_tokens != fs_num_program_tokens: 60 | raise ValueError('program token mismatch: {}'.format(dataset_path)) 61 | num_demo_per_program = max(num_demo_per_program, fs_num_demo_per_program) 62 | num_test_demo_per_program = max(num_test_demo_per_program, 63 | fs_num_test_demo_per_program) 64 | if num_action_tokens is None: num_action_tokens = fs_num_action_tokens 65 | elif num_action_tokens != fs_num_action_tokens: 66 | raise ValueError('num action token mismatch: {}'.format(dataset_path)) 67 | num_train += fs_num_train 68 | num_test += fs_num_test 69 | num_val += fs_num_val 70 | if h is None: h = fs_h 71 | elif h != fs_h: raise ValueError('image height mismatch: {}'.format(dataset_path)) 72 | if w is None: w = fs_w 73 | elif w != fs_w: raise ValueError('image width mismatch: {}'.format(dataset_path)) 74 | if c is None: c = fs_c 75 | elif c != fs_c: raise ValueError('image channel mismatch: {}'.format(dataset_path)) 76 | if percepts is None: percepts = fs_percepts 77 | elif percepts != fs_percepts: 78 | raise ValueError('percepts mismatch: {}'.format(dataset_path)) 79 | if vizdoom_pos_keys is None: vizdoom_pos_keys = fs_vizdoom_pos_keys 80 | elif vizdoom_pos_keys != fs_vizdoom_pos_keys: 81 | raise ValueError('vizdoom_pos_keys mismatch: {}'.format(dataset_path)) 82 | vizdoom_max_init_pos_len = max(vizdoom_max_init_pos_len, fs_vizdoom_max_init_pos_len) 83 | if perception_type is None: perception_type = fs_perception_type 84 | elif perception_type != fs_perception_type: 85 | raise ValueError('perception_type mismatch: {}'.format(dataset_path)) 86 | fs.close() 87 | print('copy data') 88 | for i, dataset_path in enumerate(config.dataset_paths): 89 | print('dataset [{}/{}]'.format(i, len(config.dataset_paths))) 90 | fs = h5py.File(os.path.join(dataset_path, 'data.hdf5'), 'r') 91 | ids = open(os.path.join(dataset_path, 'id.txt'), 92 | 'r').read().splitlines() 93 | for id in tqdm(ids): 94 | new_id = '{}_{}'.format(i, id) 95 | 96 | id_file.write(new_id+'\n') 97 | grp = f.create_group(new_id) 98 | for key in fs[id].keys(): 99 | grp[key] = fs[id][key].value 100 | fs.close() 101 | grp = f.create_group('data_info') 102 | grp['max_demo_length'] = max_demo_length 103 | grp['max_program_length'] = max_program_length 104 | grp['num_program_tokens'] = num_program_tokens 105 | grp['num_demo_per_program'] = num_demo_per_program 106 | grp['num_test_demo_per_program'] = num_test_demo_per_program 107 | grp['num_action_tokens'] = num_action_tokens 108 | grp['num_train'] = num_train 109 | grp['num_test'] = num_test 110 | grp['num_val'] = num_val 111 | grp['s_h_h'] = h 112 | grp['s_h_w'] = w 113 | grp['s_h_c'] = c 114 | grp['percepts'] = percepts 115 | grp['vizdoom_pos_keys'] = vizdoom_pos_keys 116 | grp['vizdoom_max_init_pos_len'] = vizdoom_max_init_pos_len 117 | grp['perception_type'] = perception_type 118 | f.close() 119 | id_file.close() 120 | print('Dataset generated under {} with {}' 121 | ' samples ({} for training and {} for testing ' 122 | 'and {} for val'.format(dir_name, num_train + num_test + num_val, 123 | num_train, num_test, num_val)) 124 | 125 | 126 | def check_path(path): 127 | if not os.path.exists(path): 128 | os.makedirs(path) 129 | else: 130 | raise ValueError('Be careful, you are trying to overwrite some dir') 131 | 132 | 133 | def get_args(): 134 | parser = argparse.ArgumentParser( 135 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 136 | parser.add_argument('--dir_name', type=str, default='vizdoom_dataset') 137 | parser.add_argument('--dataset_paths', nargs='+', 138 | help='list of existing dataset paths') 139 | args = parser.parse_args() 140 | return args 141 | 142 | 143 | if __name__ == '__main__': 144 | args = get_args() 145 | 146 | merge(args) 147 | -------------------------------------------------------------------------------- /vizdoom_env/util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | # Logging 8 | # ======= 9 | 10 | import logging 11 | from colorlog import ColoredFormatter 12 | 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | 16 | formatter = ColoredFormatter( 17 | "%(log_color)s[%(asctime)s] %(message)s", 18 | # datefmt='%H:%M:%S.%f', 19 | datefmt=None, 20 | reset=True, 21 | log_colors={ 22 | 'DEBUG': 'cyan', 23 | 'INFO': 'white,bold', 24 | 'INFOV': 'cyan,bold', 25 | 'WARNING': 'yellow', 26 | 'ERROR': 'red,bold', 27 | 'CRITICAL': 'red,bg_white', 28 | }, 29 | secondary_log_colors={}, 30 | style='%' 31 | ) 32 | ch.setFormatter(formatter) 33 | 34 | log = logging.getLogger('rn') 35 | log.setLevel(logging.DEBUG) 36 | log.handlers = [] # No duplicated handlers 37 | log.propagate = False # workaround for duplicated logs in ipython 38 | log.addHandler(ch) 39 | 40 | logging.addLevelName(logging.INFO + 1, 'INFOV') 41 | def _infov(self, msg, *args, **kwargs): 42 | self.log(logging.INFO + 1, msg, *args, **kwargs) 43 | 44 | logging.Logger.infov = _infov 45 | --------------------------------------------------------------------------------