├── .gitignore ├── LICENSE ├── README.md ├── asset ├── leaps_acc_L.PNG ├── leaps_acc_P.PNG └── leaps_model.jpeg ├── docs ├── index.html └── resources │ ├── DNN_to_LEAPS.png │ ├── LEAPS_generalization_fig.png │ ├── LEAPS_generalization_table.png │ ├── LEAPS_perf_table.png │ ├── LEAPS_slides.pdf │ ├── cem.gif │ ├── clvrbanner.png │ ├── drl_karel_gifs │ ├── cleanHouse.gif │ ├── fourCorner.gif │ ├── harvester.gif │ ├── randomMaze.gif │ ├── stairClimber.gif │ └── topOff.gif │ ├── favicon-32x32.png │ ├── human_interpretability.mp4 │ ├── karel_gifs │ ├── gt_cleanHouse_14_22.gif │ ├── gt_fourCorners_10_10.gif │ ├── gt_harvester_10_10.gif │ ├── gt_randomMaze_8_8.gif │ ├── gt_stairClimber_10_10.gif │ └── gt_topOff_10_10.gif │ ├── karel_tasks.png │ ├── leaps_embedding_vis.png │ ├── leaps_karel_gifs │ ├── pred_one_for_all_cleanHouse_14_22.gif │ ├── pred_one_for_all_fourCorners_12_12.gif │ ├── pred_one_for_all_harvester_8_8.gif │ ├── pred_one_for_all_randomMaze_8_8.gif │ ├── pred_one_for_all_stairClimber_12_12.gif │ └── pred_one_for_all_topOff_12_12.gif │ ├── leaps_model.jpeg │ ├── leaps_teaser.jpeg │ └── perf_improvement.png ├── fetch_mapping.py ├── github-assets └── leaps_teaser.jpeg ├── karel_env ├── README.md ├── __init__.py ├── asset │ ├── agent_0.PNG │ ├── agent_1.PNG │ ├── agent_2.PNG │ ├── agent_3.PNG │ ├── blank.PNG │ ├── marker.PNG │ ├── texture.hdf5 │ ├── texture.pkl │ └── wall.PNG ├── dsl │ ├── __init__.py │ ├── _parsetab.py │ ├── dsl_base.py │ ├── dsl_data.py │ ├── dsl_parse.py │ ├── dsl_parse_and_trace.py │ ├── dsl_prob.py │ ├── parser.out │ └── third_party │ │ ├── __init__.py │ │ └── yacc.py ├── generator.py ├── karel.py ├── karel_supervised.py ├── tool │ └── syntax_checker.py └── util.py ├── mapping_karel2prl.txt ├── pretrain ├── BaseModel.py ├── BaseRLModel.py ├── CEM.py ├── RLModel.py ├── SupervisedModel.py ├── SupervisedRLModel.py ├── cfg.py ├── customargparse.py ├── leaps_cleanhouse.py ├── leaps_fourcorners.py ├── leaps_harvester.py ├── leaps_maze.py ├── leaps_stairclimber.py ├── leaps_topoff.py ├── misc_utils.py ├── models.py └── trainer.py ├── prl_gym ├── README.md ├── __init__.py ├── condition_env.py ├── exec_env.py └── program_env.py ├── requirements.txt ├── rl ├── algo │ ├── a2c_acktr.py │ ├── kfac.py │ ├── ppo.py │ └── reinforce.py ├── baselines │ ├── __init__.py │ ├── bench │ │ ├── __init__.py │ │ ├── benchmarks.py │ │ ├── monitor.py │ │ └── test_monitor.py │ ├── common │ │ ├── __init__.py │ │ ├── atari_wrappers.py │ │ ├── cg.py │ │ ├── cmd_util.py │ │ ├── console_util.py │ │ ├── dataset.py │ │ ├── distributions.py │ │ ├── input.py │ │ ├── math_util.py │ │ ├── misc_util.py │ │ ├── models.py │ │ ├── mpi_adam.py │ │ ├── mpi_adam_optimizer.py │ │ ├── mpi_fork.py │ │ ├── mpi_moments.py │ │ ├── mpi_running_mean_std.py │ │ ├── mpi_util.py │ │ ├── plot_util.py │ │ ├── policies.py │ │ ├── retro_wrappers.py │ │ ├── runners.py │ │ ├── running_mean_std.py │ │ ├── schedules.py │ │ ├── segment_tree.py │ │ ├── test_mpi_util.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── envs │ │ │ │ ├── __init__.py │ │ │ │ ├── fixed_sequence_env.py │ │ │ │ ├── identity_env.py │ │ │ │ ├── identity_env_test.py │ │ │ │ └── mnist_env.py │ │ │ ├── test_cartpole.py │ │ │ ├── test_doc_examples.py │ │ │ ├── test_env_after_learn.py │ │ │ ├── test_fetchreach.py │ │ │ ├── test_fixed_sequence.py │ │ │ ├── test_identity.py │ │ │ ├── test_mnist.py │ │ │ ├── test_plot_util.py │ │ │ ├── test_schedules.py │ │ │ ├── test_segment_tree.py │ │ │ ├── test_serialization.py │ │ │ ├── test_tf_util.py │ │ │ ├── test_with_mpi.py │ │ │ └── util.py │ │ ├── tf_util.py │ │ ├── tile_images.py │ │ ├── vec_env │ │ │ ├── __init__.py │ │ │ ├── dummy_vec_env.py │ │ │ ├── shmem_vec_env.py │ │ │ ├── subproc_vec_env.py │ │ │ ├── test_vec_env.py │ │ │ ├── test_video_recorder.py │ │ │ ├── util.py │ │ │ ├── vec_env.py │ │ │ ├── vec_frame_stack.py │ │ │ ├── vec_monitor.py │ │ │ ├── vec_normalize.py │ │ │ ├── vec_remove_dict_obs.py │ │ │ └── vec_video_recorder.py │ │ └── wrappers.py │ ├── logger.py │ ├── results_plotter.py │ └── run.py ├── distributions.py ├── envs.py ├── model.py ├── storage.py └── utils.py ├── tasks ├── cleanHouse1.txt ├── fourCorners1.txt ├── harvester1.txt ├── maze.txt ├── maze1.txt ├── randomMaze1.txt ├── stairClimber1.txt ├── test1.txt ├── test2.txt ├── test3.txt ├── test4.txt └── topOff1.txt ├── utils ├── __init__.py └── misc_utils.py └── weights ├── LEAPS └── best_valid_params.ptp ├── LEAPSP └── best_valid_params.ptp ├── LEAPSPL └── best_valid_params.ptp └── LEAPSPR └── best_valid_params.ptp /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .*.sw[pon] 3 | *.json 4 | *.p 5 | .ipynb_checkpoints 6 | .ipynb_checkpoints/* 7 | **/__pycache__/ 8 | log/ 9 | .ropeproject/ 10 | train_dir 11 | trained_models/ 12 | wandb/ 13 | *.DS_Store 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dweep Trivedi 4 | Copyright (c) 2021 Jesse Zhang 5 | Copyright (c) 2021 Shao-Hua Sun 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /asset/leaps_acc_L.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/asset/leaps_acc_L.PNG -------------------------------------------------------------------------------- /asset/leaps_acc_P.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/asset/leaps_acc_P.PNG -------------------------------------------------------------------------------- /asset/leaps_model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/asset/leaps_model.jpeg -------------------------------------------------------------------------------- /docs/resources/DNN_to_LEAPS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/DNN_to_LEAPS.png -------------------------------------------------------------------------------- /docs/resources/LEAPS_generalization_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/LEAPS_generalization_fig.png -------------------------------------------------------------------------------- /docs/resources/LEAPS_generalization_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/LEAPS_generalization_table.png -------------------------------------------------------------------------------- /docs/resources/LEAPS_perf_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/LEAPS_perf_table.png -------------------------------------------------------------------------------- /docs/resources/LEAPS_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/LEAPS_slides.pdf -------------------------------------------------------------------------------- /docs/resources/cem.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/cem.gif -------------------------------------------------------------------------------- /docs/resources/clvrbanner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/clvrbanner.png -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/cleanHouse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/cleanHouse.gif -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/fourCorner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/fourCorner.gif -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/harvester.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/harvester.gif -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/randomMaze.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/randomMaze.gif -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/stairClimber.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/stairClimber.gif -------------------------------------------------------------------------------- /docs/resources/drl_karel_gifs/topOff.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/drl_karel_gifs/topOff.gif -------------------------------------------------------------------------------- /docs/resources/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/favicon-32x32.png -------------------------------------------------------------------------------- /docs/resources/human_interpretability.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/human_interpretability.mp4 -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_cleanHouse_14_22.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_cleanHouse_14_22.gif -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_fourCorners_10_10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_fourCorners_10_10.gif -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_harvester_10_10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_harvester_10_10.gif -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_randomMaze_8_8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_randomMaze_8_8.gif -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_stairClimber_10_10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_stairClimber_10_10.gif -------------------------------------------------------------------------------- /docs/resources/karel_gifs/gt_topOff_10_10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_gifs/gt_topOff_10_10.gif -------------------------------------------------------------------------------- /docs/resources/karel_tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/karel_tasks.png -------------------------------------------------------------------------------- /docs/resources/leaps_embedding_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_embedding_vis.png -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_cleanHouse_14_22.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_cleanHouse_14_22.gif -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_fourCorners_12_12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_fourCorners_12_12.gif -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_harvester_8_8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_harvester_8_8.gif -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_randomMaze_8_8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_randomMaze_8_8.gif -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_stairClimber_12_12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_stairClimber_12_12.gif -------------------------------------------------------------------------------- /docs/resources/leaps_karel_gifs/pred_one_for_all_topOff_12_12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_karel_gifs/pred_one_for_all_topOff_12_12.gif -------------------------------------------------------------------------------- /docs/resources/leaps_model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_model.jpeg -------------------------------------------------------------------------------- /docs/resources/leaps_teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/leaps_teaser.jpeg -------------------------------------------------------------------------------- /docs/resources/perf_improvement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/docs/resources/perf_improvement.png -------------------------------------------------------------------------------- /fetch_mapping.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | def fetch_mapping(filename): 5 | f = open(filename, 'r') 6 | lines = f.readlines() 7 | dsl2prl_mapping = OrderedDict() 8 | prl2dsl_mapping = OrderedDict() 9 | dsl_tokens = [] 10 | prl_tokens = [] 11 | for line in lines: 12 | tokens = [t for t in line.strip().split(' ') if not t == ''] 13 | assert len(tokens) == 2 14 | token_dsl = tokens[0] 15 | token_prl = tokens[1] 16 | dsl_tokens.append(token_dsl) 17 | dsl2prl_mapping[token_dsl] = token_prl 18 | if not token_prl == '#': 19 | prl_tokens.append(token_prl) 20 | prl2dsl_mapping[token_prl] = token_prl 21 | return dsl2prl_mapping, prl2dsl_mapping, dsl_tokens, prl_tokens 22 | 23 | if __name__ == "__main__": 24 | fetch_mapping('mapping_karel2prl.txt') 25 | -------------------------------------------------------------------------------- /github-assets/leaps_teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/github-assets/leaps_teaser.jpeg -------------------------------------------------------------------------------- /karel_env/README.md: -------------------------------------------------------------------------------- 1 | # Karel Environment 2 | 3 | This directory includes code for Karel environments, which includes: 4 | - Random programs and demonstration generator 5 | - Domain specific language interpreter 6 | 7 | ## Dataset generation 8 | Dataset used in the paper is generated with the following script 9 | ```bash 10 | ./scripts/generate_dataset.sh 11 | ``` 12 | ## Domain specific language 13 | The interpreter and random program generator for Karel domain specific language (DSL) is in the [dsl directory](./dsl). 14 | You can find detailed definition of the DSL from the supplementary material of the paper. 15 | -------------------------------------------------------------------------------- /karel_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/__init__.py -------------------------------------------------------------------------------- /karel_env/asset/agent_0.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/agent_0.PNG -------------------------------------------------------------------------------- /karel_env/asset/agent_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/agent_1.PNG -------------------------------------------------------------------------------- /karel_env/asset/agent_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/agent_2.PNG -------------------------------------------------------------------------------- /karel_env/asset/agent_3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/agent_3.PNG -------------------------------------------------------------------------------- /karel_env/asset/blank.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/blank.PNG -------------------------------------------------------------------------------- /karel_env/asset/marker.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/marker.PNG -------------------------------------------------------------------------------- /karel_env/asset/texture.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/texture.hdf5 -------------------------------------------------------------------------------- /karel_env/asset/texture.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/texture.pkl -------------------------------------------------------------------------------- /karel_env/asset/wall.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/karel_env/asset/wall.PNG -------------------------------------------------------------------------------- /karel_env/dsl/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, 'karel_env/dsl') 3 | from dsl_prob import DSLProb 4 | 5 | 6 | def get_DSL(dsl_type='prob', seed=None, environment='karel'): 7 | if dsl_type == 'prob': 8 | return DSLProb(seed=seed, environment=environment) 9 | else: 10 | raise ValueError('Undefined dsl type') -------------------------------------------------------------------------------- /karel_env/dsl/dsl_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Domain specific language for Karel Environment 3 | 4 | Code is adapted from https://github.com/carpedm20/karel 5 | """ 6 | 7 | import numpy as np 8 | import ply.lex as lex 9 | from functools import wraps 10 | 11 | from third_party import yacc 12 | 13 | MIN_INT = 0 14 | MAX_INT = 19 15 | INT_PREFIX = 'R=' 16 | 17 | 18 | class DSLBase(object): 19 | 20 | def get_yacc(self): 21 | self.yacc, self.grammar = yacc.yacc( 22 | module=self, 23 | tabmodule="_parsetab", 24 | with_grammar=True) 25 | 26 | def __init__(self, seed=None): 27 | self.lexer = lex.lex(module=self) 28 | self.get_yacc() 29 | 30 | self.prodnames = self.grammar.Prodnames 31 | self.call_counter = [0] 32 | self.max_func_call = 220 33 | self.rng = np.random.RandomState(seed) 34 | 35 | self.construct_vocab() 36 | 37 | def callout(f): 38 | @wraps(f) 39 | def wrapped(*args, **kwargs): 40 | if self.call_counter[0] > self.max_func_call: 41 | raise RuntimeError("Program execution timeout.") 42 | r = f(*args, **kwargs) 43 | self.call_counter[0] += 1 44 | return r 45 | return wrapped 46 | 47 | self.callout = callout 48 | 49 | def construct_vocab(self): 50 | self.token2int = [] 51 | self.int2token = [] 52 | for term in self.tokens: 53 | token = getattr(self, 't_{}'.format(term)) 54 | if callable(token): 55 | if token == self.t_INT: 56 | for i in range(MIN_INT, MAX_INT + 1): 57 | self.int2token.append("{}{}".format(INT_PREFIX, i)) 58 | else: 59 | self.int2token.append(str(token).replace('\\', '')) 60 | self.token2int = {v: i for i, v in enumerate(self.int2token)} 61 | 62 | def str2intseq(self, code): 63 | return [self.token2int[t] for t in code.split()] 64 | 65 | def code2intseq(self, code): 66 | return [self.token2int[t] for t in code.split()] 67 | 68 | def intseq2str(self, intseq): 69 | if max(intseq) < len(self.int2token): 70 | return ' '.join([self.int2token[i] for i in intseq]) 71 | else: 72 | # intseq contains a termination token 73 | program_str = [] 74 | for i in intseq: 75 | if i < len(self.int2token): 76 | program_str.append(self.int2token[i]) 77 | else: 78 | break 79 | return ' '.join(program_str) 80 | 81 | conditional_functions = [] 82 | 83 | action_functions = [] 84 | 85 | ######### 86 | # lexer 87 | ######### 88 | 89 | def t_error(self, t): 90 | t.lexer.skip(1) 91 | raise RuntimeError('Syntax Error') 92 | 93 | ######### 94 | # parser 95 | ######### 96 | 97 | def p_error(self, p): 98 | raise RuntimeError('Syntax Error') 99 | 100 | def random_code(self, start_token="prog", depth=0, max_depth=6, nesting_depth=0, max_nesting_depth=4): 101 | code = " ".join(self.random_tokens(start_token, depth, max_depth, nesting_depth, max_nesting_depth)) 102 | 103 | return code 104 | 105 | def parse(self, code, **kwargs): 106 | self.call_counter = [0] 107 | self.error = False 108 | program = self.yacc.parse(code, **kwargs) 109 | return program 110 | 111 | def run(self, world, code, **kwargs): 112 | self.call_counter = [0] 113 | program = self.parse(code, **kwargs) 114 | 115 | # run program 116 | world.clear_history() 117 | program(world) 118 | return world.s_h 119 | -------------------------------------------------------------------------------- /karel_env/dsl/dsl_parse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | import dsl_data 5 | 6 | 7 | def check_and_apply(queue, rule): 8 | r = rule[0].split() 9 | l = len(r) 10 | if len(queue) >= l: 11 | t = queue[-l:] 12 | if list(list(zip(*t))[0]) == r: 13 | new_t = rule[1](list(list(zip(*t))[1])) 14 | del queue[-l:] 15 | queue.extend(new_t) 16 | return True 17 | return False 18 | 19 | rules = [] 20 | 21 | # k, n, s = fn(k, n) 22 | # k: karel_world 23 | # n: num_call 24 | # s: success 25 | # c: condition [True, False] 26 | MAX_FUNC_CALL = 220 27 | 28 | 29 | def r_prog(t): 30 | stmt = t[3] 31 | 32 | def fn(k, n): 33 | if n > MAX_FUNC_CALL: return k, n, False 34 | return stmt(k, n + 1) 35 | return [('prog', fn)] 36 | rules.append(('DEF run m( stmt m)', r_prog)) 37 | 38 | 39 | def r_stmt(t): 40 | stmt = t[0] 41 | 42 | def fn(k, n): 43 | if n > MAX_FUNC_CALL: return k, n, False 44 | return stmt(k, n + 1) 45 | return [('stmt', fn)] 46 | rules.append(('while_stmt', r_stmt)) 47 | rules.append(('repeat_stmt', r_stmt)) 48 | rules.append(('stmt_stmt', r_stmt)) 49 | rules.append(('action', r_stmt)) 50 | rules.append(('if_stmt', r_stmt)) 51 | rules.append(('ifelse_stmt', r_stmt)) 52 | 53 | 54 | def r_stmt_stmt(t): 55 | stmt1, stmt2 = t[0], t[1] 56 | 57 | def fn(k, n): 58 | if n > MAX_FUNC_CALL: return k, n, False 59 | k, n, s = stmt1(k, n + 1) 60 | if not s: return k, n, s 61 | if n > MAX_FUNC_CALL: return k, n, False 62 | return stmt2(k, n) 63 | return [('stmt_stmt', fn)] 64 | rules.append(('stmt stmt', r_stmt_stmt)) 65 | 66 | 67 | def r_if(t): 68 | cond, stmt = t[2], t[5] 69 | 70 | def fn(k, n): 71 | if n > MAX_FUNC_CALL: return k, n, False 72 | k, n, s, c = cond(k, n + 1) 73 | if not s: return k, n, s 74 | if c: return stmt(k, n) 75 | else: return k, n, s 76 | return [('if_stmt', fn)] 77 | rules.append(('IF c( cond c) i( stmt i)', r_if)) 78 | 79 | 80 | def r_ifelse(t): 81 | cond, stmt1, stmt2 = t[2], t[5], t[9] 82 | 83 | def fn(k, n): 84 | if n > MAX_FUNC_CALL: return k, n, False 85 | k, n, s, c = cond(k, n + 1) 86 | if not s: return k, n, s 87 | if c: return stmt1(k, n) 88 | else: return stmt2(k, n) 89 | return [('ifelse_stmt', fn)] 90 | rules.append(('IFELSE c( cond c) i( stmt i) ELSE e( stmt e)', r_ifelse)) 91 | 92 | 93 | def r_while(t): 94 | cond, stmt = t[2], t[5] 95 | 96 | def fn(k, n): 97 | if n > MAX_FUNC_CALL: return k, n, False 98 | k, n, s, c = cond(k, n + 1) 99 | if not s: return k, n, s 100 | while(c): 101 | k, n, s = stmt(k, n) 102 | if not s: return k, n, s 103 | k, n, s, c = cond(k, n) 104 | if not s: return k, n, s 105 | return k, n, s 106 | return [('while_stmt', fn)] 107 | rules.append(('WHILE c( cond c) w( stmt w)', r_while)) 108 | 109 | 110 | def r_repeat(t): 111 | cste, stmt = t[1], t[3] 112 | 113 | def fn(k, n): 114 | if n > MAX_FUNC_CALL: return k, n, False 115 | n += 1 116 | s = True 117 | for _ in range(cste()): 118 | k, n, s = stmt(k, n) 119 | if not s: return k, n, s 120 | return k, n, s 121 | return [('repeat_stmt', fn)] 122 | rules.append(('REPEAT cste r( stmt r)', r_repeat)) 123 | 124 | 125 | def r_cond1(t): 126 | cond = t[0] 127 | 128 | def fn(k, n): 129 | if n > MAX_FUNC_CALL: return k, n, False, False 130 | return cond(k, n) 131 | return [('cond', fn)] 132 | rules.append(('cond_without_not', r_cond1)) 133 | 134 | 135 | def r_cond2(t): 136 | cond = t[2] 137 | 138 | def fn(k, n): 139 | if n > MAX_FUNC_CALL: return k, n, False, False 140 | k, n, s, c = cond(k, n) 141 | return k, n, s, not c 142 | return [('cond', fn)] 143 | rules.append(('not c( cond c)', r_cond2)) 144 | 145 | 146 | env_rules = defaultdict(list) 147 | for env in dsl_data.envs: 148 | # Condition tokens 149 | func_str = ''' 150 | def {}_r_cond_without_not_{}(t): 151 | def fn(k, n): 152 | if n > MAX_FUNC_CALL: return k, n, False 153 | c = k.{}() 154 | return k, n, True, c 155 | return [('cond_without_not', fn)] 156 | ''' 157 | 158 | for token, api in dsl_data.obv_token_api_dict[env].items(): 159 | current_func_str = func_str.format(env.replace('-','_'), token, api) 160 | exec(current_func_str) 161 | fn = eval('{}_r_cond_without_not_{}'.format(env.replace('-','_'), token)) 162 | env_rules[env].append((token, fn)) 163 | 164 | # Action tokens 165 | func_str = ''' 166 | def {}_r_action{}(t): 167 | def fn(k, n): 168 | if n > MAX_FUNC_CALL: return k, n, False 169 | action = np.zeros({}) 170 | action[{}] = 1 171 | try: k.state_transition(action) 172 | except: raise RuntimeError() 173 | #except: return k, n, False 174 | else: return k, n, True 175 | return [('action', fn)] 176 | ''' 177 | 178 | for i, token in enumerate(dsl_data.action_token_list[env]): 179 | current_func_str = func_str.format(env.replace('-','_'), i+1, len(dsl_data.action_token_list[env]), i) 180 | exec(current_func_str) 181 | fn = eval('{}_r_action{}'.format(env.replace('-','_'), i+1)) 182 | env_rules[env].append((token, fn)) 183 | i += 1 184 | 185 | 186 | def create_r_cste(number): 187 | def r_cste(t): 188 | return [('cste', lambda: number)] 189 | return r_cste 190 | 191 | 192 | for i in range(20): 193 | rules.append(('R={}'.format(i), create_r_cste(i))) 194 | 195 | 196 | def parse(program, environment='karel'): 197 | p_tokens = program.split()[::-1] 198 | queue = [] 199 | applied = False 200 | while len(p_tokens) > 0 or len(queue) != 1: 201 | if applied: applied = False 202 | else: 203 | queue.append((p_tokens.pop(), None)) 204 | for rule in rules + env_rules[environment]: 205 | applied = check_and_apply(queue, rule) 206 | if applied: break 207 | if not applied and len(p_tokens) == 0: # error parsing 208 | return None, False 209 | return queue[0][1], True 210 | 211 | 212 | -------------------------------------------------------------------------------- /karel_env/dsl/parser.out: -------------------------------------------------------------------------------- 1 | Created by PLY version 3.11 (http://www.dabeaz.com/ply) 2 | -------------------------------------------------------------------------------- /karel_env/dsl/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /karel_env/util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | # Logging 8 | # ======= 9 | 10 | import logging 11 | from colorlog import ColoredFormatter 12 | 13 | ch = logging.StreamHandler() 14 | ch.setLevel(logging.DEBUG) 15 | 16 | formatter = ColoredFormatter( 17 | "%(log_color)s[%(asctime)s] %(message)s", 18 | # datefmt='%H:%M:%S.%f', 19 | datefmt=None, 20 | reset=True, 21 | log_colors={ 22 | 'DEBUG': 'cyan', 23 | 'INFO': 'white,bold', 24 | 'INFOV': 'cyan,bold', 25 | 'WARNING': 'yellow', 26 | 'ERROR': 'red,bold', 27 | 'CRITICAL': 'red,bg_white', 28 | }, 29 | secondary_log_colors={}, 30 | style='%' 31 | ) 32 | ch.setFormatter(formatter) 33 | 34 | log = logging.getLogger('rn') 35 | log.setLevel(logging.DEBUG) 36 | log.handlers = [] # No duplicated handlers 37 | log.propagate = False # workaround for duplicated logs in ipython 38 | log.addHandler(ch) 39 | 40 | logging.addLevelName(logging.INFO + 1, 'INFOV') 41 | def _infov(self, msg, *args, **kwargs): 42 | self.log(logging.INFO + 1, msg, *args, **kwargs) 43 | 44 | logging.Logger.infov = _infov 45 | -------------------------------------------------------------------------------- /mapping_karel2prl.txt: -------------------------------------------------------------------------------- 1 | DEF # 2 | run # 3 | m( m( 4 | m) m) 5 | move move 6 | turnRight turnRight 7 | turnLeft turnLeft 8 | pickMarker pickMarker 9 | putMarker putMarker 10 | r( r( 11 | r) r) 12 | R=0 R=0 13 | R=1 R=1 14 | R=2 R=2 15 | R=3 R=3 16 | R=4 # 17 | R=5 # 18 | R=6 # 19 | R=7 # 20 | R=8 # 21 | R=9 # 22 | R=10 # 23 | R=11 # 24 | R=12 # 25 | R=13 # 26 | R=14 # 27 | R=15 # 28 | R=16 # 29 | R=17 # 30 | R=18 # 31 | R=19 # 32 | REPEAT REPEAT 33 | c( c( 34 | c) c) 35 | i( i( 36 | i) i) 37 | e( e( 38 | e) e) 39 | IF IF 40 | IFELSE IFELSE 41 | ELSE ELSE 42 | frontIsClear frontIsClear 43 | leftIsClear leftIsClear 44 | rightIsClear rightIsClear 45 | markersPresent markersPresent 46 | noMarkersPresent noMarkersPresent 47 | not not 48 | w( w( 49 | w) w) 50 | WHILE WHILE 51 | -------------------------------------------------------------------------------- /pretrain/misc_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | 5 | #### Files and Directories #### 6 | 7 | def delete_files(folder, recursive=False): 8 | for the_file in os.listdir(folder): 9 | file_path = os.path.join(folder, the_file) 10 | try: 11 | if os.path.isfile(file_path): 12 | os.unlink(file_path) 13 | elif recursive and os.path.isdir(file_path): 14 | delete_files(file_path, recursive) 15 | os.unlink(file_path) 16 | except Exception as e: 17 | print(e) 18 | 19 | 20 | def create_directory(directory): 21 | try: 22 | os.makedirs(directory) 23 | except OSError as e: 24 | if e.errno != errno.EEXIST: 25 | raise 26 | 27 | 28 | def get_files(dirpath): 29 | return [name for name in os.listdir(dirpath) if os.path.isfile(os.path.join(dirpath, name))] 30 | 31 | 32 | def get_dirs(dirpath): 33 | return [name for name in os.listdir(dirpath) if os.path.isdir(os.path.join(dirpath, name))] 34 | 35 | 36 | #### Logging Dictionary Tools #### 37 | 38 | def get_by_dotted_path(d, path, default=[]): 39 | """ Get an entry from nested dictionaries using a dotted path. 40 | 41 | Args: 42 | d: Dictionary 43 | path: Entry to extract 44 | 45 | Example: 46 | >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a') 47 | 12 48 | """ 49 | if not path: 50 | return d 51 | split_path = path.split('.') 52 | current_option = d 53 | for p in split_path: 54 | if p not in current_option: 55 | return default 56 | current_option = current_option[p] 57 | return current_option 58 | 59 | 60 | def add_record(key, value, global_logs): 61 | if 'logs' not in global_logs['info']: 62 | global_logs['info']['logs'] = {} 63 | logs = global_logs['info']['logs'] 64 | split_path = key.split('.') 65 | current = logs 66 | for p in split_path[:-1]: 67 | if p not in current: 68 | current[p] = {} 69 | current = current[p] 70 | 71 | final_key = split_path[-1] 72 | if final_key not in current: 73 | current[final_key] = [] 74 | entries = current[final_key] 75 | entries.append(value) 76 | 77 | 78 | def get_records(key, global_logs): 79 | logs = global_logs['info'].get('logs', {}) 80 | return get_by_dotted_path(logs, key) 81 | 82 | 83 | def log_record_dict(usage, log_dict, global_logs): 84 | for log_key, value in log_dict.items(): 85 | add_record('{}.{}'.format(usage, log_key), value, global_logs) 86 | 87 | 88 | #### Controlling verbosity #### 89 | 90 | def vprint(verbose, *args, **kwargs): 91 | ''' Prints only if verbose is True. 92 | ''' 93 | if verbose: 94 | print(*args, **kwargs) 95 | 96 | 97 | def vcall(verbose, fn, *args, **kwargs): 98 | ''' Calls function fn only if verbose is True. 99 | ''' 100 | if verbose: 101 | fn(*args, **kwargs) -------------------------------------------------------------------------------- /prl_gym/README.md: -------------------------------------------------------------------------------- 1 | # Neural Network Policy: 2 | - agent: karel_state -> NN -> action 3 | - karel_world: action -> next_karel_state, reward(currently sparse) 4 | 5 | # Programmatic Policy: 6 | ## Execution environment: 7 | - agent: (karel_state, program_counter) -> Program -> action 8 | - karel_world: action -> (next_karel_state, reward(currently sparse)) 9 | 10 | ## Program Environment 1: 11 | - agent: -> NN -> program 12 | - environment: program(action) -> (next_state(currently None), reward) 13 | 14 | ## Program Environment 2: 15 | - agent: -> NN -> program 16 | - environment: program(action) -> (next_state(currently None), reward (task done or not)) 17 | 18 | -------------------------------------------------------------------------------- /prl_gym/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from gym.envs.registration import register 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | register( 7 | id='CartPoleDiscrete-v0', 8 | entry_point='prl_gym.envs:CartPoleDiscreteEnv', 9 | ) 10 | -------------------------------------------------------------------------------- /prl_gym/condition_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | 4 | class ConditionEnvGym(gym.Env): 5 | """Custom Environment that follows gym interface""" 6 | metadata = {'render.modes': ['human']} 7 | 8 | def __init__(self, num_agent_actions, max_demo_length): 9 | super(ConditionEnvGym, self).__init__() 10 | 11 | self.config = config 12 | if self.config.env_name == "karel": 13 | self.dsl = get_DSL(dsl_type='prob', seed=config.seed, environment=self.config.env_name) 14 | self.s_gen = KarelStateGenerator(seed=config.seed) 15 | self._world = karel.Karel_world(make_error=False, env_task=config.env_task, reward_diff=config.reward_diff) 16 | elif self.config.env_name == "CartPoleDiscrete-v0": 17 | self.dsl = get_DSL(dsl_type='prob', seed=config.seed, environment=self.config.env_name) 18 | gym_env = gym.make(self.config.env_name) 19 | self._world = cartpole.CartPole_World(gym_env) 20 | else: 21 | raise NotImplementedError('{} not implemented for PRL setup'.format(self.config.env_name)) 22 | 23 | # Define action and observation space 24 | # They must be gym.spaces objects 25 | # Example when using discrete actions: 26 | self.action_space = spaces.Box(low=0, high=num_agent_actions+1, 27 | shape=(max_demo_length,), dtype=np.int16) 28 | # Example for using image as input: 29 | self.observation_space = spaces.Box(low=0, high=1, shape=(height, width, 16), dtype=np.bool_) 30 | 31 | if env_task == 'maze': 32 | self.init_func = self.s_gen.generate_single_state_find_marker 33 | elif env_task == 'stairClimber': 34 | self.init_func = self.s_gen.generate_single_state_stair_climber 35 | else: 36 | raise NotImplementedError('task not implemented yet') 37 | 38 | self.init_states = [self.init_func(height, width) for _ in range(config.num_demo_per_program)] 39 | self._world.set_new_state(self.init_states[0][0]) 40 | 41 | def step(self, action): 42 | return observation, reward, done, info 43 | 44 | def reset(self): 45 | 46 | return observation 47 | 48 | def render(self, mode='init_states'): 49 | if mode == 'init_states': 50 | return [x[0] for x in self.init_states] 51 | else: 52 | return self._world.render(mode) 53 | 54 | def close (self): 55 | raise NotImplementedError() -------------------------------------------------------------------------------- /prl_gym/program_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | NOTE: 3 | This file works only on Karel DSL regardless of whether we use simplified DSL for training or not. 4 | So if any of the public functions change, make sure that you call _prl_to_dsl function appropriately 5 | as we do in ProgramEnv.step (via _modify()) 6 | """ 7 | import time 8 | import numpy as np 9 | import gym 10 | from gym import spaces 11 | from exec_env import ExecEnv1, ExecEnv2 12 | from karel_env.dsl.dsl_parse import parse 13 | from fetch_mapping import fetch_mapping 14 | 15 | 16 | class ProgramEnv(gym.Env): 17 | """Environment that will follow gym interface""" 18 | 19 | def __init__(self, config, task=None, metadata={}): 20 | super(ProgramEnv, self).__init__() 21 | self.metadata = {'render.modes': ['rgb_array', 'program', 'init_states']} 22 | self.config = config 23 | self.max_program_len = config.max_program_len 24 | 25 | # load task definition (task can be defined by program or inbuilt environment reward function) 26 | if self.config.task_definition == 'program': 27 | self.task_env = ExecEnv1(config, task, metadata) 28 | self.gt_reward, _ = self.task_env.reward(self.task_env.gt_program_seq) 29 | elif self.config.task_definition == 'custom_reward': 30 | self.gt_reward = 10000.0 31 | self.task_env = ExecEnv2(config, metadata) 32 | else: 33 | raise NotImplementedError 34 | 35 | # Add one token for invalid token (all tokens after end token should be invalid) 36 | if config.use_simplified_dsl: 37 | self.num_program_tokens = len(config.prl_tokens)+1 38 | self.T2I = {tkn: i for i, tkn in enumerate(config.prl_tokens)} 39 | else: 40 | self.num_program_tokens = len(self.task_env.dsl.int2token)+1 41 | self.T2I = {tkn: i for i, tkn in enumerate(config.dsl_tokens)} 42 | 43 | self._elapsed_steps = 0 44 | self.partial_program = [] 45 | 46 | def _prl_to_dsl(self, program_seq): 47 | def func(x): 48 | return self.config.dsl_tokens.index(self.config.prl2dsl_mapping[self.config.prl_tokens[x]]) 49 | return np.array(list(map(func, program_seq)), program_seq.dtype) 50 | 51 | def _set_bad_transition(self, done, info): 52 | # TODO: need to shift this code under rl.envs.TimeLimitMask 53 | if self._elapsed_steps >= self._max_episode_steps: 54 | info['TimeLimit.truncated'] = not done 55 | info['bad_transition'] = done 56 | done = True 57 | return done, info 58 | 59 | def step(self, action): 60 | raise NotImplementedError() 61 | 62 | def render(self, mode='init_states'): 63 | """render current program for a random initial state""" 64 | if mode == 'program': 65 | pred_program = self.task_env.execute_pred_program(self.state) 66 | return pred_program 67 | elif mode == 'init_states': 68 | return self.task_env.render(mode='init_states') 69 | else: 70 | raise NotImplementedError('Yet to generate video of predicted program execution') 71 | 72 | 73 | class ProgramEnv1(ProgramEnv): 74 | """MDP2 75 | state: None 76 | action: complete program 77 | Transition: complete program -> complete program 78 | reward: environment reward for executing the program 79 | """ 80 | 81 | def __init__(self, config, task=None, metadata={}): 82 | super(ProgramEnv1, self).__init__(config, task, metadata) 83 | 84 | # define action space 85 | self.alpha = metadata.get('alpha', 1) 86 | if config.action_type == "program_multidiscrete": 87 | self.action_space = spaces.MultiDiscrete(self.alpha*self.max_program_len*[self.num_program_tokens]) 88 | elif config.action_type == "program": 89 | self.action_space = spaces.Box(low=0, high=self.num_program_tokens, 90 | shape=(self.alpha*self.max_program_len,), dtype=np.int8) 91 | 92 | # define observation space 93 | if config.obv_type == "program": 94 | self.observation_space = spaces.Box(low=0, high=self.num_program_tokens, 95 | shape=(self.max_program_len,), dtype=np.int8) 96 | self.initial_obv = (self.num_program_tokens-1) * np.ones(self.max_program_len, dtype=np.int8) 97 | elif config.obv_type == "encoded": 98 | self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), 99 | shape=[config.num_lstm_cell_units], dtype=np.float32) 100 | self.initial_obv = np.zeros(config.num_lstm_cell_units) 101 | else: 102 | raise NotImplementedError('observation not recognized') 103 | 104 | self.state = self.initial_obv 105 | 106 | def _modify(self, action): 107 | # Ignore everything after end-of-program token 108 | null_token_idx = np.argwhere(action == (self.num_program_tokens-1)) 109 | if null_token_idx.shape[0] > 0: 110 | action = action[:null_token_idx[0].squeeze()] 111 | # remap prl tokens to dsl tokens if we are using simplified DSL 112 | action = self._prl_to_dsl(action) if self.config.use_simplified_dsl else action 113 | return action 114 | 115 | def step(self, action): 116 | """Currently state is previous program, action is new program 117 | Alert: action can be in simplified DSL format, make sure to use transformed action 118 | (here we transform it in _modify()) 119 | """ 120 | if self.alpha > 1: 121 | gt_program_seq, pred_program_seq = action[:len(action)//2], action[len(action)//2:] 122 | self.task_env.gt_program_seq = gt_program_seq 123 | self.task_env.gt_program = self.task_env._execute_gt_program(self.config, gt_program_seq) 124 | action = pred_program_seq 125 | self._elapsed_steps += 1 126 | dsl_action = self._modify(action) 127 | # FIXME: temporary fix for ignoring DEF, run, )m kind of tokens 128 | if self.config.experiment == 'intention_space': 129 | dsl_action = np.concatenate((np.array([0]), dsl_action)) 130 | else: 131 | if self.config.grammar is not None: 132 | dsl_action = np.concatenate((np.array([0, 1, 2]), dsl_action)) 133 | else: 134 | dsl_action = np.concatenate((np.array([0, 1, 2]), dsl_action, np.array([3]))) 135 | 136 | self.state = program_seq = dsl_action 137 | reward, exec_data = self.task_env.reward(program_seq) 138 | done = True if reward == self.gt_reward else False 139 | info = {'cur_state': action, 'modified_action': dsl_action, 'exec_data': exec_data} 140 | 141 | done, info = self._set_bad_transition(done, info) 142 | 143 | return self.initial_obv, reward, done, info 144 | 145 | def reset(self): 146 | # Reset the state of the environment to an initial state 147 | self._elapsed_steps = 0 148 | self.partial_program = [] 149 | self.state = self.initial_obv 150 | self.task_env.reset() 151 | return self.state 152 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tensorboardX==1.9 3 | gym==0.15.4 4 | requests 5 | wandb 6 | ply 7 | imageio 8 | tqdm 9 | h5py 10 | pandas 11 | colorlog 12 | progressbar 13 | torch==1.4.0 14 | torchvision==0.5.0 15 | -------------------------------------------------------------------------------- /rl/algo/a2c_acktr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from rl.algo.kfac import KFACOptimizer 6 | 7 | 8 | class A2C_ACKTR(): 9 | def __init__(self, 10 | actor_critic, 11 | value_loss_coef, 12 | entropy_coef, 13 | lr=None, 14 | eps=None, 15 | alpha=None, 16 | max_grad_norm=None, 17 | use_recurrent_generator=False, 18 | writer=None, 19 | acktr=False): 20 | 21 | self.actor_critic = actor_critic 22 | self.acktr = acktr 23 | 24 | self.value_loss_coef = value_loss_coef 25 | self.entropy_coef = entropy_coef 26 | 27 | self.max_grad_norm = max_grad_norm 28 | 29 | self.use_recurrent_generator = use_recurrent_generator 30 | 31 | if acktr: 32 | self.optimizer = KFACOptimizer(actor_critic) 33 | else: 34 | self.optimizer = optim.RMSprop( 35 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 36 | 37 | self._global_step = 0 38 | self.writer = writer 39 | 40 | def update(self, rollouts): 41 | obs_shape = rollouts.obs.size()[2:] 42 | action_shape = rollouts.actions.size()[-1] 43 | num_steps, num_processes, _ = rollouts.rewards.size() 44 | 45 | values, action_log_probs, dist_entropy, _, _ = self.actor_critic.evaluate_actions( 46 | rollouts.obs[:-1].view(-1, *obs_shape), 47 | rollouts.recurrent_hidden_states[0].view( 48 | -1, self.actor_critic.recurrent_hidden_state_size), 49 | rollouts.masks[:-1].view(-1, 1), 50 | rollouts.actions.view(-1, action_shape), 51 | rollouts.output_masks.view(-1, action_shape)) 52 | 53 | values = values.view(num_steps, num_processes, 1) 54 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 55 | 56 | advantages = rollouts.returns[:-1] - values 57 | value_loss = advantages.pow(2).mean() 58 | 59 | action_loss = -(advantages.detach() * action_log_probs).mean() 60 | 61 | if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0: 62 | # Sampled fisher, see Martens 2014 63 | self.actor_critic.zero_grad() 64 | pg_fisher_loss = -action_log_probs.mean() 65 | 66 | value_noise = torch.randn(values.size()) 67 | if values.is_cuda: 68 | value_noise = value_noise.cuda() 69 | 70 | sample_values = values + value_noise 71 | vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean() 72 | 73 | fisher_loss = pg_fisher_loss + vf_fisher_loss 74 | self.optimizer.acc_stats = True 75 | fisher_loss.backward(retain_graph=True) 76 | self.optimizer.acc_stats = False 77 | 78 | self.optimizer.zero_grad() 79 | (value_loss * self.value_loss_coef + action_loss - 80 | dist_entropy * self.entropy_coef).backward() 81 | 82 | if self.acktr == False: 83 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 84 | self.max_grad_norm) 85 | 86 | self.optimizer.step() 87 | 88 | self.writer.add_scalar('ppo/avg_policy_loss', action_loss.item(), self._global_step) 89 | self.writer.add_scalar('ppo/avg_value_loss', value_loss.item(), self._global_step) 90 | self.writer.add_scalar('ppo/entropy', dist_entropy.item(), self._global_step) 91 | self._global_step += 1 92 | 93 | return value_loss.item(), action_loss.item(), dist_entropy.item() 94 | -------------------------------------------------------------------------------- /rl/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/rl/baselines/__init__.py -------------------------------------------------------------------------------- /rl/baselines/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from baselines.bench.benchmarks import * 3 | from baselines.bench.monitor import * 4 | -------------------------------------------------------------------------------- /rl/baselines/bench/benchmarks.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | _atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders'] 6 | _atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture'] 7 | 8 | _BENCHMARKS = [] 9 | 10 | remove_version_re = re.compile(r'-v\d+$') 11 | 12 | 13 | def register_benchmark(benchmark): 14 | for b in _BENCHMARKS: 15 | if b['name'] == benchmark['name']: 16 | raise ValueError('Benchmark with name %s already registered!' % b['name']) 17 | 18 | # automatically add a description if it is not present 19 | if 'tasks' in benchmark: 20 | for t in benchmark['tasks']: 21 | if 'desc' not in t: 22 | t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id'))) 23 | _BENCHMARKS.append(benchmark) 24 | 25 | 26 | def list_benchmarks(): 27 | return [b['name'] for b in _BENCHMARKS] 28 | 29 | 30 | def get_benchmark(benchmark_name): 31 | for b in _BENCHMARKS: 32 | if b['name'] == benchmark_name: 33 | return b 34 | raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks())) 35 | 36 | 37 | def get_task(benchmark, env_id): 38 | """Get a task by env_id. Return None if the benchmark doesn't have the env""" 39 | return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None) 40 | 41 | 42 | def find_task_for_env_id_in_any_benchmark(env_id): 43 | for bm in _BENCHMARKS: 44 | for task in bm["tasks"]: 45 | if task["env_id"] == env_id: 46 | return bm, task 47 | return None, None 48 | 49 | 50 | _ATARI_SUFFIX = 'NoFrameskip-v4' 51 | 52 | register_benchmark({ 53 | 'name': 'Atari50M', 54 | 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 50M timesteps', 55 | 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(50e6)} for _game in _atari7] 56 | }) 57 | 58 | register_benchmark({ 59 | 'name': 'Atari10M', 60 | 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps', 61 | 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 6, 'num_timesteps': int(10e6)} for _game in _atari7] 62 | }) 63 | 64 | register_benchmark({ 65 | 'name': 'Atari1Hr', 66 | 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime', 67 | 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_seconds': 60 * 60} for _game in _atari7] 68 | }) 69 | 70 | register_benchmark({ 71 | 'name': 'AtariExploration10M', 72 | 'description': '7 Atari games emphasizing exploration, with pixel observations, 10M timesteps', 73 | 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atariexpl7] 74 | }) 75 | 76 | 77 | # MuJoCo 78 | 79 | _mujocosmall = [ 80 | 'InvertedDoublePendulum-v2', 'InvertedPendulum-v2', 81 | 'HalfCheetah-v2', 'Hopper-v2', 'Walker2d-v2', 82 | 'Reacher-v2', 'Swimmer-v2'] 83 | register_benchmark({ 84 | 'name': 'Mujoco1M', 85 | 'description': 'Some small 2D MuJoCo tasks, run for 1M timesteps', 86 | 'tasks': [{'env_id': _envid, 'trials': 6, 'num_timesteps': int(1e6)} for _envid in _mujocosmall] 87 | }) 88 | 89 | register_benchmark({ 90 | 'name': 'MujocoWalkers', 91 | 'description': 'MuJoCo forward walkers, run for 8M, humanoid 100M', 92 | 'tasks': [ 93 | {'env_id': "Hopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 94 | {'env_id': "Walker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 95 | {'env_id': "Humanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000}, 96 | ] 97 | }) 98 | 99 | # Bullet 100 | _bulletsmall = [ 101 | 'InvertedDoublePendulum', 'InvertedPendulum', 'HalfCheetah', 'Reacher', 'Walker2D', 'Hopper', 'Ant' 102 | ] 103 | _bulletsmall = [e + 'BulletEnv-v0' for e in _bulletsmall] 104 | 105 | register_benchmark({ 106 | 'name': 'Bullet1M', 107 | 'description': '6 mujoco-like tasks from bullet, 1M steps', 108 | 'tasks': [{'env_id': e, 'trials': 6, 'num_timesteps': int(1e6)} for e in _bulletsmall] 109 | }) 110 | 111 | 112 | # Roboschool 113 | 114 | register_benchmark({ 115 | 'name': 'Roboschool8M', 116 | 'description': 'Small 2D tasks, up to 30 minutes to complete on 8 cores', 117 | 'tasks': [ 118 | {'env_id': "RoboschoolReacher-v1", 'trials': 4, 'num_timesteps': 2 * 1000000}, 119 | {'env_id': "RoboschoolAnt-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 120 | {'env_id': "RoboschoolHalfCheetah-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 121 | {'env_id': "RoboschoolHopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 122 | {'env_id': "RoboschoolWalker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000}, 123 | ] 124 | }) 125 | register_benchmark({ 126 | 'name': 'RoboschoolHarder', 127 | 'description': 'Test your might!!! Up to 12 hours on 32 cores', 128 | 'tasks': [ 129 | {'env_id': "RoboschoolHumanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000}, 130 | {'env_id': "RoboschoolHumanoidFlagrun-v1", 'trials': 4, 'num_timesteps': 200 * 1000000}, 131 | {'env_id': "RoboschoolHumanoidFlagrunHarder-v1", 'trials': 4, 'num_timesteps': 400 * 1000000}, 132 | ] 133 | }) 134 | 135 | # Other 136 | 137 | _atari50 = [ # actually 47 138 | 'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids', 139 | 'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling', 140 | 'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber', 141 | 'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway', 142 | 'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond', 143 | 'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman', 144 | 'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert', 145 | 'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders', 'StarGunner', 146 | 'Tennis', 'TimePilot', 'Tutankham', 'UpNDown', 'Venture', 147 | 'VideoPinball', 'WizardOfWor', 'Zaxxon', 148 | ] 149 | 150 | register_benchmark({ 151 | 'name': 'Atari50_10M', 152 | 'description': '47 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps', 153 | 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atari50] 154 | }) 155 | 156 | # HER DDPG 157 | 158 | _fetch_tasks = ['FetchReach-v1', 'FetchPush-v1', 'FetchSlide-v1'] 159 | register_benchmark({ 160 | 'name': 'Fetch1M', 161 | 'description': 'Fetch* benchmarks for 1M timesteps', 162 | 'tasks': [{'trials': 6, 'env_id': env_id, 'num_timesteps': int(1e6)} for env_id in _fetch_tasks] 163 | }) 164 | 165 | -------------------------------------------------------------------------------- /rl/baselines/bench/monitor.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Monitor', 'get_monitor_files', 'load_results'] 2 | 3 | from gym.core import Wrapper 4 | import time 5 | from glob import glob 6 | import csv 7 | import os.path as osp 8 | import json 9 | 10 | class Monitor(Wrapper): 11 | EXT = "monitor.csv" 12 | f = None 13 | 14 | def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): 15 | Wrapper.__init__(self, env=env) 16 | self.tstart = time.time() 17 | if filename: 18 | self.results_writer = ResultsWriter(filename, 19 | header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id}, 20 | extra_keys=reset_keywords + info_keywords 21 | ) 22 | else: 23 | self.results_writer = None 24 | self.reset_keywords = reset_keywords 25 | self.info_keywords = info_keywords 26 | self.allow_early_resets = allow_early_resets 27 | self.rewards = None 28 | self.needs_reset = True 29 | self.episode_rewards = [] 30 | self.episode_lengths = [] 31 | self.episode_times = [] 32 | self.total_steps = 0 33 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() 34 | 35 | def reset(self, **kwargs): 36 | self.reset_state() 37 | for k in self.reset_keywords: 38 | v = kwargs.get(k) 39 | if v is None: 40 | raise ValueError('Expected you to pass kwarg %s into reset'%k) 41 | self.current_reset_info[k] = v 42 | return self.env.reset(**kwargs) 43 | 44 | def reset_state(self): 45 | if not self.allow_early_resets and not self.needs_reset: 46 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)") 47 | self.rewards = [] 48 | self.needs_reset = False 49 | 50 | 51 | def step(self, action): 52 | if self.needs_reset: 53 | raise RuntimeError("Tried to step environment that needs reset") 54 | ob, rew, done, info = self.env.step(action) 55 | self.update(ob, rew, done, info) 56 | return (ob, rew, done, info) 57 | 58 | def update(self, ob, rew, done, info): 59 | self.rewards.append(rew) 60 | if done: 61 | self.needs_reset = True 62 | eprew = sum(self.rewards) 63 | eplen = len(self.rewards) 64 | epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)} 65 | for k in self.info_keywords: 66 | epinfo[k] = info[k] 67 | self.episode_rewards.append(eprew) 68 | self.episode_lengths.append(eplen) 69 | self.episode_times.append(time.time() - self.tstart) 70 | epinfo.update(self.current_reset_info) 71 | if self.results_writer: 72 | self.results_writer.write_row(epinfo) 73 | assert isinstance(info, dict) 74 | if isinstance(info, dict): 75 | info['episode'] = epinfo 76 | 77 | self.total_steps += 1 78 | 79 | def close(self): 80 | if self.f is not None: 81 | self.f.close() 82 | 83 | def get_total_steps(self): 84 | return self.total_steps 85 | 86 | def get_episode_rewards(self): 87 | return self.episode_rewards 88 | 89 | def get_episode_lengths(self): 90 | return self.episode_lengths 91 | 92 | def get_episode_times(self): 93 | return self.episode_times 94 | 95 | class LoadMonitorResultsError(Exception): 96 | pass 97 | 98 | 99 | class ResultsWriter(object): 100 | def __init__(self, filename, header='', extra_keys=()): 101 | self.extra_keys = extra_keys 102 | assert filename is not None 103 | if not filename.endswith(Monitor.EXT): 104 | if osp.isdir(filename): 105 | filename = osp.join(filename, Monitor.EXT) 106 | else: 107 | filename = filename + "." + Monitor.EXT 108 | self.f = open(filename, "wt") 109 | if isinstance(header, dict): 110 | header = '# {} \n'.format(json.dumps(header)) 111 | self.f.write(header) 112 | self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys)) 113 | self.logger.writeheader() 114 | self.f.flush() 115 | 116 | def write_row(self, epinfo): 117 | if self.logger: 118 | self.logger.writerow(epinfo) 119 | self.f.flush() 120 | 121 | 122 | def get_monitor_files(dir): 123 | return glob(osp.join(dir, "*" + Monitor.EXT)) 124 | 125 | def load_results(dir): 126 | import pandas 127 | monitor_files = ( 128 | glob(osp.join(dir, "*monitor.json")) + 129 | glob(osp.join(dir, "*monitor.csv"))) # get both csv and (old) json files 130 | if not monitor_files: 131 | raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir)) 132 | dfs = [] 133 | headers = [] 134 | for fname in monitor_files: 135 | with open(fname, 'rt') as fh: 136 | if fname.endswith('csv'): 137 | firstline = fh.readline() 138 | if not firstline: 139 | continue 140 | assert firstline[0] == '#' 141 | header = json.loads(firstline[1:]) 142 | df = pandas.read_csv(fh, index_col=None) 143 | headers.append(header) 144 | elif fname.endswith('json'): # Deprecated json format 145 | episodes = [] 146 | lines = fh.readlines() 147 | header = json.loads(lines[0]) 148 | headers.append(header) 149 | for line in lines[1:]: 150 | episode = json.loads(line) 151 | episodes.append(episode) 152 | df = pandas.DataFrame(episodes) 153 | else: 154 | assert 0, 'unreachable' 155 | df['t'] += header['t_start'] 156 | dfs.append(df) 157 | df = pandas.concat(dfs) 158 | df.sort_values('t', inplace=True) 159 | df.reset_index(inplace=True) 160 | df['t'] -= min(header['t_start'] for header in headers) 161 | df.headers = headers # HACK to preserve backwards compatibility 162 | return df 163 | -------------------------------------------------------------------------------- /rl/baselines/bench/test_monitor.py: -------------------------------------------------------------------------------- 1 | from .monitor import Monitor 2 | import gym 3 | import json 4 | 5 | def test_monitor(): 6 | import pandas 7 | import os 8 | import uuid 9 | 10 | env = gym.make("CartPole-v1") 11 | env.seed(0) 12 | mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4() 13 | menv = Monitor(env, mon_file) 14 | menv.reset() 15 | for _ in range(1000): 16 | _, _, done, _ = menv.step(0) 17 | if done: 18 | menv.reset() 19 | 20 | f = open(mon_file, 'rt') 21 | 22 | firstline = f.readline() 23 | assert firstline.startswith('#') 24 | metadata = json.loads(firstline[1:]) 25 | assert metadata['env_id'] == "CartPole-v1" 26 | assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata" 27 | 28 | last_logline = pandas.read_csv(f, index_col=None) 29 | assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" 30 | f.close() 31 | os.remove(mon_file) 32 | -------------------------------------------------------------------------------- /rl/baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from baselines.common.console_util import * 3 | from baselines.common.dataset import Dataset 4 | from baselines.common.math_util import * 5 | from baselines.common.misc_util import * 6 | -------------------------------------------------------------------------------- /rl/baselines/common/cg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): 3 | """ 4 | Demmel p 312 5 | """ 6 | p = b.copy() 7 | r = b.copy() 8 | x = np.zeros_like(b) 9 | rdotr = r.dot(r) 10 | 11 | fmtstr = "%10i %10.3g %10.3g" 12 | titlestr = "%10s %10s %10s" 13 | if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) 14 | 15 | for i in range(cg_iters): 16 | if callback is not None: 17 | callback(x) 18 | if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) 19 | z = f_Ax(p) 20 | v = rdotr / p.dot(z) 21 | x += v*p 22 | r -= v*z 23 | newrdotr = r.dot(r) 24 | mu = newrdotr/rdotr 25 | p = r + mu*p 26 | 27 | rdotr = newrdotr 28 | if rdotr < residual_tol: 29 | break 30 | 31 | if callback is not None: 32 | callback(x) 33 | if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 34 | return x 35 | -------------------------------------------------------------------------------- /rl/baselines/common/console_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from contextlib import contextmanager 3 | import numpy as np 4 | import time 5 | import shlex 6 | import subprocess 7 | 8 | # ================================================================ 9 | # Misc 10 | # ================================================================ 11 | 12 | def fmt_row(width, row, header=False): 13 | out = " | ".join(fmt_item(x, width) for x in row) 14 | if header: out = out + "\n" + "-"*len(out) 15 | return out 16 | 17 | def fmt_item(x, l): 18 | if isinstance(x, np.ndarray): 19 | assert x.ndim==0 20 | x = x.item() 21 | if isinstance(x, (float, np.float32, np.float64)): 22 | v = abs(x) 23 | if (v < 1e-4 or v > 1e+4) and v > 0: 24 | rep = "%7.2e" % x 25 | else: 26 | rep = "%7.5f" % x 27 | else: rep = str(x) 28 | return " "*(l - len(rep)) + rep 29 | 30 | color2num = dict( 31 | gray=30, 32 | red=31, 33 | green=32, 34 | yellow=33, 35 | blue=34, 36 | magenta=35, 37 | cyan=36, 38 | white=37, 39 | crimson=38 40 | ) 41 | 42 | def colorize(string, color='green', bold=False, highlight=False): 43 | attr = [] 44 | num = color2num[color] 45 | if highlight: num += 10 46 | attr.append(str(num)) 47 | if bold: attr.append('1') 48 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 49 | 50 | def print_cmd(cmd, dry=False): 51 | if isinstance(cmd, str): # for shell=True 52 | pass 53 | else: 54 | cmd = ' '.join(shlex.quote(arg) for arg in cmd) 55 | print(colorize(('CMD: ' if not dry else 'DRY: ') + cmd)) 56 | 57 | 58 | def get_git_commit(cwd=None): 59 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=cwd).decode('utf8') 60 | 61 | def get_git_commit_message(cwd=None): 62 | return subprocess.check_output(['git', 'show', '-s', '--format=%B', 'HEAD'], cwd=cwd).decode('utf8') 63 | 64 | def ccap(cmd, dry=False, env=None, **kwargs): 65 | print_cmd(cmd, dry) 66 | if not dry: 67 | subprocess.check_call(cmd, env=env, **kwargs) 68 | 69 | 70 | MESSAGE_DEPTH = 0 71 | 72 | @contextmanager 73 | def timed(msg): 74 | global MESSAGE_DEPTH #pylint: disable=W0603 75 | print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta')) 76 | tstart = time.time() 77 | MESSAGE_DEPTH += 1 78 | yield 79 | MESSAGE_DEPTH -= 1 80 | print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta')) 81 | -------------------------------------------------------------------------------- /rl/baselines/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Dataset(object): 4 | def __init__(self, data_map, deterministic=False, shuffle=True): 5 | self.data_map = data_map 6 | self.deterministic = deterministic 7 | self.enable_shuffle = shuffle 8 | self.n = next(iter(data_map.values())).shape[0] 9 | self._next_id = 0 10 | self.shuffle() 11 | 12 | def shuffle(self): 13 | if self.deterministic: 14 | return 15 | perm = np.arange(self.n) 16 | np.random.shuffle(perm) 17 | 18 | for key in self.data_map: 19 | self.data_map[key] = self.data_map[key][perm] 20 | 21 | self._next_id = 0 22 | 23 | def next_batch(self, batch_size): 24 | if self._next_id >= self.n and self.enable_shuffle: 25 | self.shuffle() 26 | 27 | cur_id = self._next_id 28 | cur_batch_size = min(batch_size, self.n - self._next_id) 29 | self._next_id += cur_batch_size 30 | 31 | data_map = dict() 32 | for key in self.data_map: 33 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size] 34 | return data_map 35 | 36 | def iterate_once(self, batch_size): 37 | if self.enable_shuffle: self.shuffle() 38 | 39 | while self._next_id <= self.n - batch_size: 40 | yield self.next_batch(batch_size) 41 | self._next_id = 0 42 | 43 | def subset(self, num_elements, deterministic=True): 44 | data_map = dict() 45 | for key in self.data_map: 46 | data_map[key] = self.data_map[key][:num_elements] 47 | return Dataset(data_map, deterministic) 48 | 49 | 50 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 51 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 52 | arrays = tuple(map(np.asarray, arrays)) 53 | n = arrays[0].shape[0] 54 | assert all(a.shape[0] == n for a in arrays[1:]) 55 | inds = np.arange(n) 56 | if shuffle: np.random.shuffle(inds) 57 | sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches 58 | for batch_inds in np.array_split(inds, sections): 59 | if include_final_partial_batch or len(batch_inds) == batch_size: 60 | yield tuple(a[batch_inds] for a in arrays) 61 | -------------------------------------------------------------------------------- /rl/baselines/common/input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from gym.spaces import Discrete, Box, MultiDiscrete 4 | 5 | def observation_placeholder(ob_space, batch_size=None, name='Ob'): 6 | ''' 7 | Create placeholder to feed observations into of the size appropriate to the observation space 8 | 9 | Parameters: 10 | ---------- 11 | 12 | ob_space: gym.Space observation space 13 | 14 | batch_size: int size of the batch to be fed into input. Can be left None in most cases. 15 | 16 | name: str name of the placeholder 17 | 18 | Returns: 19 | ------- 20 | 21 | tensorflow placeholder tensor 22 | ''' 23 | 24 | assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 25 | 'Can only deal with Discrete and Box observation spaces for now' 26 | 27 | dtype = ob_space.dtype 28 | if dtype == np.int8: 29 | dtype = np.uint8 30 | 31 | return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) 32 | 33 | 34 | def observation_input(ob_space, batch_size=None, name='Ob'): 35 | ''' 36 | Create placeholder to feed observations into of the size appropriate to the observation space, and add input 37 | encoder of the appropriate type. 38 | ''' 39 | 40 | placeholder = observation_placeholder(ob_space, batch_size, name) 41 | return placeholder, encode_observation(ob_space, placeholder) 42 | 43 | def encode_observation(ob_space, placeholder): 44 | ''' 45 | Encode input in the way that is appropriate to the observation space 46 | 47 | Parameters: 48 | ---------- 49 | 50 | ob_space: gym.Space observation space 51 | 52 | placeholder: tf.placeholder observation input placeholder 53 | ''' 54 | if isinstance(ob_space, Discrete): 55 | return tf.to_float(tf.one_hot(placeholder, ob_space.n)) 56 | elif isinstance(ob_space, Box): 57 | return tf.to_float(placeholder) 58 | elif isinstance(ob_space, MultiDiscrete): 59 | placeholder = tf.cast(placeholder, tf.int32) 60 | one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] 61 | return tf.concat(one_hots, axis=-1) 62 | else: 63 | raise NotImplementedError 64 | 65 | -------------------------------------------------------------------------------- /rl/baselines/common/math_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | 4 | 5 | def discount(x, gamma): 6 | """ 7 | computes discounted sums along 0th dimension of x. 8 | 9 | inputs 10 | ------ 11 | x: ndarray 12 | gamma: float 13 | 14 | outputs 15 | ------- 16 | y: ndarray with same shape as x, satisfying 17 | 18 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 19 | where k = len(x) - t - 1 20 | 21 | """ 22 | assert x.ndim >= 1 23 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 24 | 25 | def explained_variance(ypred,y): 26 | """ 27 | Computes fraction of variance that ypred explains about y. 28 | Returns 1 - Var[y-ypred] / Var[y] 29 | 30 | interpretation: 31 | ev=0 => might as well have predicted zero 32 | ev=1 => perfect prediction 33 | ev<0 => worse than just predicting zero 34 | 35 | """ 36 | assert y.ndim == 1 and ypred.ndim == 1 37 | vary = np.var(y) 38 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 39 | 40 | def explained_variance_2d(ypred, y): 41 | assert y.ndim == 2 and ypred.ndim == 2 42 | vary = np.var(y, axis=0) 43 | out = 1 - np.var(y-ypred)/vary 44 | out[vary < 1e-10] = 0 45 | return out 46 | 47 | def ncc(ypred, y): 48 | return np.corrcoef(ypred, y)[1,0] 49 | 50 | def flatten_arrays(arrs): 51 | return np.concatenate([arr.flat for arr in arrs]) 52 | 53 | def unflatten_vector(vec, shapes): 54 | i=0 55 | arrs = [] 56 | for shape in shapes: 57 | size = np.prod(shape) 58 | arr = vec[i:i+size].reshape(shape) 59 | arrs.append(arr) 60 | i += size 61 | return arrs 62 | 63 | def discount_with_boundaries(X, New, gamma): 64 | """ 65 | X: 2d array of floats, time x features 66 | New: 2d array of bools, indicating when a new episode has started 67 | """ 68 | Y = np.zeros_like(X) 69 | T = X.shape[0] 70 | Y[T-1] = X[T-1] 71 | for t in range(T-2, -1, -1): 72 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 73 | return Y 74 | 75 | def test_discount_with_boundaries(): 76 | gamma=0.9 77 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 78 | starts = [1.0, 0.0, 0.0, 1.0] 79 | y = discount_with_boundaries(x, starts, gamma) 80 | assert np.allclose(y, [ 81 | 1 + gamma * 2 + gamma**2 * 3, 82 | 2 + gamma * 3, 83 | 3, 84 | 4 85 | ]) 86 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_adam.py: -------------------------------------------------------------------------------- 1 | import baselines.common.tf_util as U 2 | import tensorflow as tf 3 | import numpy as np 4 | try: 5 | from mpi4py import MPI 6 | except ImportError: 7 | MPI = None 8 | 9 | 10 | class MpiAdam(object): 11 | def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): 12 | self.var_list = var_list 13 | self.beta1 = beta1 14 | self.beta2 = beta2 15 | self.epsilon = epsilon 16 | self.scale_grad_by_procs = scale_grad_by_procs 17 | size = sum(U.numel(v) for v in var_list) 18 | self.m = np.zeros(size, 'float32') 19 | self.v = np.zeros(size, 'float32') 20 | self.t = 0 21 | self.setfromflat = U.SetFromFlat(var_list) 22 | self.getflat = U.GetFlat(var_list) 23 | self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm 24 | 25 | def update(self, localg, stepsize): 26 | if self.t % 100 == 0: 27 | self.check_synced() 28 | localg = localg.astype('float32') 29 | if self.comm is not None: 30 | globalg = np.zeros_like(localg) 31 | self.comm.Allreduce(localg, globalg, op=MPI.SUM) 32 | if self.scale_grad_by_procs: 33 | globalg /= self.comm.Get_size() 34 | else: 35 | globalg = np.copy(localg) 36 | 37 | self.t += 1 38 | a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t) 39 | self.m = self.beta1 * self.m + (1 - self.beta1) * globalg 40 | self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) 41 | step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon) 42 | self.setfromflat(self.getflat() + step) 43 | 44 | def sync(self): 45 | if self.comm is None: 46 | return 47 | theta = self.getflat() 48 | self.comm.Bcast(theta, root=0) 49 | self.setfromflat(theta) 50 | 51 | def check_synced(self): 52 | if self.comm is None: 53 | return 54 | if self.comm.Get_rank() == 0: # this is root 55 | theta = self.getflat() 56 | self.comm.Bcast(theta, root=0) 57 | else: 58 | thetalocal = self.getflat() 59 | thetaroot = np.empty_like(thetalocal) 60 | self.comm.Bcast(thetaroot, root=0) 61 | assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) 62 | 63 | @U.in_session 64 | def test_MpiAdam(): 65 | np.random.seed(0) 66 | tf.set_random_seed(0) 67 | 68 | a = tf.Variable(np.random.randn(3).astype('float32')) 69 | b = tf.Variable(np.random.randn(2,5).astype('float32')) 70 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b)) 71 | 72 | stepsize = 1e-2 73 | update_op = tf.train.AdamOptimizer(stepsize).minimize(loss) 74 | do_update = U.function([], loss, updates=[update_op]) 75 | 76 | tf.get_default_session().run(tf.global_variables_initializer()) 77 | losslist_ref = [] 78 | for i in range(10): 79 | l = do_update() 80 | print(i, l) 81 | losslist_ref.append(l) 82 | 83 | 84 | 85 | tf.set_random_seed(0) 86 | tf.get_default_session().run(tf.global_variables_initializer()) 87 | 88 | var_list = [a,b] 89 | lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)]) 90 | adam = MpiAdam(var_list) 91 | 92 | losslist_test = [] 93 | for i in range(10): 94 | l,g = lossandgrad() 95 | adam.update(g, stepsize) 96 | print(i,l) 97 | losslist_test.append(l) 98 | 99 | np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4) 100 | 101 | 102 | if __name__ == '__main__': 103 | test_MpiAdam() 104 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_adam_optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from baselines.common import tf_util as U 4 | from baselines.common.tests.test_with_mpi import with_mpi 5 | from baselines import logger 6 | try: 7 | from mpi4py import MPI 8 | except ImportError: 9 | MPI = None 10 | 11 | class MpiAdamOptimizer(tf.train.AdamOptimizer): 12 | """Adam optimizer that averages gradients across mpi processes.""" 13 | def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs): 14 | self.comm = comm 15 | self.grad_clip = grad_clip 16 | self.mpi_rank_weight = mpi_rank_weight 17 | tf.train.AdamOptimizer.__init__(self, **kwargs) 18 | def compute_gradients(self, loss, var_list, **kwargs): 19 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) 20 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 21 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight 22 | shapes = [v.shape.as_list() for g, v in grads_and_vars] 23 | sizes = [int(np.prod(s)) for s in shapes] 24 | 25 | total_weight = np.zeros(1, np.float32) 26 | self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM) 27 | total_weight = total_weight[0] 28 | 29 | buf = np.zeros(sum(sizes), np.float32) 30 | countholder = [0] # Counts how many times _collect_grads has been called 31 | stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable 32 | def _collect_grads(flat_grad, np_stat): 33 | if self.grad_clip is not None: 34 | gradnorm = np.linalg.norm(flat_grad) 35 | if gradnorm > 1: 36 | flat_grad /= gradnorm 37 | logger.logkv_mean('gradnorm', gradnorm) 38 | logger.logkv_mean('gradclipfrac', float(gradnorm > 1)) 39 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) 40 | np.divide(buf, float(total_weight), out=buf) 41 | if countholder[0] % 100 == 0: 42 | check_synced(np_stat, self.comm) 43 | countholder[0] += 1 44 | return buf 45 | 46 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad, stat], tf.float32) 47 | avg_flat_grad.set_shape(flat_grad.shape) 48 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0) 49 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v) 50 | for g, (_, v) in zip(avg_grads, grads_and_vars)] 51 | return avg_grads_and_vars 52 | 53 | def check_synced(localval, comm=None): 54 | """ 55 | It's common to forget to initialize your variables to the same values, or 56 | (less commonly) if you update them in some other way than adam, to get them out of sync. 57 | This function checks that variables on all MPI workers are the same, and raises 58 | an AssertionError otherwise 59 | 60 | Arguments: 61 | comm: MPI communicator 62 | localval: list of local variables (list of variables on current worker to be compared with the other workers) 63 | """ 64 | comm = comm or MPI.COMM_WORLD 65 | vals = comm.gather(localval) 66 | if comm.rank == 0: 67 | assert all(val==vals[0] for val in vals[1:]),\ 68 | 'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals) 69 | 70 | @with_mpi(timeout=5) 71 | def test_nonfreeze(): 72 | np.random.seed(0) 73 | tf.set_random_seed(0) 74 | 75 | a = tf.Variable(np.random.randn(3).astype('float32')) 76 | b = tf.Variable(np.random.randn(2,5).astype('float32')) 77 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b)) 78 | 79 | stepsize = 1e-2 80 | # for some reason the session config with inter_op_parallelism_threads was causing 81 | # nested sess.run calls to freeze 82 | config = tf.ConfigProto(inter_op_parallelism_threads=1) 83 | sess = U.get_session(config=config) 84 | update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss) 85 | sess.run(tf.global_variables_initializer()) 86 | losslist_ref = [] 87 | for i in range(100): 88 | l,_ = sess.run([loss, update_op]) 89 | print(i, l) 90 | losslist_ref.append(l) 91 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_fork.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, sys 2 | 3 | def mpi_fork(n, bind_to_core=False): 4 | """Re-launches the current script with workers 5 | Returns "parent" for original parent, "child" for MPI children 6 | """ 7 | if n<=1: 8 | return "child" 9 | if os.getenv("IN_MPI") is None: 10 | env = os.environ.copy() 11 | env.update( 12 | MKL_NUM_THREADS="1", 13 | OMP_NUM_THREADS="1", 14 | IN_MPI="1" 15 | ) 16 | args = ["mpirun", "-np", str(n)] 17 | if bind_to_core: 18 | args += ["-bind-to", "core"] 19 | args += [sys.executable] + sys.argv 20 | subprocess.check_call(args, env=env) 21 | return "parent" 22 | else: 23 | return "child" 24 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | from baselines.common import zipsame 4 | 5 | 6 | def mpi_mean(x, axis=0, comm=None, keepdims=False): 7 | x = np.asarray(x) 8 | assert x.ndim > 0 9 | if comm is None: comm = MPI.COMM_WORLD 10 | xsum = x.sum(axis=axis, keepdims=keepdims) 11 | n = xsum.size 12 | localsum = np.zeros(n+1, x.dtype) 13 | localsum[:n] = xsum.ravel() 14 | localsum[n] = x.shape[axis] 15 | # globalsum = np.zeros_like(localsum) 16 | # comm.Allreduce(localsum, globalsum, op=MPI.SUM) 17 | globalsum = comm.allreduce(localsum, op=MPI.SUM) 18 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] 19 | 20 | def mpi_moments(x, axis=0, comm=None, keepdims=False): 21 | x = np.asarray(x) 22 | assert x.ndim > 0 23 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) 24 | sqdiffs = np.square(x - mean) 25 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 26 | assert count1 == count 27 | std = np.sqrt(meansqdiff) 28 | if not keepdims: 29 | newshape = mean.shape[:axis] + mean.shape[axis+1:] 30 | mean = mean.reshape(newshape) 31 | std = std.reshape(newshape) 32 | return mean, std, count 33 | 34 | 35 | def test_runningmeanstd(): 36 | import subprocess 37 | subprocess.check_call(['mpirun', '-np', '3', 38 | 'python','-c', 39 | 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) 40 | 41 | def _helper_runningmeanstd(): 42 | comm = MPI.COMM_WORLD 43 | np.random.seed(0) 44 | for (triple,axis) in [ 45 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), 46 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), 47 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), 48 | ]: 49 | 50 | 51 | x = np.concatenate(triple, axis=axis) 52 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] 53 | 54 | 55 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) 56 | 57 | for (a1,a2) in zipsame(ms1, ms2): 58 | print(a1, a2) 59 | assert np.allclose(a1, a2) 60 | print("ok!") 61 | 62 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_running_mean_std.py: -------------------------------------------------------------------------------- 1 | try: 2 | from mpi4py import MPI 3 | except ImportError: 4 | MPI = None 5 | 6 | import tensorflow as tf, baselines.common.tf_util as U, numpy as np 7 | 8 | class RunningMeanStd(object): 9 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 10 | def __init__(self, epsilon=1e-2, shape=()): 11 | 12 | self._sum = tf.get_variable( 13 | dtype=tf.float64, 14 | shape=shape, 15 | initializer=tf.constant_initializer(0.0), 16 | name="runningsum", trainable=False) 17 | self._sumsq = tf.get_variable( 18 | dtype=tf.float64, 19 | shape=shape, 20 | initializer=tf.constant_initializer(epsilon), 21 | name="runningsumsq", trainable=False) 22 | self._count = tf.get_variable( 23 | dtype=tf.float64, 24 | shape=(), 25 | initializer=tf.constant_initializer(epsilon), 26 | name="count", trainable=False) 27 | self.shape = shape 28 | 29 | self.mean = tf.to_float(self._sum / self._count) 30 | self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 )) 31 | 32 | newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum') 33 | newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var') 34 | newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') 35 | self.incfiltparams = U.function([newsum, newsumsq, newcount], [], 36 | updates=[tf.assign_add(self._sum, newsum), 37 | tf.assign_add(self._sumsq, newsumsq), 38 | tf.assign_add(self._count, newcount)]) 39 | 40 | 41 | def update(self, x): 42 | x = x.astype('float64') 43 | n = int(np.prod(self.shape)) 44 | totalvec = np.zeros(n*2+1, 'float64') 45 | addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')]) 46 | if MPI is not None: 47 | MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) 48 | self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n]) 49 | 50 | @U.in_session 51 | def test_runningmeanstd(): 52 | for (x1, x2, x3) in [ 53 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), 54 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), 55 | ]: 56 | 57 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) 58 | U.initialize() 59 | 60 | x = np.concatenate([x1, x2, x3], axis=0) 61 | ms1 = [x.mean(axis=0), x.std(axis=0)] 62 | rms.update(x1) 63 | rms.update(x2) 64 | rms.update(x3) 65 | ms2 = [rms.mean.eval(), rms.std.eval()] 66 | 67 | assert np.allclose(ms1, ms2) 68 | 69 | @U.in_session 70 | def test_dist(): 71 | np.random.seed(0) 72 | p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1)) 73 | q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1)) 74 | 75 | # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5)) 76 | # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8)) 77 | 78 | comm = MPI.COMM_WORLD 79 | assert comm.Get_size()==2 80 | if comm.Get_rank()==0: 81 | x1,x2,x3 = p1,p2,p3 82 | elif comm.Get_rank()==1: 83 | x1,x2,x3 = q1,q2,q3 84 | else: 85 | assert False 86 | 87 | rms = RunningMeanStd(epsilon=0.0, shape=(1,)) 88 | U.initialize() 89 | 90 | rms.update(x1) 91 | rms.update(x2) 92 | rms.update(x3) 93 | 94 | bigvec = np.concatenate([p1,p2,p3,q1,q2,q3]) 95 | 96 | def checkallclose(x,y): 97 | print(x,y) 98 | return np.allclose(x,y) 99 | 100 | assert checkallclose( 101 | bigvec.mean(axis=0), 102 | rms.mean.eval(), 103 | ) 104 | assert checkallclose( 105 | bigvec.std(axis=0), 106 | rms.std.eval(), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | # Run with mpirun -np 2 python 112 | test_dist() 113 | -------------------------------------------------------------------------------- /rl/baselines/common/mpi_util.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os, numpy as np 3 | import platform 4 | import shutil 5 | import subprocess 6 | import warnings 7 | import sys 8 | 9 | try: 10 | from mpi4py import MPI 11 | except ImportError: 12 | MPI = None 13 | 14 | 15 | def sync_from_root(sess, variables, comm=None): 16 | """ 17 | Send the root node's parameters to every worker. 18 | Arguments: 19 | sess: the TensorFlow session. 20 | variables: all parameter variables including optimizer's 21 | """ 22 | if comm is None: comm = MPI.COMM_WORLD 23 | import tensorflow as tf 24 | values = comm.bcast(sess.run(variables)) 25 | sess.run([tf.assign(var, val) 26 | for (var, val) in zip(variables, values)]) 27 | 28 | def gpu_count(): 29 | """ 30 | Count the GPUs on this machine. 31 | """ 32 | if shutil.which('nvidia-smi') is None: 33 | return 0 34 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) 35 | return max(0, len(output.split(b'\n')) - 2) 36 | 37 | def setup_mpi_gpus(): 38 | """ 39 | Set CUDA_VISIBLE_DEVICES to MPI rank if not already set 40 | """ 41 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 42 | if sys.platform == 'darwin': # This Assumes if you're on OSX you're just 43 | ids = [] # doing a smoke test and don't want GPUs 44 | else: 45 | lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD) 46 | ids = [lrank] 47 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids)) 48 | 49 | def get_local_rank_size(comm): 50 | """ 51 | Returns the rank of each process on its machine 52 | The processes on a given machine will be assigned ranks 53 | 0, 1, 2, ..., N-1, 54 | where N is the number of processes on this machine. 55 | 56 | Useful if you want to assign one gpu per machine 57 | """ 58 | this_node = platform.node() 59 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node)) 60 | node2rankssofar = defaultdict(int) 61 | local_rank = None 62 | for (rank, node) in ranks_nodes: 63 | if rank == comm.Get_rank(): 64 | local_rank = node2rankssofar[node] 65 | node2rankssofar[node] += 1 66 | assert local_rank is not None 67 | return local_rank, node2rankssofar[this_node] 68 | 69 | def share_file(comm, path): 70 | """ 71 | Copies the file from rank 0 to all other ranks 72 | Puts it in the same place on all machines 73 | """ 74 | localrank, _ = get_local_rank_size(comm) 75 | if comm.Get_rank() == 0: 76 | with open(path, 'rb') as fh: 77 | data = fh.read() 78 | comm.bcast(data) 79 | else: 80 | data = comm.bcast(None) 81 | if localrank == 0: 82 | os.makedirs(os.path.dirname(path), exist_ok=True) 83 | with open(path, 'wb') as fh: 84 | fh.write(data) 85 | comm.Barrier() 86 | 87 | def dict_gather(comm, d, op='mean', assert_all_have_data=True): 88 | """ 89 | Perform a reduction operation over dicts 90 | """ 91 | if comm is None: return d 92 | alldicts = comm.allgather(d) 93 | size = comm.size 94 | k2li = defaultdict(list) 95 | for d in alldicts: 96 | for (k,v) in d.items(): 97 | k2li[k].append(v) 98 | result = {} 99 | for (k,li) in k2li.items(): 100 | if assert_all_have_data: 101 | assert len(li)==size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k) 102 | if op=='mean': 103 | result[k] = np.mean(li, axis=0) 104 | elif op=='sum': 105 | result[k] = np.sum(li, axis=0) 106 | else: 107 | assert 0, op 108 | return result 109 | 110 | def mpi_weighted_mean(comm, local_name2valcount): 111 | """ 112 | Perform a weighted average over dicts that are each on a different node 113 | Input: local_name2valcount: dict mapping key -> (value, count) 114 | Returns: key -> mean 115 | """ 116 | all_name2valcount = comm.gather(local_name2valcount) 117 | if comm.rank == 0: 118 | name2sum = defaultdict(float) 119 | name2count = defaultdict(float) 120 | for n2vc in all_name2valcount: 121 | for (name, (val, count)) in n2vc.items(): 122 | try: 123 | val = float(val) 124 | except ValueError: 125 | if comm.rank == 0: 126 | warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val)) 127 | else: 128 | name2sum[name] += val * count 129 | name2count[name] += count 130 | return {name : name2sum[name] / name2count[name] for name in name2sum} 131 | else: 132 | return {} 133 | 134 | -------------------------------------------------------------------------------- /rl/baselines/common/policies.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.common import tf_util 3 | from baselines.a2c.utils import fc 4 | from baselines.common.distributions import make_pdtype 5 | from baselines.common.input import observation_placeholder, encode_observation 6 | from baselines.common.tf_util import adjust_shape 7 | from baselines.common.mpi_running_mean_std import RunningMeanStd 8 | from baselines.common.models import get_network_builder 9 | 10 | import gym 11 | 12 | 13 | class PolicyWithValue(object): 14 | """ 15 | Encapsulates fields and methods for RL policy and value function estimation with shared parameters 16 | """ 17 | 18 | def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, **tensors): 19 | """ 20 | Parameters: 21 | ---------- 22 | env RL environment 23 | 24 | observations tensorflow placeholder in which the observations will be fed 25 | 26 | latent latent state from which policy distribution parameters should be inferred 27 | 28 | vf_latent latent state from which value function should be inferred (if None, then latent is used) 29 | 30 | sess tensorflow session to run calculations in (if None, default session is used) 31 | 32 | **tensors tensorflow tensors for additional attributes such as state or mask 33 | 34 | """ 35 | 36 | self.X = observations 37 | self.state = tf.constant([]) 38 | self.initial_state = None 39 | self.__dict__.update(tensors) 40 | 41 | vf_latent = vf_latent if vf_latent is not None else latent 42 | 43 | vf_latent = tf.layers.flatten(vf_latent) 44 | latent = tf.layers.flatten(latent) 45 | 46 | # Based on the action space, will select what probability distribution type 47 | self.pdtype = make_pdtype(env.action_space) 48 | 49 | self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01) 50 | 51 | # Take an action 52 | self.action = self.pd.sample() 53 | 54 | # Calculate the neg log of our probability 55 | self.neglogp = self.pd.neglogp(self.action) 56 | self.sess = sess or tf.get_default_session() 57 | 58 | if estimate_q: 59 | assert isinstance(env.action_space, gym.spaces.Discrete) 60 | self.q = fc(vf_latent, 'q', env.action_space.n) 61 | self.vf = self.q 62 | else: 63 | self.vf = fc(vf_latent, 'vf', 1) 64 | self.vf = self.vf[:,0] 65 | 66 | def _evaluate(self, variables, observation, **extra_feed): 67 | sess = self.sess 68 | feed_dict = {self.X: adjust_shape(self.X, observation)} 69 | for inpt_name, data in extra_feed.items(): 70 | if inpt_name in self.__dict__.keys(): 71 | inpt = self.__dict__[inpt_name] 72 | if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder': 73 | feed_dict[inpt] = adjust_shape(inpt, data) 74 | 75 | return sess.run(variables, feed_dict) 76 | 77 | def step(self, observation, **extra_feed): 78 | """ 79 | Compute next action(s) given the observation(s) 80 | 81 | Parameters: 82 | ---------- 83 | 84 | observation observation data (either single or a batch) 85 | 86 | **extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__) 87 | 88 | Returns: 89 | ------- 90 | (action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple 91 | """ 92 | 93 | a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, **extra_feed) 94 | if state.size == 0: 95 | state = None 96 | return a, v, state, neglogp 97 | 98 | def value(self, ob, *args, **kwargs): 99 | """ 100 | Compute value estimate(s) given the observation(s) 101 | 102 | Parameters: 103 | ---------- 104 | 105 | observation observation data (either single or a batch) 106 | 107 | **extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__) 108 | 109 | Returns: 110 | ------- 111 | value estimate 112 | """ 113 | return self._evaluate(self.vf, ob, *args, **kwargs) 114 | 115 | def save(self, save_path): 116 | tf_util.save_state(save_path, sess=self.sess) 117 | 118 | def load(self, load_path): 119 | tf_util.load_state(load_path, sess=self.sess) 120 | 121 | def build_policy(env, policy_network, value_network=None, normalize_observations=False, estimate_q=False, **policy_kwargs): 122 | if isinstance(policy_network, str): 123 | network_type = policy_network 124 | policy_network = get_network_builder(network_type)(**policy_kwargs) 125 | 126 | def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None): 127 | ob_space = env.observation_space 128 | 129 | X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch) 130 | 131 | extra_tensors = {} 132 | 133 | if normalize_observations and X.dtype == tf.float32: 134 | encoded_x, rms = _normalize_clip_observation(X) 135 | extra_tensors['rms'] = rms 136 | else: 137 | encoded_x = X 138 | 139 | encoded_x = encode_observation(ob_space, encoded_x) 140 | 141 | with tf.variable_scope('pi', reuse=tf.AUTO_REUSE): 142 | policy_latent = policy_network(encoded_x) 143 | if isinstance(policy_latent, tuple): 144 | policy_latent, recurrent_tensors = policy_latent 145 | 146 | if recurrent_tensors is not None: 147 | # recurrent architecture, need a few more steps 148 | nenv = nbatch // nsteps 149 | assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps) 150 | policy_latent, recurrent_tensors = policy_network(encoded_x, nenv) 151 | extra_tensors.update(recurrent_tensors) 152 | 153 | 154 | _v_net = value_network 155 | 156 | if _v_net is None or _v_net == 'shared': 157 | vf_latent = policy_latent 158 | else: 159 | if _v_net == 'copy': 160 | _v_net = policy_network 161 | else: 162 | assert callable(_v_net) 163 | 164 | with tf.variable_scope('vf', reuse=tf.AUTO_REUSE): 165 | # TODO recurrent architectures are not supported with value_network=copy yet 166 | vf_latent = _v_net(encoded_x) 167 | 168 | policy = PolicyWithValue( 169 | env=env, 170 | observations=X, 171 | latent=policy_latent, 172 | vf_latent=vf_latent, 173 | sess=sess, 174 | estimate_q=estimate_q, 175 | **extra_tensors 176 | ) 177 | return policy 178 | 179 | return policy_fn 180 | 181 | 182 | def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]): 183 | rms = RunningMeanStd(shape=x.shape[1:]) 184 | norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range)) 185 | return norm_x, rms 186 | 187 | -------------------------------------------------------------------------------- /rl/baselines/common/runners.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | 4 | class AbstractEnvRunner(ABC): 5 | def __init__(self, *, env, model, nsteps): 6 | self.env = env 7 | self.model = model 8 | self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1 9 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape 10 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) 11 | self.obs[:] = env.reset() 12 | self.nsteps = nsteps 13 | self.states = model.initial_state 14 | self.dones = [False for _ in range(nenv)] 15 | 16 | @abstractmethod 17 | def run(self): 18 | raise NotImplementedError 19 | 20 | -------------------------------------------------------------------------------- /rl/baselines/common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from baselines.common.tf_util import get_session 4 | 5 | class RunningMeanStd(object): 6 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 7 | def __init__(self, epsilon=1e-4, shape=()): 8 | self.mean = np.zeros(shape, 'float64') 9 | self.var = np.ones(shape, 'float64') 10 | self.count = epsilon 11 | 12 | def update(self, x): 13 | batch_mean = np.mean(x, axis=0) 14 | batch_var = np.var(x, axis=0) 15 | batch_count = x.shape[0] 16 | self.update_from_moments(batch_mean, batch_var, batch_count) 17 | 18 | def update_from_moments(self, batch_mean, batch_var, batch_count): 19 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 20 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 21 | 22 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 23 | delta = batch_mean - mean 24 | tot_count = count + batch_count 25 | 26 | new_mean = mean + delta * batch_count / tot_count 27 | m_a = var * count 28 | m_b = batch_var * batch_count 29 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 30 | new_var = M2 / tot_count 31 | new_count = tot_count 32 | 33 | return new_mean, new_var, new_count 34 | 35 | 36 | class TfRunningMeanStd(object): 37 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 38 | ''' 39 | TensorFlow variables-based implmentation of computing running mean and std 40 | Benefit of this implementation is that it can be saved / loaded together with the tensorflow model 41 | ''' 42 | def __init__(self, epsilon=1e-4, shape=(), scope=''): 43 | sess = get_session() 44 | 45 | self._new_mean = tf.placeholder(shape=shape, dtype=tf.float64) 46 | self._new_var = tf.placeholder(shape=shape, dtype=tf.float64) 47 | self._new_count = tf.placeholder(shape=(), dtype=tf.float64) 48 | 49 | 50 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 51 | self._mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64) 52 | self._var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64) 53 | self._count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64) 54 | 55 | self.update_ops = tf.group([ 56 | self._var.assign(self._new_var), 57 | self._mean.assign(self._new_mean), 58 | self._count.assign(self._new_count) 59 | ]) 60 | 61 | sess.run(tf.variables_initializer([self._mean, self._var, self._count])) 62 | self.sess = sess 63 | self._set_mean_var_count() 64 | 65 | def _set_mean_var_count(self): 66 | self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count]) 67 | 68 | def update(self, x): 69 | batch_mean = np.mean(x, axis=0) 70 | batch_var = np.var(x, axis=0) 71 | batch_count = x.shape[0] 72 | 73 | new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 74 | 75 | self.sess.run(self.update_ops, feed_dict={ 76 | self._new_mean: new_mean, 77 | self._new_var: new_var, 78 | self._new_count: new_count 79 | }) 80 | 81 | self._set_mean_var_count() 82 | 83 | 84 | 85 | def test_runningmeanstd(): 86 | for (x1, x2, x3) in [ 87 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), 88 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), 89 | ]: 90 | 91 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) 92 | 93 | x = np.concatenate([x1, x2, x3], axis=0) 94 | ms1 = [x.mean(axis=0), x.var(axis=0)] 95 | rms.update(x1) 96 | rms.update(x2) 97 | rms.update(x3) 98 | ms2 = [rms.mean, rms.var] 99 | 100 | np.testing.assert_allclose(ms1, ms2) 101 | 102 | def test_tf_runningmeanstd(): 103 | for (x1, x2, x3) in [ 104 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), 105 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), 106 | ]: 107 | 108 | rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128))) 109 | 110 | x = np.concatenate([x1, x2, x3], axis=0) 111 | ms1 = [x.mean(axis=0), x.var(axis=0)] 112 | rms.update(x1) 113 | rms.update(x2) 114 | rms.update(x3) 115 | ms2 = [rms.mean, rms.var] 116 | 117 | np.testing.assert_allclose(ms1, ms2) 118 | 119 | 120 | def profile_tf_runningmeanstd(): 121 | import time 122 | from baselines.common import tf_util 123 | 124 | tf_util.get_session( config=tf.ConfigProto( 125 | inter_op_parallelism_threads=1, 126 | intra_op_parallelism_threads=1, 127 | allow_soft_placement=True 128 | )) 129 | 130 | x = np.random.random((376,)) 131 | 132 | n_trials = 10000 133 | rms = RunningMeanStd() 134 | tfrms = TfRunningMeanStd() 135 | 136 | tic1 = time.time() 137 | for _ in range(n_trials): 138 | rms.update(x) 139 | 140 | tic2 = time.time() 141 | for _ in range(n_trials): 142 | tfrms.update(x) 143 | 144 | tic3 = time.time() 145 | 146 | print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1)) 147 | print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2)) 148 | 149 | 150 | tic1 = time.time() 151 | for _ in range(n_trials): 152 | z1 = rms.mean 153 | 154 | tic2 = time.time() 155 | for _ in range(n_trials): 156 | z2 = tfrms.mean 157 | 158 | assert z1 == z2 159 | 160 | tic3 = time.time() 161 | 162 | print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1)) 163 | print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2)) 164 | 165 | 166 | 167 | ''' 168 | options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101 169 | run_metadata = tf.RunMetadata() 170 | profile_opts = dict(options=options, run_metadata=run_metadata) 171 | 172 | 173 | 174 | from tensorflow.python.client import timeline 175 | fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101 176 | chrome_trace = fetched_timeline.generate_chrome_trace_format() 177 | outfile = '/tmp/timeline.json' 178 | with open(outfile, 'wt') as f: 179 | f.write(chrome_trace) 180 | print('Successfully saved profile to {}. Exiting.'.format(outfile)) 181 | exit(0) 182 | ''' 183 | 184 | 185 | 186 | if __name__ == '__main__': 187 | profile_tf_runningmeanstd() 188 | -------------------------------------------------------------------------------- /rl/baselines/common/schedules.py: -------------------------------------------------------------------------------- 1 | """This file is used for specifying various schedules that evolve over 2 | time throughout the execution of the algorithm, such as: 3 | - learning rate for the optimizer 4 | - exploration epsilon for the epsilon greedy exploration strategy 5 | - beta parameter for beta parameter in prioritized replay 6 | 7 | Each schedule has a function `value(t)` which returns the current value 8 | of the parameter given the timestep t of the optimization procedure. 9 | """ 10 | 11 | 12 | class Schedule(object): 13 | def value(self, t): 14 | """Value of the schedule at time t""" 15 | raise NotImplementedError() 16 | 17 | 18 | class ConstantSchedule(object): 19 | def __init__(self, value): 20 | """Value remains constant over time. 21 | 22 | Parameters 23 | ---------- 24 | value: float 25 | Constant value of the schedule 26 | """ 27 | self._v = value 28 | 29 | def value(self, t): 30 | """See Schedule.value""" 31 | return self._v 32 | 33 | 34 | def linear_interpolation(l, r, alpha): 35 | return l + alpha * (r - l) 36 | 37 | 38 | class PiecewiseSchedule(object): 39 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): 40 | """Piecewise schedule. 41 | 42 | endpoints: [(int, int)] 43 | list of pairs `(time, value)` meanining that schedule should output 44 | `value` when `t==time`. All the values for time must be sorted in 45 | an increasing order. When t is between two times, e.g. `(time_a, value_a)` 46 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs 47 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of 48 | time passed between `time_a` and `time_b` for time `t`. 49 | interpolation: lambda float, float, float: float 50 | a function that takes value to the left and to the right of t according 51 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to 52 | right endpoint that t has covered. See linear_interpolation for example. 53 | outside_value: float 54 | if the value is requested outside of all the intervals sepecified in 55 | `endpoints` this value is returned. If None then AssertionError is 56 | raised when outside value is requested. 57 | """ 58 | idxes = [e[0] for e in endpoints] 59 | assert idxes == sorted(idxes) 60 | self._interpolation = interpolation 61 | self._outside_value = outside_value 62 | self._endpoints = endpoints 63 | 64 | def value(self, t): 65 | """See Schedule.value""" 66 | for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): 67 | if l_t <= t and t < r_t: 68 | alpha = float(t - l_t) / (r_t - l_t) 69 | return self._interpolation(l, r, alpha) 70 | 71 | # t does not belong to any of the pieces, so doom. 72 | assert self._outside_value is not None 73 | return self._outside_value 74 | 75 | 76 | class LinearSchedule(object): 77 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 78 | """Linear interpolation between initial_p and final_p over 79 | schedule_timesteps. After this many timesteps pass final_p is 80 | returned. 81 | 82 | Parameters 83 | ---------- 84 | schedule_timesteps: int 85 | Number of timesteps for which to linearly anneal initial_p 86 | to final_p 87 | initial_p: float 88 | initial output value 89 | final_p: float 90 | final output value 91 | """ 92 | self.schedule_timesteps = schedule_timesteps 93 | self.final_p = final_p 94 | self.initial_p = initial_p 95 | 96 | def value(self, t): 97 | """See Schedule.value""" 98 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 99 | return self.initial_p + fraction * (self.final_p - self.initial_p) 100 | -------------------------------------------------------------------------------- /rl/baselines/common/segment_tree.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | 4 | class SegmentTree(object): 5 | def __init__(self, capacity, operation, neutral_element): 6 | """Build a Segment Tree data structure. 7 | 8 | https://en.wikipedia.org/wiki/Segment_tree 9 | 10 | Can be used as regular array, but with two 11 | important differences: 12 | 13 | a) setting item's value is slightly slower. 14 | It is O(lg capacity) instead of O(1). 15 | b) user has access to an efficient ( O(log segment size) ) 16 | `reduce` operation which reduces `operation` over 17 | a contiguous subsequence of items in the array. 18 | 19 | Paramters 20 | --------- 21 | capacity: int 22 | Total size of the array - must be a power of two. 23 | operation: lambda obj, obj -> obj 24 | and operation for combining elements (eg. sum, max) 25 | must form a mathematical group together with the set of 26 | possible values for array elements (i.e. be associative) 27 | neutral_element: obj 28 | neutral element for the operation above. eg. float('-inf') 29 | for max and 0 for sum. 30 | """ 31 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 32 | self._capacity = capacity 33 | self._value = [neutral_element for _ in range(2 * capacity)] 34 | self._operation = operation 35 | 36 | def _reduce_helper(self, start, end, node, node_start, node_end): 37 | if start == node_start and end == node_end: 38 | return self._value[node] 39 | mid = (node_start + node_end) // 2 40 | if end <= mid: 41 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 42 | else: 43 | if mid + 1 <= start: 44 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 45 | else: 46 | return self._operation( 47 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 48 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 49 | ) 50 | 51 | def reduce(self, start=0, end=None): 52 | """Returns result of applying `self.operation` 53 | to a contiguous subsequence of the array. 54 | 55 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 56 | 57 | Parameters 58 | ---------- 59 | start: int 60 | beginning of the subsequence 61 | end: int 62 | end of the subsequences 63 | 64 | Returns 65 | ------- 66 | reduced: obj 67 | result of reducing self.operation over the specified range of array elements. 68 | """ 69 | if end is None: 70 | end = self._capacity 71 | if end < 0: 72 | end += self._capacity 73 | end -= 1 74 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 75 | 76 | def __setitem__(self, idx, val): 77 | # index of the leaf 78 | idx += self._capacity 79 | self._value[idx] = val 80 | idx //= 2 81 | while idx >= 1: 82 | self._value[idx] = self._operation( 83 | self._value[2 * idx], 84 | self._value[2 * idx + 1] 85 | ) 86 | idx //= 2 87 | 88 | def __getitem__(self, idx): 89 | assert 0 <= idx < self._capacity 90 | return self._value[self._capacity + idx] 91 | 92 | 93 | class SumSegmentTree(SegmentTree): 94 | def __init__(self, capacity): 95 | super(SumSegmentTree, self).__init__( 96 | capacity=capacity, 97 | operation=operator.add, 98 | neutral_element=0.0 99 | ) 100 | 101 | def sum(self, start=0, end=None): 102 | """Returns arr[start] + ... + arr[end]""" 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | """Find the highest index `i` in the array such that 107 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 108 | 109 | if array values are probabilities, this function 110 | allows to sample indexes according to the discrete 111 | probability efficiently. 112 | 113 | Parameters 114 | ---------- 115 | perfixsum: float 116 | upperbound on the sum of array prefix 117 | 118 | Returns 119 | ------- 120 | idx: int 121 | highest index satisfying the prefixsum constraint 122 | """ 123 | assert 0 <= prefixsum <= self.sum() + 1e-5 124 | idx = 1 125 | while idx < self._capacity: # while non-leaf 126 | if self._value[2 * idx] > prefixsum: 127 | idx = 2 * idx 128 | else: 129 | prefixsum -= self._value[2 * idx] 130 | idx = 2 * idx + 1 131 | return idx - self._capacity 132 | 133 | 134 | class MinSegmentTree(SegmentTree): 135 | def __init__(self, capacity): 136 | super(MinSegmentTree, self).__init__( 137 | capacity=capacity, 138 | operation=min, 139 | neutral_element=float('inf') 140 | ) 141 | 142 | def min(self, start=0, end=None): 143 | """Returns min(arr[start], ..., arr[end])""" 144 | 145 | return super(MinSegmentTree, self).reduce(start, end) 146 | -------------------------------------------------------------------------------- /rl/baselines/common/test_mpi_util.py: -------------------------------------------------------------------------------- 1 | from baselines.common import mpi_util 2 | from baselines import logger 3 | from baselines.common.tests.test_with_mpi import with_mpi 4 | try: 5 | from mpi4py import MPI 6 | except ImportError: 7 | MPI = None 8 | 9 | @with_mpi() 10 | def test_mpi_weighted_mean(): 11 | comm = MPI.COMM_WORLD 12 | with logger.scoped_configure(comm=comm): 13 | if comm.rank == 0: 14 | name2valcount = {'a' : (10, 2), 'b' : (20,3)} 15 | elif comm.rank == 1: 16 | name2valcount = {'a' : (19, 1), 'c' : (42,3)} 17 | else: 18 | raise NotImplementedError 19 | d = mpi_util.mpi_weighted_mean(comm, name2valcount) 20 | correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42} 21 | if comm.rank == 0: 22 | assert d == correctval, '{} != {}'.format(d, correctval) 23 | 24 | for name, (val, count) in name2valcount.items(): 25 | for _ in range(count): 26 | logger.logkv_mean(name, val) 27 | d2 = logger.dumpkvs() 28 | if comm.rank == 0: 29 | assert d2 == correctval 30 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os, pytest 2 | mark_slow = pytest.mark.skipif(not os.getenv('RUNSLOW'), reason='slow') -------------------------------------------------------------------------------- /rl/baselines/common/tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/rl/baselines/common/tests/envs/__init__.py -------------------------------------------------------------------------------- /rl/baselines/common/tests/envs/fixed_sequence_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Env 3 | from gym.spaces import Discrete 4 | 5 | 6 | class FixedSequenceEnv(Env): 7 | def __init__( 8 | self, 9 | n_actions=10, 10 | episode_len=100 11 | ): 12 | self.action_space = Discrete(n_actions) 13 | self.observation_space = Discrete(1) 14 | self.np_random = np.random.RandomState(0) 15 | self.episode_len = episode_len 16 | self.sequence = [self.np_random.randint(0, self.action_space.n) 17 | for _ in range(self.episode_len)] 18 | self.time = 0 19 | 20 | 21 | def reset(self): 22 | self.time = 0 23 | return 0 24 | 25 | def step(self, actions): 26 | rew = self._get_reward(actions) 27 | self._choose_next_state() 28 | done = False 29 | if self.episode_len and self.time >= self.episode_len: 30 | done = True 31 | 32 | return 0, rew, done, {} 33 | 34 | def seed(self, seed=None): 35 | self.np_random.seed(seed) 36 | 37 | def _choose_next_state(self): 38 | self.time += 1 39 | 40 | def _get_reward(self, actions): 41 | return 1 if actions == self.sequence[self.time] else 0 42 | 43 | 44 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/envs/identity_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import abstractmethod 3 | from gym import Env 4 | from gym.spaces import MultiDiscrete, Discrete, Box 5 | from collections import deque 6 | 7 | class IdentityEnv(Env): 8 | def __init__( 9 | self, 10 | episode_len=None, 11 | delay=0, 12 | zero_first_rewards=True 13 | ): 14 | 15 | self.observation_space = self.action_space 16 | self.episode_len = episode_len 17 | self.time = 0 18 | self.delay = delay 19 | self.zero_first_rewards = zero_first_rewards 20 | self.q = deque(maxlen=delay+1) 21 | 22 | def reset(self): 23 | self.q.clear() 24 | for _ in range(self.delay + 1): 25 | self.q.append(self.action_space.sample()) 26 | self.time = 0 27 | 28 | return self.q[-1] 29 | 30 | def step(self, actions): 31 | rew = self._get_reward(self.q.popleft(), actions) 32 | if self.zero_first_rewards and self.time < self.delay: 33 | rew = 0 34 | self.q.append(self.action_space.sample()) 35 | self.time += 1 36 | done = self.episode_len is not None and self.time >= self.episode_len 37 | return self.q[-1], rew, done, {} 38 | 39 | def seed(self, seed=None): 40 | self.action_space.seed(seed) 41 | 42 | @abstractmethod 43 | def _get_reward(self, state, actions): 44 | raise NotImplementedError 45 | 46 | 47 | class DiscreteIdentityEnv(IdentityEnv): 48 | def __init__( 49 | self, 50 | dim, 51 | episode_len=None, 52 | delay=0, 53 | zero_first_rewards=True 54 | ): 55 | 56 | self.action_space = Discrete(dim) 57 | super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards) 58 | 59 | def _get_reward(self, state, actions): 60 | return 1 if state == actions else 0 61 | 62 | class MultiDiscreteIdentityEnv(IdentityEnv): 63 | def __init__( 64 | self, 65 | dims, 66 | episode_len=None, 67 | delay=0, 68 | ): 69 | 70 | self.action_space = MultiDiscrete(dims) 71 | super().__init__(episode_len=episode_len, delay=delay) 72 | 73 | def _get_reward(self, state, actions): 74 | return 1 if all(state == actions) else 0 75 | 76 | 77 | class BoxIdentityEnv(IdentityEnv): 78 | def __init__( 79 | self, 80 | shape, 81 | episode_len=None, 82 | ): 83 | 84 | self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32) 85 | super().__init__(episode_len=episode_len) 86 | 87 | def _get_reward(self, state, actions): 88 | diff = actions - state 89 | diff = diff[:] 90 | return -0.5 * np.dot(diff, diff) 91 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/envs/identity_env_test.py: -------------------------------------------------------------------------------- 1 | from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv 2 | 3 | 4 | def test_discrete_nodelay(): 5 | nsteps = 100 6 | eplen = 50 7 | env = DiscreteIdentityEnv(10, episode_len=eplen) 8 | ob = env.reset() 9 | for t in range(nsteps): 10 | action = env.action_space.sample() 11 | next_ob, rew, done, info = env.step(action) 12 | assert rew == (1 if action == ob else 0) 13 | if (t + 1) % eplen == 0: 14 | assert done 15 | next_ob = env.reset() 16 | else: 17 | assert not done 18 | ob = next_ob 19 | 20 | def test_discrete_delay1(): 21 | eplen = 50 22 | env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1) 23 | ob = env.reset() 24 | prev_ob = None 25 | for t in range(eplen): 26 | action = env.action_space.sample() 27 | next_ob, rew, done, info = env.step(action) 28 | if t > 0: 29 | assert rew == (1 if action == prev_ob else 0) 30 | else: 31 | assert rew == 0 32 | prev_ob = ob 33 | ob = next_ob 34 | if t < eplen - 1: 35 | assert not done 36 | assert done 37 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/envs/mnist_env.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import tempfile 4 | from gym import Env 5 | from gym.spaces import Discrete, Box 6 | 7 | 8 | 9 | class MnistEnv(Env): 10 | def __init__( 11 | self, 12 | episode_len=None, 13 | no_images=None 14 | ): 15 | import filelock 16 | from tensorflow.examples.tutorials.mnist import input_data 17 | # we could use temporary directory for this with a context manager and 18 | # TemporaryDirecotry, but then each test that uses mnist would re-download the data 19 | # this way the data is not cleaned up, but we only download it once per machine 20 | mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data') 21 | with filelock.FileLock(mnist_path + '.lock'): 22 | self.mnist = input_data.read_data_sets(mnist_path) 23 | 24 | self.np_random = np.random.RandomState() 25 | 26 | self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1)) 27 | self.action_space = Discrete(10) 28 | self.episode_len = episode_len 29 | self.time = 0 30 | self.no_images = no_images 31 | 32 | self.train_mode() 33 | self.reset() 34 | 35 | def reset(self): 36 | self._choose_next_state() 37 | self.time = 0 38 | 39 | return self.state[0] 40 | 41 | def step(self, actions): 42 | rew = self._get_reward(actions) 43 | self._choose_next_state() 44 | done = False 45 | if self.episode_len and self.time >= self.episode_len: 46 | rew = 0 47 | done = True 48 | 49 | return self.state[0], rew, done, {} 50 | 51 | def seed(self, seed=None): 52 | self.np_random.seed(seed) 53 | 54 | def train_mode(self): 55 | self.dataset = self.mnist.train 56 | 57 | def test_mode(self): 58 | self.dataset = self.mnist.test 59 | 60 | def _choose_next_state(self): 61 | max_index = (self.no_images if self.no_images is not None else self.dataset.num_examples) - 1 62 | index = self.np_random.randint(0, max_index) 63 | image = self.dataset.images[index].reshape(28,28,1)*255 64 | label = self.dataset.labels[index] 65 | self.state = (image, label) 66 | self.time += 1 67 | 68 | def _get_reward(self, actions): 69 | return 1 if self.state[1] == actions else 0 70 | 71 | 72 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_cartpole.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | 4 | from baselines.run import get_learn_function 5 | from baselines.common.tests.util import reward_per_episode_test 6 | from baselines.common.tests import mark_slow 7 | 8 | common_kwargs = dict( 9 | total_timesteps=30000, 10 | network='mlp', 11 | gamma=1.0, 12 | seed=0, 13 | ) 14 | 15 | learn_kwargs = { 16 | 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), 17 | 'acer': dict(value_network='copy'), 18 | 'acktr': dict(nsteps=32, value_network='copy', is_async=False), 19 | 'deepq': dict(total_timesteps=20000), 20 | 'ppo2': dict(value_network='copy'), 21 | 'trpo_mpi': {} 22 | } 23 | 24 | @mark_slow 25 | @pytest.mark.parametrize("alg", learn_kwargs.keys()) 26 | def test_cartpole(alg): 27 | ''' 28 | Test if the algorithm (with an mlp policy) 29 | can learn to balance the cartpole 30 | ''' 31 | 32 | kwargs = common_kwargs.copy() 33 | kwargs.update(learn_kwargs[alg]) 34 | 35 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 36 | def env_fn(): 37 | 38 | env = gym.make('CartPole-v0') 39 | env.seed(0) 40 | return env 41 | 42 | reward_per_episode_test(env_fn, learn_fn, 100) 43 | 44 | if __name__ == '__main__': 45 | test_cartpole('acer') 46 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_doc_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | try: 3 | import mujoco_py 4 | _mujoco_present = True 5 | except BaseException: 6 | mujoco_py = None 7 | _mujoco_present = False 8 | 9 | 10 | @pytest.mark.skipif( 11 | not _mujoco_present, 12 | reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library' 13 | ) 14 | def test_lstm_example(): 15 | import tensorflow as tf 16 | from baselines.common import policies, models, cmd_util 17 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 18 | 19 | # create vectorized environment 20 | venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)]) 21 | 22 | with tf.Session() as sess: 23 | # build policy based on lstm network with 128 units 24 | policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1) 25 | 26 | # initialize tensorflow variables 27 | sess.run(tf.global_variables_initializer()) 28 | 29 | # prepare environment variables 30 | ob = venv.reset() 31 | state = policy.initial_state 32 | done = [False] 33 | step_counter = 0 34 | 35 | # run a single episode until the end (i.e. until done) 36 | while True: 37 | action, _, state, _ = policy.step(ob, S=state, M=done) 38 | ob, reward, done, _ = venv.step(action) 39 | step_counter += 1 40 | if done: 41 | break 42 | 43 | 44 | assert step_counter > 5 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_env_after_learn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | import tensorflow as tf 4 | 5 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 6 | from baselines.run import get_learn_function 7 | from baselines.common.tf_util import make_session 8 | 9 | algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] 10 | 11 | @pytest.mark.parametrize('algo', algos) 12 | def test_env_after_learn(algo): 13 | def make_env(): 14 | # acktr requires too much RAM, fails on travis 15 | env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4') 16 | return env 17 | 18 | make_session(make_default=True, graph=tf.Graph()) 19 | env = SubprocVecEnv([make_env]) 20 | 21 | learn = get_learn_function(algo) 22 | 23 | # Commenting out the following line resolves the issue, though crash happens at env.reset(). 24 | learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None) 25 | 26 | env.reset() 27 | env.close() 28 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_fetchreach.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | 4 | from baselines.run import get_learn_function 5 | from baselines.common.tests.util import reward_per_episode_test 6 | from baselines.common.tests import mark_slow 7 | 8 | pytest.importorskip('mujoco_py') 9 | 10 | common_kwargs = dict( 11 | network='mlp', 12 | seed=0, 13 | ) 14 | 15 | learn_kwargs = { 16 | 'her': dict(total_timesteps=2000) 17 | } 18 | 19 | @mark_slow 20 | @pytest.mark.parametrize("alg", learn_kwargs.keys()) 21 | def test_fetchreach(alg): 22 | ''' 23 | Test if the algorithm (with an mlp policy) 24 | can learn the FetchReach task 25 | ''' 26 | 27 | kwargs = common_kwargs.copy() 28 | kwargs.update(learn_kwargs[alg]) 29 | 30 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 31 | def env_fn(): 32 | 33 | env = gym.make('FetchReach-v1') 34 | env.seed(0) 35 | return env 36 | 37 | reward_per_episode_test(env_fn, learn_fn, -15) 38 | 39 | if __name__ == '__main__': 40 | test_fetchreach('her') 41 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_fixed_sequence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from baselines.common.tests.envs.fixed_sequence_env import FixedSequenceEnv 3 | 4 | from baselines.common.tests.util import simple_test 5 | from baselines.run import get_learn_function 6 | from baselines.common.tests import mark_slow 7 | 8 | 9 | common_kwargs = dict( 10 | seed=0, 11 | total_timesteps=50000, 12 | ) 13 | 14 | learn_kwargs = { 15 | 'a2c': {}, 16 | 'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1), 17 | # TODO enable sequential models for trpo_mpi (proper handling of nbatch and nsteps) 18 | # github issue: https://github.com/openai/baselines/issues/188 19 | # 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001) 20 | } 21 | 22 | 23 | alg_list = learn_kwargs.keys() 24 | rnn_list = ['lstm'] 25 | 26 | @mark_slow 27 | @pytest.mark.parametrize("alg", alg_list) 28 | @pytest.mark.parametrize("rnn", rnn_list) 29 | def test_fixed_sequence(alg, rnn): 30 | ''' 31 | Test if the algorithm (with a given policy) 32 | can learn an identity transformation (i.e. return observation as an action) 33 | ''' 34 | 35 | kwargs = learn_kwargs[alg] 36 | kwargs.update(common_kwargs) 37 | 38 | env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5) 39 | learn = lambda e: get_learn_function(alg)( 40 | env=e, 41 | network=rnn, 42 | **kwargs 43 | ) 44 | 45 | simple_test(env_fn, learn, 0.7) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_fixed_sequence('ppo2', 'lstm') 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_identity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv 3 | from baselines.run import get_learn_function 4 | from baselines.common.tests.util import simple_test 5 | from baselines.common.tests import mark_slow 6 | 7 | common_kwargs = dict( 8 | total_timesteps=30000, 9 | network='mlp', 10 | gamma=0.9, 11 | seed=0, 12 | ) 13 | 14 | learn_kwargs = { 15 | 'a2c' : {}, 16 | 'acktr': {}, 17 | 'deepq': {}, 18 | 'ddpg': dict(layer_norm=True), 19 | 'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0), 20 | 'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01) 21 | } 22 | 23 | 24 | algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] 25 | algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi'] 26 | algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi'] 27 | 28 | @mark_slow 29 | @pytest.mark.parametrize("alg", algos_disc) 30 | def test_discrete_identity(alg): 31 | ''' 32 | Test if the algorithm (with an mlp policy) 33 | can learn an identity transformation (i.e. return observation as an action) 34 | ''' 35 | 36 | kwargs = learn_kwargs[alg] 37 | kwargs.update(common_kwargs) 38 | 39 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 40 | env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100) 41 | simple_test(env_fn, learn_fn, 0.9) 42 | 43 | @mark_slow 44 | @pytest.mark.parametrize("alg", algos_multidisc) 45 | def test_multidiscrete_identity(alg): 46 | ''' 47 | Test if the algorithm (with an mlp policy) 48 | can learn an identity transformation (i.e. return observation as an action) 49 | ''' 50 | 51 | kwargs = learn_kwargs[alg] 52 | kwargs.update(common_kwargs) 53 | 54 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 55 | env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100) 56 | simple_test(env_fn, learn_fn, 0.9) 57 | 58 | @mark_slow 59 | @pytest.mark.parametrize("alg", algos_cont) 60 | def test_continuous_identity(alg): 61 | ''' 62 | Test if the algorithm (with an mlp policy) 63 | can learn an identity transformation (i.e. return observation as an action) 64 | to a required precision 65 | ''' 66 | 67 | kwargs = learn_kwargs[alg] 68 | kwargs.update(common_kwargs) 69 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 70 | 71 | env_fn = lambda: BoxIdentityEnv((1,), episode_len=100) 72 | simple_test(env_fn, learn_fn, -0.1) 73 | 74 | if __name__ == '__main__': 75 | test_multidiscrete_identity('acktr') 76 | 77 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | # from baselines.acer import acer_simple as acer 4 | from baselines.common.tests.envs.mnist_env import MnistEnv 5 | from baselines.common.tests.util import simple_test 6 | from baselines.run import get_learn_function 7 | from baselines.common.tests import mark_slow 8 | 9 | # TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem? 10 | # GitHub issue https://github.com/openai/baselines/issues/189 11 | common_kwargs = { 12 | 'seed': 0, 13 | 'network':'cnn', 14 | 'gamma':0.9, 15 | 'pad':'SAME' 16 | } 17 | 18 | learn_args = { 19 | 'a2c': dict(total_timesteps=50000), 20 | 'acer': dict(total_timesteps=20000), 21 | 'deepq': dict(total_timesteps=5000), 22 | 'acktr': dict(total_timesteps=30000), 23 | 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), 24 | 'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001) 25 | } 26 | 27 | 28 | #tests pass, but are too slow on travis. Same algorithms are covered 29 | # by other tests with less compute-hungry nn's and by benchmarks 30 | @pytest.mark.skip 31 | @mark_slow 32 | @pytest.mark.parametrize("alg", learn_args.keys()) 33 | def test_mnist(alg): 34 | ''' 35 | Test if the algorithm can learn to classify MNIST digits. 36 | Uses CNN policy. 37 | ''' 38 | 39 | learn_kwargs = learn_args[alg] 40 | learn_kwargs.update(common_kwargs) 41 | 42 | learn = get_learn_function(alg) 43 | learn_fn = lambda e: learn(env=e, **learn_kwargs) 44 | env_fn = lambda: MnistEnv(episode_len=100) 45 | 46 | simple_test(env_fn, learn_fn, 0.6) 47 | 48 | if __name__ == '__main__': 49 | test_mnist('acer') 50 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_plot_util.py: -------------------------------------------------------------------------------- 1 | # smoke tests of plot_util 2 | from baselines.common import plot_util as pu 3 | from baselines.common.tests.util import smoketest 4 | 5 | 6 | def test_plot_util(): 7 | nruns = 4 8 | logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)] 9 | data = pu.load_results(logdirs) 10 | assert len(data) == 4 11 | 12 | _, axes = pu.plot_results(data[:1]); assert len(axes) == 1 13 | _, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1) 14 | _, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4) 15 | _, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2) 16 | _, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1 17 | 18 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from baselines.common.schedules import ConstantSchedule, PiecewiseSchedule 4 | 5 | 6 | def test_piecewise_schedule(): 7 | ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500) 8 | 9 | assert np.isclose(ps.value(-10), 500) 10 | assert np.isclose(ps.value(0), 150) 11 | assert np.isclose(ps.value(5), 200) 12 | assert np.isclose(ps.value(9), 80) 13 | assert np.isclose(ps.value(50), 50) 14 | assert np.isclose(ps.value(80), 50) 15 | assert np.isclose(ps.value(150), 0) 16 | assert np.isclose(ps.value(175), -25) 17 | assert np.isclose(ps.value(201), 500) 18 | assert np.isclose(ps.value(500), 500) 19 | 20 | assert np.isclose(ps.value(200 - 1e-10), -50) 21 | 22 | 23 | def test_constant_schedule(): 24 | cs = ConstantSchedule(5) 25 | for i in range(-100, 100): 26 | assert np.isclose(cs.value(i), 5) 27 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_segment_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree 4 | 5 | 6 | def test_tree_set(): 7 | tree = SumSegmentTree(4) 8 | 9 | tree[2] = 1.0 10 | tree[3] = 3.0 11 | 12 | assert np.isclose(tree.sum(), 4.0) 13 | assert np.isclose(tree.sum(0, 2), 0.0) 14 | assert np.isclose(tree.sum(0, 3), 1.0) 15 | assert np.isclose(tree.sum(2, 3), 1.0) 16 | assert np.isclose(tree.sum(2, -1), 1.0) 17 | assert np.isclose(tree.sum(2, 4), 4.0) 18 | 19 | 20 | def test_tree_set_overlap(): 21 | tree = SumSegmentTree(4) 22 | 23 | tree[2] = 1.0 24 | tree[2] = 3.0 25 | 26 | assert np.isclose(tree.sum(), 3.0) 27 | assert np.isclose(tree.sum(2, 3), 3.0) 28 | assert np.isclose(tree.sum(2, -1), 3.0) 29 | assert np.isclose(tree.sum(2, 4), 3.0) 30 | assert np.isclose(tree.sum(1, 2), 0.0) 31 | 32 | 33 | def test_prefixsum_idx(): 34 | tree = SumSegmentTree(4) 35 | 36 | tree[2] = 1.0 37 | tree[3] = 3.0 38 | 39 | assert tree.find_prefixsum_idx(0.0) == 2 40 | assert tree.find_prefixsum_idx(0.5) == 2 41 | assert tree.find_prefixsum_idx(0.99) == 2 42 | assert tree.find_prefixsum_idx(1.01) == 3 43 | assert tree.find_prefixsum_idx(3.00) == 3 44 | assert tree.find_prefixsum_idx(4.00) == 3 45 | 46 | 47 | def test_prefixsum_idx2(): 48 | tree = SumSegmentTree(4) 49 | 50 | tree[0] = 0.5 51 | tree[1] = 1.0 52 | tree[2] = 1.0 53 | tree[3] = 3.0 54 | 55 | assert tree.find_prefixsum_idx(0.00) == 0 56 | assert tree.find_prefixsum_idx(0.55) == 1 57 | assert tree.find_prefixsum_idx(0.99) == 1 58 | assert tree.find_prefixsum_idx(1.51) == 2 59 | assert tree.find_prefixsum_idx(3.00) == 3 60 | assert tree.find_prefixsum_idx(5.50) == 3 61 | 62 | 63 | def test_max_interval_tree(): 64 | tree = MinSegmentTree(4) 65 | 66 | tree[0] = 1.0 67 | tree[2] = 0.5 68 | tree[3] = 3.0 69 | 70 | assert np.isclose(tree.min(), 0.5) 71 | assert np.isclose(tree.min(0, 2), 1.0) 72 | assert np.isclose(tree.min(0, 3), 0.5) 73 | assert np.isclose(tree.min(0, -1), 0.5) 74 | assert np.isclose(tree.min(2, 4), 0.5) 75 | assert np.isclose(tree.min(3, 4), 3.0) 76 | 77 | tree[2] = 0.7 78 | 79 | assert np.isclose(tree.min(), 0.7) 80 | assert np.isclose(tree.min(0, 2), 1.0) 81 | assert np.isclose(tree.min(0, 3), 0.7) 82 | assert np.isclose(tree.min(0, -1), 0.7) 83 | assert np.isclose(tree.min(2, 4), 0.7) 84 | assert np.isclose(tree.min(3, 4), 3.0) 85 | 86 | tree[2] = 4.0 87 | 88 | assert np.isclose(tree.min(), 1.0) 89 | assert np.isclose(tree.min(0, 2), 1.0) 90 | assert np.isclose(tree.min(0, 3), 1.0) 91 | assert np.isclose(tree.min(0, -1), 1.0) 92 | assert np.isclose(tree.min(2, 4), 3.0) 93 | assert np.isclose(tree.min(2, 3), 4.0) 94 | assert np.isclose(tree.min(2, -1), 4.0) 95 | assert np.isclose(tree.min(3, 4), 3.0) 96 | 97 | 98 | if __name__ == '__main__': 99 | test_tree_set() 100 | test_tree_set_overlap() 101 | test_prefixsum_idx() 102 | test_prefixsum_idx2() 103 | test_max_interval_tree() 104 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import tempfile 4 | import pytest 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from baselines.common.tests.envs.mnist_env import MnistEnv 9 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 10 | from baselines.run import get_learn_function 11 | from baselines.common.tf_util import make_session, get_session 12 | 13 | from functools import partial 14 | 15 | 16 | learn_kwargs = { 17 | 'deepq': {}, 18 | 'a2c': {}, 19 | 'acktr': {}, 20 | 'acer': {}, 21 | 'ppo2': {'nminibatches': 1, 'nsteps': 10}, 22 | 'trpo_mpi': {}, 23 | } 24 | 25 | network_kwargs = { 26 | 'mlp': {}, 27 | 'cnn': {'pad': 'SAME'}, 28 | 'lstm': {}, 29 | 'cnn_lnlstm': {'pad': 'SAME'} 30 | } 31 | 32 | 33 | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) 34 | @pytest.mark.parametrize("network_fn", network_kwargs.keys()) 35 | def test_serialization(learn_fn, network_fn): 36 | ''' 37 | Test if the trained model can be serialized 38 | ''' 39 | 40 | 41 | if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']: 42 | # TODO make acktr work with recurrent policies 43 | # and test 44 | # github issue: https://github.com/openai/baselines/issues/660 45 | return 46 | 47 | def make_env(): 48 | env = MnistEnv(episode_len=100) 49 | env.seed(10) 50 | return env 51 | 52 | env = DummyVecEnv([make_env]) 53 | ob = env.reset().copy() 54 | learn = get_learn_function(learn_fn) 55 | 56 | kwargs = {} 57 | kwargs.update(network_kwargs[network_fn]) 58 | kwargs.update(learn_kwargs[learn_fn]) 59 | 60 | 61 | learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs) 62 | 63 | with tempfile.TemporaryDirectory() as td: 64 | model_path = os.path.join(td, 'serialization_test_model') 65 | 66 | with tf.Graph().as_default(), make_session().as_default(): 67 | model = learn(total_timesteps=100) 68 | model.save(model_path) 69 | mean1, std1 = _get_action_stats(model, ob) 70 | variables_dict1 = _serialize_variables() 71 | 72 | with tf.Graph().as_default(), make_session().as_default(): 73 | model = learn(total_timesteps=0, load_path=model_path) 74 | mean2, std2 = _get_action_stats(model, ob) 75 | variables_dict2 = _serialize_variables() 76 | 77 | for k, v in variables_dict1.items(): 78 | np.testing.assert_allclose(v, variables_dict2[k], atol=0.01, 79 | err_msg='saved and loaded variable {} value mismatch'.format(k)) 80 | 81 | np.testing.assert_allclose(mean1, mean2, atol=0.5) 82 | np.testing.assert_allclose(std1, std2, atol=0.5) 83 | 84 | 85 | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) 86 | @pytest.mark.parametrize("network_fn", ['mlp']) 87 | def test_coexistence(learn_fn, network_fn): 88 | ''' 89 | Test if more than one model can exist at a time 90 | ''' 91 | 92 | if learn_fn == 'deepq': 93 | # TODO enable multiple DQN models to be useable at the same time 94 | # github issue https://github.com/openai/baselines/issues/656 95 | return 96 | 97 | if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: 98 | # TODO make acktr work with recurrent policies 99 | # and test 100 | # github issue: https://github.com/openai/baselines/issues/660 101 | return 102 | 103 | env = DummyVecEnv([lambda: gym.make('CartPole-v0')]) 104 | learn = get_learn_function(learn_fn) 105 | 106 | kwargs = {} 107 | kwargs.update(network_kwargs[network_fn]) 108 | kwargs.update(learn_kwargs[learn_fn]) 109 | 110 | learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs) 111 | make_session(make_default=True, graph=tf.Graph()) 112 | model1 = learn(seed=1) 113 | make_session(make_default=True, graph=tf.Graph()) 114 | model2 = learn(seed=2) 115 | 116 | model1.step(env.observation_space.sample()) 117 | model2.step(env.observation_space.sample()) 118 | 119 | 120 | 121 | def _serialize_variables(): 122 | sess = get_session() 123 | variables = tf.trainable_variables() 124 | values = sess.run(variables) 125 | return {var.name: value for var, value in zip(variables, values)} 126 | 127 | 128 | def _get_action_stats(model, ob): 129 | ntrials = 1000 130 | if model.initial_state is None or model.initial_state == []: 131 | actions = np.array([model.step(ob)[0] for _ in range(ntrials)]) 132 | else: 133 | actions = np.array([model.step(ob, S=model.initial_state, M=[False])[0] for _ in range(ntrials)]) 134 | 135 | mean = np.mean(actions, axis=0) 136 | std = np.std(actions, axis=0) 137 | 138 | return mean, std 139 | 140 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_tf_util.py: -------------------------------------------------------------------------------- 1 | # tests for tf_util 2 | import tensorflow as tf 3 | from baselines.common.tf_util import ( 4 | function, 5 | initialize, 6 | single_threaded_session 7 | ) 8 | 9 | 10 | def test_function(): 11 | with tf.Graph().as_default(): 12 | x = tf.placeholder(tf.int32, (), name="x") 13 | y = tf.placeholder(tf.int32, (), name="y") 14 | z = 3 * x + 2 * y 15 | lin = function([x, y], z, givens={y: 0}) 16 | 17 | with single_threaded_session(): 18 | initialize() 19 | 20 | assert lin(2) == 6 21 | assert lin(x=3) == 9 22 | assert lin(2, 2) == 10 23 | assert lin(x=2, y=3) == 12 24 | 25 | 26 | def test_multikwargs(): 27 | with tf.Graph().as_default(): 28 | x = tf.placeholder(tf.int32, (), name="x") 29 | with tf.variable_scope("other"): 30 | x2 = tf.placeholder(tf.int32, (), name="x") 31 | z = 3 * x + 2 * x2 32 | 33 | lin = function([x, x2], z, givens={x2: 0}) 34 | with single_threaded_session(): 35 | initialize() 36 | assert lin(2) == 6 37 | assert lin(2, 2) == 10 38 | 39 | 40 | if __name__ == '__main__': 41 | test_function() 42 | test_multikwargs() 43 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/test_with_mpi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import cloudpickle 5 | import base64 6 | import pytest 7 | from functools import wraps 8 | 9 | try: 10 | from mpi4py import MPI 11 | except ImportError: 12 | MPI = None 13 | 14 | def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True): 15 | def outer_thunk(fn): 16 | @wraps(fn) 17 | def thunk(*args, **kwargs): 18 | serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs))) 19 | subprocess.check_call([ 20 | 'mpiexec','-n', str(nproc), 21 | sys.executable, 22 | '-m', 'baselines.common.tests.test_with_mpi', 23 | serialized_fn 24 | ], env=os.environ, timeout=timeout) 25 | 26 | if skip_if_no_mpi: 27 | return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk) 28 | else: 29 | return thunk 30 | 31 | return outer_thunk 32 | 33 | 34 | if __name__ == '__main__': 35 | if len(sys.argv) > 1: 36 | fn = cloudpickle.loads(base64.b64decode(sys.argv[1])) 37 | assert callable(fn) 38 | fn() 39 | -------------------------------------------------------------------------------- /rl/baselines/common/tests/util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 4 | 5 | N_TRIALS = 10000 6 | N_EPISODES = 100 7 | 8 | _sess_config = tf.ConfigProto( 9 | allow_soft_placement=True, 10 | intra_op_parallelism_threads=1, 11 | inter_op_parallelism_threads=1 12 | ) 13 | 14 | def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS): 15 | def seeded_env_fn(): 16 | env = env_fn() 17 | env.seed(0) 18 | return env 19 | 20 | np.random.seed(0) 21 | env = DummyVecEnv([seeded_env_fn]) 22 | with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default(): 23 | tf.set_random_seed(0) 24 | model = learn_fn(env) 25 | sum_rew = 0 26 | done = True 27 | for i in range(n_trials): 28 | if done: 29 | obs = env.reset() 30 | state = model.initial_state 31 | if state is not None: 32 | a, v, state, _ = model.step(obs, S=state, M=[False]) 33 | else: 34 | a, v, _, _ = model.step(obs) 35 | obs, rew, done, _ = env.step(a) 36 | sum_rew += float(rew) 37 | print("Reward in {} trials is {}".format(n_trials, sum_rew)) 38 | assert sum_rew > min_reward_fraction * n_trials, \ 39 | 'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials) 40 | 41 | def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES): 42 | env = DummyVecEnv([env_fn]) 43 | with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default(): 44 | model = learn_fn(env) 45 | N_TRIALS = 100 46 | observations, actions, rewards = rollout(env, model, N_TRIALS) 47 | rewards = [sum(r) for r in rewards] 48 | avg_rew = sum(rewards) / N_TRIALS 49 | print("Average reward in {} episodes is {}".format(n_trials, avg_rew)) 50 | assert avg_rew > min_avg_reward, \ 51 | 'average reward in {} episodes ({}) is less than {}'.format(n_trials, avg_rew, min_avg_reward) 52 | 53 | def rollout(env, model, n_trials): 54 | rewards = [] 55 | actions = [] 56 | observations = [] 57 | for i in range(n_trials): 58 | obs = env.reset() 59 | state = model.initial_state if hasattr(model, 'initial_state') else None 60 | episode_rew = [] 61 | episode_actions = [] 62 | episode_obs = [] 63 | while True: 64 | if state is not None: 65 | a, v, state, _ = model.step(obs, S=state, M=[False]) 66 | else: 67 | a,v, _, _ = model.step(obs) 68 | 69 | obs, rew, done, _ = env.step(a) 70 | episode_rew.append(rew) 71 | episode_actions.append(a) 72 | episode_obs.append(obs) 73 | if done: 74 | break 75 | rewards.append(episode_rew) 76 | actions.append(episode_actions) 77 | observations.append(episode_obs) 78 | return observations, actions, rewards 79 | 80 | 81 | def smoketest(argstr, **kwargs): 82 | import tempfile 83 | import subprocess 84 | import os 85 | argstr = 'python -m baselines.run ' + argstr 86 | for key, value in kwargs: 87 | argstr += ' --{}={}'.format(key, value) 88 | tempdir = tempfile.mkdtemp() 89 | env = os.environ.copy() 90 | env['OPENAI_LOGDIR'] = tempdir 91 | subprocess.run(argstr.split(' '), env=env) 92 | return tempdir 93 | -------------------------------------------------------------------------------- /rl/baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def tile_images(img_nhwc): 4 | """ 5 | Tile N images into one big PxQ image 6 | (P,Q) are chosen to be as close as possible, and if N 7 | is square, then P=Q. 8 | 9 | input: img_nhwc, list or array of images, ndim=4 once turned into array 10 | n = batch index, h = height, w = width, c = channel 11 | returns: 12 | bigim_HWc, ndarray with ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | N, h, w, c = img_nhwc.shape 16 | H = int(np.ceil(np.sqrt(N))) 17 | W = int(np.ceil(float(N)/H)) 18 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 19 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 20 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 21 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 22 | return img_Hh_Ww_c 23 | 24 | 25 | def stack_images(img_nhwc): 26 | np_img = np.array(img_nhwc) 27 | if len(img_nhwc) == 1: 28 | np_img = np.expand_dims(np_img, axis=0) 29 | return np_img 30 | 31 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from .vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, VecEnvObservationWrapper, CloudpickleWrapper 2 | from .dummy_vec_env import DummyVecEnv 3 | from .shmem_vec_env import ShmemVecEnv 4 | from .subproc_vec_env import SubprocVecEnv 5 | from .vec_frame_stack import VecFrameStack 6 | from .vec_monitor import VecMonitor 7 | from .vec_normalize import VecNormalize 8 | from .vec_remove_dict_obs import VecExtractDictObs 9 | 10 | __all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize', 'VecExtractDictObs'] 11 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .vec_env import VecEnv 3 | from .util import copy_obs_dict, dict_to_obs, obs_space_info 4 | 5 | class DummyVecEnv(VecEnv): 6 | """ 7 | VecEnv that does runs multiple environments sequentially, that is, 8 | the step and reset commands are send to one environment at a time. 9 | Useful when debugging and when num_env == 1 (in the latter case, 10 | avoids communication overhead) 11 | """ 12 | def __init__(self, env_fns): 13 | """ 14 | Arguments: 15 | 16 | env_fns: iterable of callables functions that build environments 17 | """ 18 | self.envs = [fn() for fn in env_fns] 19 | env = self.envs[0] 20 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 21 | obs_space = env.observation_space 22 | self.keys, shapes, dtypes = obs_space_info(obs_space) 23 | 24 | self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } 25 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 26 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 27 | self.buf_infos = [{} for _ in range(self.num_envs)] 28 | self.actions = None 29 | self.spec = self.envs[0].spec 30 | 31 | def step_async(self, actions): 32 | listify = True 33 | try: 34 | if len(actions) == self.num_envs: 35 | listify = False 36 | except TypeError: 37 | pass 38 | 39 | if not listify: 40 | self.actions = actions 41 | else: 42 | assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs) 43 | self.actions = [actions] 44 | 45 | def step_wait(self): 46 | for e in range(self.num_envs): 47 | action = self.actions[e] 48 | # if isinstance(self.envs[e].action_space, spaces.Discrete): 49 | # action = int(action) 50 | 51 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action) 52 | if self.buf_dones[e]: 53 | obs = self.envs[e].reset() 54 | self._save_obs(e, obs) 55 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), 56 | self.buf_infos.copy()) 57 | 58 | def reset(self): 59 | for e in range(self.num_envs): 60 | obs = self.envs[e].reset() 61 | self._save_obs(e, obs) 62 | return self._obs_from_buf() 63 | 64 | def _save_obs(self, e, obs): 65 | for k in self.keys: 66 | if k is None: 67 | self.buf_obs[k][e] = obs 68 | else: 69 | self.buf_obs[k][e] = obs[k] 70 | 71 | def _obs_from_buf(self): 72 | return dict_to_obs(copy_obs_dict(self.buf_obs)) 73 | 74 | def get_images(self): 75 | return [env.render(mode='rgb_array') for env in self.envs] 76 | 77 | def render(self, mode='human'): 78 | if self.num_envs == 1: 79 | return self.envs[0].render(mode=mode) 80 | else: 81 | return super().render(mode=mode) 82 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/shmem_vec_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | An interface for asynchronous vectorized environments. 3 | """ 4 | 5 | import multiprocessing as mp 6 | import numpy as np 7 | from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars 8 | import ctypes 9 | from baselines import logger 10 | 11 | from .util import dict_to_obs, obs_space_info, obs_to_dict 12 | 13 | _NP_TO_CT = {np.float32: ctypes.c_float, 14 | np.int32: ctypes.c_int32, 15 | np.int8: ctypes.c_int8, 16 | np.uint8: ctypes.c_char, 17 | np.bool: ctypes.c_bool} 18 | 19 | 20 | class ShmemVecEnv(VecEnv): 21 | """ 22 | Optimized version of SubprocVecEnv that uses shared variables to communicate observations. 23 | """ 24 | 25 | def __init__(self, env_fns, spaces=None, context='spawn'): 26 | """ 27 | If you don't specify observation_space, we'll have to create a dummy 28 | environment to get it. 29 | """ 30 | ctx = mp.get_context(context) 31 | if spaces: 32 | observation_space, action_space = spaces 33 | else: 34 | logger.log('Creating dummy env object to get spaces') 35 | with logger.scoped_configure(format_strs=[]): 36 | dummy = env_fns[0]() 37 | observation_space, action_space = dummy.observation_space, dummy.action_space 38 | dummy.close() 39 | del dummy 40 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 41 | self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space) 42 | self.obs_bufs = [ 43 | {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} 44 | for _ in env_fns] 45 | self.parent_pipes = [] 46 | self.procs = [] 47 | with clear_mpi_env_vars(): 48 | for env_fn, obs_buf in zip(env_fns, self.obs_bufs): 49 | wrapped_fn = CloudpickleWrapper(env_fn) 50 | parent_pipe, child_pipe = ctx.Pipe() 51 | proc = ctx.Process(target=_subproc_worker, 52 | args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) 53 | proc.daemon = True 54 | self.procs.append(proc) 55 | self.parent_pipes.append(parent_pipe) 56 | proc.start() 57 | child_pipe.close() 58 | self.waiting_step = False 59 | self.viewer = None 60 | 61 | def reset(self): 62 | if self.waiting_step: 63 | logger.warn('Called reset() while waiting for the step to complete') 64 | self.step_wait() 65 | for pipe in self.parent_pipes: 66 | pipe.send(('reset', None)) 67 | return self._decode_obses([pipe.recv() for pipe in self.parent_pipes]) 68 | 69 | def step_async(self, actions): 70 | assert len(actions) == len(self.parent_pipes) 71 | for pipe, act in zip(self.parent_pipes, actions): 72 | pipe.send(('step', act)) 73 | self.waiting_step = True 74 | 75 | def step_wait(self): 76 | outs = [pipe.recv() for pipe in self.parent_pipes] 77 | self.waiting_step = False 78 | obs, rews, dones, infos = zip(*outs) 79 | return self._decode_obses(obs), np.array(rews), np.array(dones), infos 80 | 81 | def close_extras(self): 82 | if self.waiting_step: 83 | self.step_wait() 84 | for pipe in self.parent_pipes: 85 | pipe.send(('close', None)) 86 | for pipe in self.parent_pipes: 87 | pipe.recv() 88 | pipe.close() 89 | for proc in self.procs: 90 | proc.join() 91 | 92 | def get_images(self, mode='human'): 93 | for pipe in self.parent_pipes: 94 | pipe.send(('render', mode)) 95 | return [pipe.recv() for pipe in self.parent_pipes] 96 | 97 | def _decode_obses(self, obs): 98 | result = {} 99 | for k in self.obs_keys: 100 | 101 | bufs = [b[k] for b in self.obs_bufs] 102 | o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs] 103 | result[k] = np.array(o) 104 | return dict_to_obs(result) 105 | 106 | 107 | def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys): 108 | """ 109 | Control a single environment instance using IPC and 110 | shared memory. 111 | """ 112 | def _write_obs(maybe_dict_obs): 113 | flatdict = obs_to_dict(maybe_dict_obs) 114 | for k in keys: 115 | dst = obs_bufs[k].get_obj() 116 | dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k]) # pylint: disable=W0212 117 | np.copyto(dst_np, flatdict[k]) 118 | 119 | env = env_fn_wrapper.x() 120 | parent_pipe.close() 121 | try: 122 | while True: 123 | cmd, data = pipe.recv() 124 | if cmd == 'reset': 125 | pipe.send(_write_obs(env.reset())) 126 | elif cmd == 'step': 127 | obs, reward, done, info = env.step(data) 128 | if done: 129 | obs = env.reset() 130 | pipe.send((_write_obs(obs), reward, done, info)) 131 | elif cmd == 'render': 132 | pipe.send(env.render(mode=data)) 133 | elif cmd == 'close': 134 | pipe.send(None) 135 | break 136 | else: 137 | raise RuntimeError('Got unrecognized cmd %s' % cmd) 138 | except KeyboardInterrupt: 139 | print('ShmemVecEnv worker: got KeyboardInterrupt') 140 | finally: 141 | env.close() 142 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import numpy as np 4 | from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars 5 | 6 | 7 | def worker(remote, parent_remote, env_fn_wrappers): 8 | def step_env(env, action): 9 | ob, reward, done, info = env.step(action) 10 | if done: 11 | ob = env.reset() 12 | return ob, reward, done, info 13 | 14 | parent_remote.close() 15 | envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x] 16 | try: 17 | while True: 18 | cmd, data = remote.recv() 19 | if cmd == 'step': 20 | remote.send([step_env(env, action) for env, action in zip(envs, data)]) 21 | elif cmd == 'reset': 22 | remote.send([env.reset() for env in envs]) 23 | elif cmd == 'render': 24 | remote.send([env.render(mode='rgb_array') for env in envs]) 25 | elif cmd == 'close': 26 | remote.close() 27 | break 28 | elif cmd == 'get_spaces_spec': 29 | remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec))) 30 | else: 31 | raise NotImplementedError 32 | except KeyboardInterrupt: 33 | print('SubprocVecEnv worker: got KeyboardInterrupt') 34 | finally: 35 | for env in envs: 36 | env.close() 37 | 38 | 39 | class SubprocVecEnv(VecEnv): 40 | """ 41 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. 42 | Recommended to use when num_envs > 1 and step() can be a bottleneck. 43 | """ 44 | def __init__(self, env_fns, spaces=None, context='spawn', in_series=1): 45 | """ 46 | Arguments: 47 | 48 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable 49 | in_series: number of environments to run in series in a single process 50 | (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series) 51 | """ 52 | self.waiting = False 53 | self.closed = False 54 | self.in_series = in_series 55 | nenvs = len(env_fns) 56 | assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series" 57 | self.nremotes = nenvs // in_series 58 | env_fns = np.array_split(env_fns, self.nremotes) 59 | ctx = mp.get_context(context) 60 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)]) 61 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 62 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 63 | for p in self.ps: 64 | p.daemon = True # if the main process crashes, we should not cause things to hang 65 | with clear_mpi_env_vars(): 66 | p.start() 67 | for remote in self.work_remotes: 68 | remote.close() 69 | 70 | self.remotes[0].send(('get_spaces_spec', None)) 71 | observation_space, action_space, self.spec = self.remotes[0].recv().x 72 | self.viewer = None 73 | VecEnv.__init__(self, nenvs, observation_space, action_space) 74 | 75 | def step_async(self, actions): 76 | self._assert_not_closed() 77 | actions = np.array_split(actions, self.nremotes) 78 | for remote, action in zip(self.remotes, actions): 79 | remote.send(('step', action)) 80 | self.waiting = True 81 | 82 | def step_wait(self): 83 | self._assert_not_closed() 84 | results = [remote.recv() for remote in self.remotes] 85 | results = _flatten_list(results) 86 | self.waiting = False 87 | obs, rews, dones, infos = zip(*results) 88 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos 89 | 90 | def reset(self): 91 | self._assert_not_closed() 92 | for remote in self.remotes: 93 | remote.send(('reset', None)) 94 | obs = [remote.recv() for remote in self.remotes] 95 | obs = _flatten_list(obs) 96 | return _flatten_obs(obs) 97 | 98 | def close_extras(self): 99 | self.closed = True 100 | if self.waiting: 101 | for remote in self.remotes: 102 | remote.recv() 103 | for remote in self.remotes: 104 | remote.send(('close', None)) 105 | for p in self.ps: 106 | p.join() 107 | 108 | def get_images(self): 109 | self._assert_not_closed() 110 | for pipe in self.remotes: 111 | pipe.send(('render', None)) 112 | imgs = [pipe.recv() for pipe in self.remotes] 113 | imgs = _flatten_list(imgs) 114 | return imgs 115 | 116 | def _assert_not_closed(self): 117 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" 118 | 119 | def __del__(self): 120 | if not self.closed: 121 | self.close() 122 | 123 | def _flatten_obs(obs): 124 | assert isinstance(obs, (list, tuple)) 125 | assert len(obs) > 0 126 | 127 | if isinstance(obs[0], dict): 128 | keys = obs[0].keys() 129 | return {k: np.stack([o[k] for o in obs]) for k in keys} 130 | else: 131 | return np.stack(obs) 132 | 133 | def _flatten_list(l): 134 | assert isinstance(l, (list, tuple)) 135 | assert len(l) > 0 136 | assert all([len(l_) > 0 for l_ in l]) 137 | 138 | return [l__ for l_ in l for l__ in l_] 139 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/test_vec_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for asynchronous vectorized environments. 3 | """ 4 | 5 | import gym 6 | import numpy as np 7 | import pytest 8 | from .dummy_vec_env import DummyVecEnv 9 | from .shmem_vec_env import ShmemVecEnv 10 | from .subproc_vec_env import SubprocVecEnv 11 | from baselines.common.tests.test_with_mpi import with_mpi 12 | 13 | 14 | def assert_venvs_equal(venv1, venv2, num_steps): 15 | """ 16 | Compare two environments over num_steps steps and make sure 17 | that the observations produced by each are the same when given 18 | the same actions. 19 | """ 20 | assert venv1.num_envs == venv2.num_envs 21 | assert venv1.observation_space.shape == venv2.observation_space.shape 22 | assert venv1.observation_space.dtype == venv2.observation_space.dtype 23 | assert venv1.action_space.shape == venv2.action_space.shape 24 | assert venv1.action_space.dtype == venv2.action_space.dtype 25 | 26 | try: 27 | obs1, obs2 = venv1.reset(), venv2.reset() 28 | assert np.array(obs1).shape == np.array(obs2).shape 29 | assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape 30 | assert np.allclose(obs1, obs2) 31 | venv1.action_space.seed(1337) 32 | for _ in range(num_steps): 33 | actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)]) 34 | for venv in [venv1, venv2]: 35 | venv.step_async(actions) 36 | outs1 = venv1.step_wait() 37 | outs2 = venv2.step_wait() 38 | for out1, out2 in zip(outs1[:3], outs2[:3]): 39 | assert np.array(out1).shape == np.array(out2).shape 40 | assert np.allclose(out1, out2) 41 | assert list(outs1[3]) == list(outs2[3]) 42 | finally: 43 | venv1.close() 44 | venv2.close() 45 | 46 | 47 | @pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv)) 48 | @pytest.mark.parametrize('dtype', ('uint8', 'float32')) 49 | def test_vec_env(klass, dtype): # pylint: disable=R0914 50 | """ 51 | Test that a vectorized environment is equivalent to 52 | DummyVecEnv, since DummyVecEnv is less likely to be 53 | error prone. 54 | """ 55 | num_envs = 3 56 | num_steps = 100 57 | shape = (3, 8) 58 | 59 | def make_fn(seed): 60 | """ 61 | Get an environment constructor with a seed. 62 | """ 63 | return lambda: SimpleEnv(seed, shape, dtype) 64 | fns = [make_fn(i) for i in range(num_envs)] 65 | env1 = DummyVecEnv(fns) 66 | env2 = klass(fns) 67 | assert_venvs_equal(env1, env2, num_steps=num_steps) 68 | 69 | 70 | @pytest.mark.parametrize('dtype', ('uint8', 'float32')) 71 | @pytest.mark.parametrize('num_envs_in_series', (3, 4, 6)) 72 | def test_sync_sampling(dtype, num_envs_in_series): 73 | """ 74 | Test that a SubprocVecEnv running with envs in series 75 | outputs the same as DummyVecEnv. 76 | """ 77 | num_envs = 12 78 | num_steps = 100 79 | shape = (3, 8) 80 | 81 | def make_fn(seed): 82 | """ 83 | Get an environment constructor with a seed. 84 | """ 85 | return lambda: SimpleEnv(seed, shape, dtype) 86 | fns = [make_fn(i) for i in range(num_envs)] 87 | env1 = DummyVecEnv(fns) 88 | env2 = SubprocVecEnv(fns, in_series=num_envs_in_series) 89 | assert_venvs_equal(env1, env2, num_steps=num_steps) 90 | 91 | 92 | @pytest.mark.parametrize('dtype', ('uint8', 'float32')) 93 | @pytest.mark.parametrize('num_envs_in_series', (3, 4, 6)) 94 | def test_sync_sampling_sanity(dtype, num_envs_in_series): 95 | """ 96 | Test that a SubprocVecEnv running with envs in series 97 | outputs the same as SubprocVecEnv without running in series. 98 | """ 99 | num_envs = 12 100 | num_steps = 100 101 | shape = (3, 8) 102 | 103 | def make_fn(seed): 104 | """ 105 | Get an environment constructor with a seed. 106 | """ 107 | return lambda: SimpleEnv(seed, shape, dtype) 108 | fns = [make_fn(i) for i in range(num_envs)] 109 | env1 = SubprocVecEnv(fns) 110 | env2 = SubprocVecEnv(fns, in_series=num_envs_in_series) 111 | assert_venvs_equal(env1, env2, num_steps=num_steps) 112 | 113 | 114 | class SimpleEnv(gym.Env): 115 | """ 116 | An environment with a pre-determined observation space 117 | and RNG seed. 118 | """ 119 | 120 | def __init__(self, seed, shape, dtype): 121 | np.random.seed(seed) 122 | self._dtype = dtype 123 | self._start_obs = np.array(np.random.randint(0, 0x100, size=shape), 124 | dtype=dtype) 125 | self._max_steps = seed + 1 126 | self._cur_obs = None 127 | self._cur_step = 0 128 | # this is 0xFF instead of 0x100 because the Box space includes 129 | # the high end, while randint does not 130 | self.action_space = gym.spaces.Box(low=0, high=0xFF, shape=shape, dtype=dtype) 131 | self.observation_space = self.action_space 132 | 133 | def step(self, action): 134 | self._cur_obs += np.array(action, dtype=self._dtype) 135 | self._cur_step += 1 136 | done = self._cur_step >= self._max_steps 137 | reward = self._cur_step / self._max_steps 138 | return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)} 139 | 140 | def reset(self): 141 | self._cur_obs = self._start_obs 142 | self._cur_step = 0 143 | return self._cur_obs 144 | 145 | def render(self, mode=None): 146 | raise NotImplementedError 147 | 148 | 149 | 150 | @with_mpi() 151 | def test_mpi_with_subprocvecenv(): 152 | shape = (2,3,4) 153 | nenv = 1 154 | venv = SubprocVecEnv([lambda: SimpleEnv(0, shape, 'float32')] * nenv) 155 | ob = venv.reset() 156 | venv.close() 157 | assert ob.shape == (nenv,) + shape 158 | 159 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/test_video_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for asynchronous vectorized environments. 3 | """ 4 | 5 | import gym 6 | import pytest 7 | import os 8 | import glob 9 | import tempfile 10 | 11 | from .dummy_vec_env import DummyVecEnv 12 | from .shmem_vec_env import ShmemVecEnv 13 | from .subproc_vec_env import SubprocVecEnv 14 | from .vec_video_recorder import VecVideoRecorder 15 | 16 | @pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv)) 17 | @pytest.mark.parametrize('num_envs', (1, 4)) 18 | @pytest.mark.parametrize('video_length', (10, 100)) 19 | @pytest.mark.parametrize('video_interval', (1, 50)) 20 | def test_video_recorder(klass, num_envs, video_length, video_interval): 21 | """ 22 | Wrap an existing VecEnv with VevVideoRecorder, 23 | Make (video_interval + video_length + 1) steps, 24 | then check that the file is present 25 | """ 26 | 27 | def make_fn(): 28 | env = gym.make('PongNoFrameskip-v4') 29 | return env 30 | fns = [make_fn for _ in range(num_envs)] 31 | env = klass(fns) 32 | 33 | with tempfile.TemporaryDirectory() as video_path: 34 | env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length) 35 | 36 | env.reset() 37 | for _ in range(video_interval + video_length + 1): 38 | env.step([0] * num_envs) 39 | env.close() 40 | 41 | 42 | recorded_video = glob.glob(os.path.join(video_path, "*.mp4")) 43 | 44 | # first and second step 45 | assert len(recorded_video) == 2 46 | # Files are not empty 47 | assert all(os.stat(p).st_size != 0 for p in recorded_video) 48 | 49 | 50 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy an observation dict. 14 | """ 15 | return {k: np.copy(v) for k, v in obs.items()} 16 | 17 | 18 | def dict_to_obs(obs_dict): 19 | """ 20 | Convert an observation dict into a raw array if the 21 | original observation space was not a Dict space. 22 | """ 23 | if set(obs_dict.keys()) == {None}: 24 | return obs_dict[None] 25 | return obs_dict 26 | 27 | 28 | def obs_space_info(obs_space): 29 | """ 30 | Get dict-structured information about a gym.Space. 31 | 32 | Returns: 33 | A tuple (keys, shapes, dtypes): 34 | keys: a list of dict keys. 35 | shapes: a dict mapping keys to shapes. 36 | dtypes: a dict mapping keys to dtypes. 37 | """ 38 | if isinstance(obs_space, gym.spaces.Dict): 39 | assert isinstance(obs_space.spaces, OrderedDict) 40 | subspaces = obs_space.spaces 41 | elif isinstance(obs_space, gym.spaces.Tuple): 42 | assert isinstance(obs_space.spaces, tuple) 43 | subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))} 44 | else: 45 | subspaces = {None: obs_space} 46 | keys = [] 47 | shapes = {} 48 | dtypes = {} 49 | for key, box in subspaces.items(): 50 | keys.append(key) 51 | shapes[key] = box.shape 52 | dtypes[key] = box.dtype 53 | return keys, shapes, dtypes 54 | 55 | 56 | def obs_to_dict(obs): 57 | """ 58 | Convert an observation into a dict. 59 | """ 60 | if isinstance(obs, dict): 61 | return obs 62 | return {None: obs} 63 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_env.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | from baselines.common.tile_images import tile_images, stack_images 6 | 7 | class AlreadySteppingError(Exception): 8 | """ 9 | Raised when an asynchronous step is running while 10 | step_async() is called again. 11 | """ 12 | 13 | def __init__(self): 14 | msg = 'already running an async step' 15 | Exception.__init__(self, msg) 16 | 17 | 18 | class NotSteppingError(Exception): 19 | """ 20 | Raised when an asynchronous step is not running but 21 | step_wait() is called. 22 | """ 23 | 24 | def __init__(self): 25 | msg = 'not running an async step' 26 | Exception.__init__(self, msg) 27 | 28 | 29 | class VecEnv(ABC): 30 | """ 31 | An abstract asynchronous, vectorized environment. 32 | Used to batch data from multiple copies of an environment, so that 33 | each observation becomes an batch of observations, and expected action is a batch of actions to 34 | be applied per-environment. 35 | """ 36 | closed = False 37 | viewer = None 38 | 39 | metadata = { 40 | 'render.modes': ['human', 'rgb_array'] 41 | } 42 | 43 | def __init__(self, num_envs, observation_space, action_space): 44 | self.num_envs = num_envs 45 | self.observation_space = observation_space 46 | self.action_space = action_space 47 | 48 | @abstractmethod 49 | def reset(self): 50 | """ 51 | Reset all the environments and return an array of 52 | observations, or a dict of observation arrays. 53 | 54 | If step_async is still doing work, that work will 55 | be cancelled and step_wait() should not be called 56 | until step_async() is invoked again. 57 | """ 58 | pass 59 | 60 | @abstractmethod 61 | def step_async(self, actions): 62 | """ 63 | Tell all the environments to start taking a step 64 | with the given actions. 65 | Call step_wait() to get the results of the step. 66 | 67 | You should not call this if a step_async run is 68 | already pending. 69 | """ 70 | pass 71 | 72 | @abstractmethod 73 | def step_wait(self): 74 | """ 75 | Wait for the step taken with step_async(). 76 | 77 | Returns (obs, rews, dones, infos): 78 | - obs: an array of observations, or a dict of 79 | arrays of observations. 80 | - rews: an array of rewards 81 | - dones: an array of "episode done" booleans 82 | - infos: a sequence of info objects 83 | """ 84 | pass 85 | 86 | def close_extras(self): 87 | """ 88 | Clean up the extra resources, beyond what's in this base class. 89 | Only runs when not self.closed. 90 | """ 91 | pass 92 | 93 | def close(self): 94 | if self.closed: 95 | return 96 | if self.viewer is not None: 97 | self.viewer.close() 98 | self.close_extras() 99 | self.closed = True 100 | 101 | def step(self, actions): 102 | """ 103 | Step the environments synchronously. 104 | 105 | This is available for backwards compatibility. 106 | """ 107 | self.step_async(actions) 108 | return self.step_wait() 109 | 110 | def render(self, mode='human'): 111 | imgs = self.get_images(mode) 112 | if mode == 'human': 113 | bigimg = tile_images(imgs) 114 | self.get_viewer().imshow(bigimg) 115 | return self.get_viewer().isopen 116 | elif mode == 'rgb_array': 117 | bigimg = tile_images(imgs) 118 | return bigimg 119 | elif mode == 'init_states': 120 | return imgs 121 | else: 122 | raise NotImplementedError 123 | 124 | def get_images(self): 125 | """ 126 | Return RGB images from each environment 127 | """ 128 | raise NotImplementedError 129 | 130 | @property 131 | def unwrapped(self): 132 | if isinstance(self, VecEnvWrapper): 133 | return self.venv.unwrapped 134 | else: 135 | return self 136 | 137 | def get_viewer(self): 138 | if self.viewer is None: 139 | from gym.envs.classic_control import rendering 140 | self.viewer = rendering.SimpleImageViewer() 141 | return self.viewer 142 | 143 | class VecEnvWrapper(VecEnv): 144 | """ 145 | An environment wrapper that applies to an entire batch 146 | of environments at once. 147 | """ 148 | 149 | def __init__(self, venv, observation_space=None, action_space=None): 150 | self.venv = venv 151 | super().__init__(num_envs=venv.num_envs, 152 | observation_space=observation_space or venv.observation_space, 153 | action_space=action_space or venv.action_space) 154 | 155 | def step_async(self, actions): 156 | self.venv.step_async(actions) 157 | 158 | @abstractmethod 159 | def reset(self): 160 | pass 161 | 162 | @abstractmethod 163 | def step_wait(self): 164 | pass 165 | 166 | def close(self): 167 | return self.venv.close() 168 | 169 | def render(self, mode='human'): 170 | return self.venv.render(mode=mode) 171 | 172 | def get_images(self): 173 | return self.venv.get_images() 174 | 175 | def __getattr__(self, name): 176 | if name.startswith('_'): 177 | raise AttributeError("attempted to get missing private attribute '{}'".format(name)) 178 | return getattr(self.venv, name) 179 | 180 | class VecEnvObservationWrapper(VecEnvWrapper): 181 | @abstractmethod 182 | def process(self, obs): 183 | pass 184 | 185 | def reset(self): 186 | obs = self.venv.reset() 187 | return self.process(obs) 188 | 189 | def step_wait(self): 190 | obs, rews, dones, infos = self.venv.step_wait() 191 | return self.process(obs), rews, dones, infos 192 | 193 | class CloudpickleWrapper(object): 194 | """ 195 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 196 | """ 197 | 198 | def __init__(self, x): 199 | self.x = x 200 | 201 | def __getstate__(self): 202 | import cloudpickle 203 | return cloudpickle.dumps(self.x) 204 | 205 | def __setstate__(self, ob): 206 | import pickle 207 | self.x = pickle.loads(ob) 208 | 209 | 210 | @contextlib.contextmanager 211 | def clear_mpi_env_vars(): 212 | """ 213 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. 214 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing 215 | Processes. 216 | """ 217 | removed_environment = {} 218 | for k, v in list(os.environ.items()): 219 | for prefix in ['OMPI_', 'PMI_']: 220 | if k.startswith(prefix): 221 | removed_environment[k] = v 222 | del os.environ[k] 223 | try: 224 | yield 225 | finally: 226 | os.environ.update(removed_environment) 227 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from .vec_env import VecEnvWrapper 2 | import numpy as np 3 | from gym import spaces 4 | 5 | 6 | class VecFrameStack(VecEnvWrapper): 7 | def __init__(self, venv, nstack): 8 | self.venv = venv 9 | self.nstack = nstack 10 | wos = venv.observation_space # wrapped ob space 11 | low = np.repeat(wos.low, self.nstack, axis=-1) 12 | high = np.repeat(wos.high, self.nstack, axis=-1) 13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 16 | 17 | def step_wait(self): 18 | obs, rews, news, infos = self.venv.step_wait() 19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 20 | for (i, new) in enumerate(news): 21 | if new: 22 | self.stackedobs[i] = 0 23 | self.stackedobs[..., -obs.shape[-1]:] = obs 24 | return self.stackedobs, rews, news, infos 25 | 26 | def reset(self): 27 | obs = self.venv.reset() 28 | self.stackedobs[...] = 0 29 | self.stackedobs[..., -obs.shape[-1]:] = obs 30 | return self.stackedobs 31 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | from baselines.bench.monitor import ResultsWriter 3 | import numpy as np 4 | import time 5 | from collections import deque 6 | 7 | class VecMonitor(VecEnvWrapper): 8 | def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): 9 | VecEnvWrapper.__init__(self, venv) 10 | self.eprets = None 11 | self.eplens = None 12 | self.epcount = 0 13 | self.tstart = time.time() 14 | if filename: 15 | self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}, 16 | extra_keys=info_keywords) 17 | else: 18 | self.results_writer = None 19 | self.info_keywords = info_keywords 20 | self.keep_buf = keep_buf 21 | if self.keep_buf: 22 | self.epret_buf = deque([], maxlen=keep_buf) 23 | self.eplen_buf = deque([], maxlen=keep_buf) 24 | 25 | def reset(self): 26 | obs = self.venv.reset() 27 | self.eprets = np.zeros(self.num_envs, 'f') 28 | self.eplens = np.zeros(self.num_envs, 'i') 29 | return obs 30 | 31 | def step_wait(self): 32 | obs, rews, dones, infos = self.venv.step_wait() 33 | self.eprets += rews 34 | self.eplens += 1 35 | 36 | newinfos = list(infos[:]) 37 | for i in range(len(dones)): 38 | if dones[i]: 39 | info = infos[i].copy() 40 | ret = self.eprets[i] 41 | eplen = self.eplens[i] 42 | epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} 43 | for k in self.info_keywords: 44 | epinfo[k] = info[k] 45 | info['episode'] = epinfo 46 | if self.keep_buf: 47 | self.epret_buf.append(ret) 48 | self.eplen_buf.append(eplen) 49 | self.epcount += 1 50 | self.eprets[i] = 0 51 | self.eplens[i] = 0 52 | if self.results_writer: 53 | self.results_writer.write_row(epinfo) 54 | newinfos[i] = info 55 | return obs, rews, dones, newinfos 56 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_normalize.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | import numpy as np 3 | 4 | class VecNormalize(VecEnvWrapper): 5 | """ 6 | A vectorized wrapper that normalizes the observations 7 | and returns from an environment. 8 | """ 9 | 10 | def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False): 11 | VecEnvWrapper.__init__(self, venv) 12 | if use_tf: 13 | from baselines.common.running_mean_std import TfRunningMeanStd 14 | self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='ob_rms') if ob else None 15 | self.ret_rms = TfRunningMeanStd(shape=(), scope='ret_rms') if ret else None 16 | else: 17 | from baselines.common.running_mean_std import RunningMeanStd 18 | self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None 19 | self.ret_rms = RunningMeanStd(shape=()) if ret else None 20 | self.clipob = clipob 21 | self.cliprew = cliprew 22 | self.ret = np.zeros(self.num_envs) 23 | self.gamma = gamma 24 | self.epsilon = epsilon 25 | 26 | def step_wait(self): 27 | obs, rews, news, infos = self.venv.step_wait() 28 | self.ret = self.ret * self.gamma + rews 29 | obs = self._obfilt(obs) 30 | if self.ret_rms: 31 | self.ret_rms.update(self.ret) 32 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) 33 | self.ret[news] = 0. 34 | return obs, rews, news, infos 35 | 36 | def _obfilt(self, obs): 37 | if self.ob_rms: 38 | self.ob_rms.update(obs) 39 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) 40 | return obs 41 | else: 42 | return obs 43 | 44 | def reset(self): 45 | self.ret = np.zeros(self.num_envs) 46 | obs = self.venv.reset() 47 | return self._obfilt(obs) 48 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_remove_dict_obs.py: -------------------------------------------------------------------------------- 1 | from .vec_env import VecEnvObservationWrapper 2 | 3 | class VecExtractDictObs(VecEnvObservationWrapper): 4 | def __init__(self, venv, key): 5 | self.key = key 6 | super().__init__(venv=venv, 7 | observation_space=venv.observation_space.spaces[self.key]) 8 | 9 | def process(self, obs): 10 | return obs[self.key] 11 | -------------------------------------------------------------------------------- /rl/baselines/common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from baselines import logger 3 | from baselines.common.vec_env import VecEnvWrapper 4 | from gym.wrappers.monitoring import video_recorder 5 | 6 | 7 | class VecVideoRecorder(VecEnvWrapper): 8 | """ 9 | Wrap VecEnv to record rendered image as mp4 video. 10 | """ 11 | 12 | def __init__(self, venv, directory, record_video_trigger, video_length=200): 13 | """ 14 | # Arguments 15 | venv: VecEnv to wrap 16 | directory: Where to save videos 17 | record_video_trigger: 18 | Function that defines when to start recording. 19 | The function takes the current number of step, 20 | and returns whether we should start recording or not. 21 | video_length: Length of recorded video 22 | """ 23 | 24 | VecEnvWrapper.__init__(self, venv) 25 | self.record_video_trigger = record_video_trigger 26 | self.video_recorder = None 27 | 28 | self.directory = os.path.abspath(directory) 29 | if not os.path.exists(self.directory): os.mkdir(self.directory) 30 | 31 | self.file_prefix = "vecenv" 32 | self.file_infix = '{}'.format(os.getpid()) 33 | self.step_id = 0 34 | self.video_length = video_length 35 | 36 | self.recording = False 37 | self.recorded_frames = 0 38 | 39 | def reset(self): 40 | obs = self.venv.reset() 41 | 42 | self.start_video_recorder() 43 | 44 | return obs 45 | 46 | def start_video_recorder(self): 47 | self.close_video_recorder() 48 | 49 | base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id)) 50 | self.video_recorder = video_recorder.VideoRecorder( 51 | env=self.venv, 52 | base_path=base_path, 53 | metadata={'step_id': self.step_id} 54 | ) 55 | 56 | self.video_recorder.capture_frame() 57 | self.recorded_frames = 1 58 | self.recording = True 59 | 60 | def _video_enabled(self): 61 | return self.record_video_trigger(self.step_id) 62 | 63 | def step_wait(self): 64 | obs, rews, dones, infos = self.venv.step_wait() 65 | 66 | self.step_id += 1 67 | if self.recording: 68 | self.video_recorder.capture_frame() 69 | self.recorded_frames += 1 70 | if self.recorded_frames > self.video_length: 71 | logger.info("Saving video to ", self.video_recorder.path) 72 | self.close_video_recorder() 73 | elif self._video_enabled(): 74 | self.start_video_recorder() 75 | 76 | return obs, rews, dones, infos 77 | 78 | def close_video_recorder(self): 79 | if self.recording: 80 | self.video_recorder.close() 81 | self.recording = False 82 | self.recorded_frames = 0 83 | 84 | def close(self): 85 | VecEnvWrapper.close(self) 86 | self.close_video_recorder() 87 | 88 | def __del__(self): 89 | self.close() 90 | -------------------------------------------------------------------------------- /rl/baselines/common/wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | class TimeLimit(gym.Wrapper): 4 | def __init__(self, env, max_episode_steps=None): 5 | super(TimeLimit, self).__init__(env) 6 | self._max_episode_steps = max_episode_steps 7 | self._elapsed_steps = 0 8 | 9 | def step(self, ac): 10 | observation, reward, done, info = self.env.step(ac) 11 | self._elapsed_steps += 1 12 | if self._elapsed_steps >= self._max_episode_steps: 13 | done = True 14 | info['TimeLimit.truncated'] = True 15 | return observation, reward, done, info 16 | 17 | def reset(self, **kwargs): 18 | self._elapsed_steps = 0 19 | return self.env.reset(**kwargs) 20 | 21 | class ClipActionsWrapper(gym.Wrapper): 22 | def step(self, action): 23 | import numpy as np 24 | action = np.nan_to_num(action) 25 | action = np.clip(action, self.action_space.low, self.action_space.high) 26 | return self.env.step(action) 27 | 28 | def reset(self, **kwargs): 29 | return self.env.reset(**kwargs) 30 | -------------------------------------------------------------------------------- /rl/baselines/results_plotter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode 4 | 5 | import matplotlib.pyplot as plt 6 | plt.rcParams['svg.fonttype'] = 'none' 7 | 8 | from baselines.common import plot_util 9 | 10 | X_TIMESTEPS = 'timesteps' 11 | X_EPISODES = 'episodes' 12 | X_WALLTIME = 'walltime_hrs' 13 | Y_REWARD = 'reward' 14 | Y_TIMESTEPS = 'timesteps' 15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] 16 | EPISODES_WINDOW = 100 17 | COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink', 18 | 'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise', 19 | 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue'] 20 | 21 | def rolling_window(a, window): 22 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 23 | strides = a.strides + (a.strides[-1],) 24 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 25 | 26 | def window_func(x, y, window, func): 27 | yw = rolling_window(y, window) 28 | yw_func = func(yw, axis=-1) 29 | return x[window-1:], yw_func 30 | 31 | def ts2xy(ts, xaxis, yaxis): 32 | if xaxis == X_TIMESTEPS: 33 | x = np.cumsum(ts.l.values) 34 | elif xaxis == X_EPISODES: 35 | x = np.arange(len(ts)) 36 | elif xaxis == X_WALLTIME: 37 | x = ts.t.values / 3600. 38 | else: 39 | raise NotImplementedError 40 | if yaxis == Y_REWARD: 41 | y = ts.r.values 42 | elif yaxis == Y_TIMESTEPS: 43 | y = ts.l.values 44 | else: 45 | raise NotImplementedError 46 | return x, y 47 | 48 | def plot_curves(xy_list, xaxis, yaxis, title): 49 | fig = plt.figure(figsize=(8,2)) 50 | maxx = max(xy[0][-1] for xy in xy_list) 51 | minx = 0 52 | for (i, (x, y)) in enumerate(xy_list): 53 | color = COLORS[i % len(COLORS)] 54 | plt.scatter(x, y, s=2) 55 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes 56 | plt.plot(x, y_mean, color=color) 57 | plt.xlim(minx, maxx) 58 | plt.title(title) 59 | plt.xlabel(xaxis) 60 | plt.ylabel(yaxis) 61 | plt.tight_layout() 62 | fig.canvas.mpl_connect('resize_event', lambda event: plt.tight_layout()) 63 | plt.grid(True) 64 | 65 | 66 | def split_by_task(taskpath): 67 | return taskpath['dirname'].split('/')[-1].split('-')[0] 68 | 69 | def plot_results(dirs, num_timesteps=10e6, xaxis=X_TIMESTEPS, yaxis=Y_REWARD, title='', split_fn=split_by_task): 70 | results = plot_util.load_results(dirs) 71 | plot_util.plot_results(results, xy_fn=lambda r: ts2xy(r['monitor'], xaxis, yaxis), split_fn=split_fn, average_group=True, resample=int(1e6)) 72 | 73 | # Example usage in jupyter-notebook 74 | # from baselines.results_plotter import plot_results 75 | # %matplotlib inline 76 | # plot_results("./log") 77 | # Here ./log is a directory containing the monitor.csv files 78 | 79 | def main(): 80 | import argparse 81 | import os 82 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 | parser.add_argument('--dirs', help='List of log directories', nargs = '*', default=['./log']) 84 | parser.add_argument('--num_timesteps', type=int, default=int(10e6)) 85 | parser.add_argument('--xaxis', help = 'Varible on X-axis', default = X_TIMESTEPS) 86 | parser.add_argument('--yaxis', help = 'Varible on Y-axis', default = Y_REWARD) 87 | parser.add_argument('--task_name', help = 'Title of plot', default = 'Breakout') 88 | args = parser.parse_args() 89 | args.dirs = [os.path.abspath(dir) for dir in args.dirs] 90 | plot_results(args.dirs, args.num_timesteps, args.xaxis, args.yaxis, args.task_name) 91 | plt.show() 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /rl/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from rl.utils import AddBias, init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | 13 | # 14 | # Standardize distribution interfaces 15 | # 16 | 17 | # Categorical 18 | class FixedCategorical(torch.distributions.Categorical): 19 | def sample(self): 20 | return super().sample().unsqueeze(-1) 21 | 22 | def log_probs(self, actions): 23 | return ( 24 | super() 25 | .log_prob(actions.squeeze(-1)) 26 | .view(actions.size(0), -1) 27 | .sum(-1) 28 | .unsqueeze(-1) 29 | ) 30 | 31 | def mode(self): 32 | return self.probs.argmax(dim=-1, keepdim=True) 33 | 34 | 35 | # Normal 36 | class FixedNormal(torch.distributions.Normal): 37 | def log_probs(self, actions): 38 | return super().log_prob(actions).sum(-1, keepdim=True) 39 | 40 | def entropy(self): 41 | return super().entropy().sum(-1) 42 | 43 | def mode(self): 44 | return self.mean 45 | 46 | 47 | # Bernoulli 48 | class FixedBernoulli(torch.distributions.Bernoulli): 49 | def log_probs(self, actions): 50 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 51 | 52 | def entropy(self): 53 | return super().entropy().sum(-1) 54 | 55 | def mode(self): 56 | return torch.gt(self.probs, 0.5).float() 57 | 58 | 59 | class Categorical(nn.Module): 60 | def __init__(self, num_inputs, num_outputs): 61 | super(Categorical, self).__init__() 62 | 63 | init_ = lambda m: init( 64 | m, 65 | nn.init.orthogonal_, 66 | lambda x: nn.init.constant_(x, 0), 67 | gain=0.01) 68 | 69 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 70 | 71 | def forward(self, x, mask=None): 72 | x = self.linear(x) 73 | if mask is not None: 74 | assert x.shape == mask.shape 75 | x += mask 76 | return FixedCategorical(logits=x) 77 | 78 | 79 | class DiagGaussian(nn.Module): 80 | def __init__(self, num_inputs, num_outputs): 81 | super(DiagGaussian, self).__init__() 82 | 83 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 84 | constant_(x, 0)) 85 | 86 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 87 | self.logstd = AddBias(torch.zeros(num_outputs)) 88 | 89 | def forward(self, x): 90 | action_mean = self.fc_mean(x) 91 | 92 | # An ugly hack for my KFAC implementation. 93 | zeros = torch.zeros(action_mean.size()) 94 | if x.is_cuda: 95 | zeros = zeros.cuda() 96 | 97 | action_logstd = self.logstd(zeros) 98 | return FixedNormal(action_mean, action_logstd.exp()) 99 | 100 | 101 | class Bernoulli(nn.Module): 102 | def __init__(self, num_inputs, num_outputs): 103 | super(Bernoulli, self).__init__() 104 | 105 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 106 | constant_(x, 0)) 107 | 108 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 109 | 110 | def forward(self, x): 111 | x = self.linear(x) 112 | return FixedBernoulli(logits=x) 113 | -------------------------------------------------------------------------------- /rl/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from rl.utils import init 7 | 8 | 9 | class Flatten(nn.Module): 10 | def forward(self, x): 11 | return x.view(x.size(0), -1) 12 | 13 | 14 | class NNBase(nn.Module): 15 | def __init__(self, recurrent, recurrent_input_size, hidden_size, rnn_type='GRU'): 16 | super(NNBase, self).__init__() 17 | 18 | self._hidden_size = hidden_size 19 | self._recurrent = recurrent 20 | self.rnn_type = rnn_type 21 | 22 | if recurrent: 23 | if rnn_type == 'GRU': 24 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 25 | elif rnn_type == 'LSTM': 26 | self.lstm = nn.LSTM(recurrent_input_size, hidden_size) 27 | # need to keep this for backward compatibility to pre-trained weights 28 | self.gru = self.lstm 29 | else: 30 | raise NotImplementedError() 31 | 32 | for name, param in self.gru.named_parameters(): 33 | if 'bias' in name: 34 | nn.init.constant_(param, 0) 35 | elif 'weight' in name: 36 | nn.init.orthogonal_(param) 37 | 38 | @property 39 | def is_recurrent(self): 40 | return self._recurrent 41 | 42 | @property 43 | def recurrent_hidden_state_size(self): 44 | if self._recurrent: 45 | return self._hidden_size 46 | return 1 47 | 48 | @property 49 | def output_size(self): 50 | return self._hidden_size 51 | 52 | def _forward_rnn(self, x, hxs, masks): 53 | if self.rnn_type == 'GRU': 54 | return self._forward_gru(x, hxs, masks) 55 | elif self.rnn_type == 'LSTM': 56 | return self._forward_lstm(x, hxs, masks) 57 | else: 58 | raise NotImplementedError() 59 | 60 | def _forward_lstm(self, x, hxs, masks): 61 | assert isinstance(hxs, tuple) and len(hxs) == 2 62 | if x.size(0) == hxs[0].size(0): 63 | x, hxs = self.lstm(x.unsqueeze(0), (hxs[0].unsqueeze(0), hxs[1].unsqueeze(0))) 64 | x = x.squeeze(0) 65 | hxs = (hxs[0].squeeze(0), hxs[1].squeeze(0)) 66 | else: 67 | assert 0, "current code doesn't support this" 68 | return x, hxs 69 | 70 | def _forward_gru(self, x, hxs, masks): 71 | if x.size(0) == hxs.size(0): 72 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 73 | x = x.squeeze(0) 74 | hxs = hxs.squeeze(0) 75 | else: 76 | assert 0, "current code doesn't support this" 77 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 78 | N = hxs.size(0) 79 | T = int(x.size(0) / N) 80 | 81 | # unflatten 82 | x = x.view(T, N, x.size(1)) 83 | 84 | # Same deal with masks 85 | masks = masks.view(T, N) 86 | 87 | # Let's figure out which steps in the sequence have a zero for any agent 88 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 89 | has_zeros = ((masks[1:] == 0.0) \ 90 | .any(dim=-1) 91 | .nonzero() 92 | .squeeze() 93 | .cpu()) 94 | 95 | # +1 to correct the masks[1:] 96 | if has_zeros.dim() == 0: 97 | # Deal with scalar 98 | has_zeros = [has_zeros.item() + 1] 99 | else: 100 | has_zeros = (has_zeros + 1).numpy().tolist() 101 | 102 | # add t=0 and t=T to the list 103 | has_zeros = [0] + has_zeros + [T] 104 | 105 | hxs = hxs.unsqueeze(0) 106 | outputs = [] 107 | for i in range(len(has_zeros) - 1): 108 | # We can now process steps that don't have any zeros in masks together! 109 | # This is much faster 110 | start_idx = has_zeros[i] 111 | end_idx = has_zeros[i + 1] 112 | 113 | rnn_scores, hxs = self.gru( 114 | x[start_idx:end_idx], 115 | hxs * masks[start_idx].view(1, -1, 1)) 116 | 117 | outputs.append(rnn_scores) 118 | 119 | # assert len(outputs) == T 120 | # x is a (T, N, -1) tensor 121 | x = torch.cat(outputs, dim=0) 122 | # flatten 123 | x = x.view(T * N, -1) 124 | hxs = hxs.squeeze(0) 125 | 126 | return x, hxs -------------------------------------------------------------------------------- /rl/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from operator import itemgetter 4 | from itertools import chain 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import lr_scheduler 10 | 11 | from rl.envs import VecNormalize 12 | 13 | 14 | # Get a render function 15 | def get_render_func(venv): 16 | if hasattr(venv, 'envs'): 17 | return venv.envs[0].render 18 | elif hasattr(venv, 'venv'): 19 | return get_render_func(venv.venv) 20 | elif hasattr(venv, 'env'): 21 | return get_render_func(venv.env) 22 | 23 | return None 24 | 25 | 26 | def get_vec_normalize(venv): 27 | if isinstance(venv, VecNormalize): 28 | return venv 29 | elif hasattr(venv, 'venv'): 30 | return get_vec_normalize(venv.venv) 31 | 32 | return None 33 | 34 | 35 | # Necessary for my KFAC implementation. 36 | class AddBias(nn.Module): 37 | def __init__(self, bias): 38 | super(AddBias, self).__init__() 39 | self._bias = nn.Parameter(bias.unsqueeze(1)) 40 | 41 | def forward(self, x): 42 | if x.dim() == 2: 43 | bias = self._bias.t().view(1, -1) 44 | else: 45 | bias = self._bias.t().view(1, -1, 1, 1) 46 | 47 | return x + bias 48 | 49 | 50 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 51 | """Decreases the learning rate linearly""" 52 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 53 | for param_group in optimizer.param_groups: 54 | param_group['lr'] = lr 55 | return lr 56 | 57 | 58 | def init(module, weight_init, bias_init, gain=1): 59 | weight_init(module.weight.data, gain=gain) 60 | bias_init(module.bias.data) 61 | return module 62 | 63 | 64 | def cleanup_log_dir(log_dir): 65 | try: 66 | os.makedirs(log_dir) 67 | except OSError: 68 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 69 | for f in files: 70 | os.remove(f) 71 | 72 | 73 | def masked_mean(x, mask, dim=-1, keepdim=False): 74 | assert x.shape == mask.shape 75 | return torch.sum(x * mask.float(), dim=dim, keepdim=keepdim) / torch.sum(mask, dim=dim, keepdim=keepdim) 76 | 77 | 78 | def masked_sum(x, mask, dim=-1, keepdim=False): 79 | assert x.shape == mask.shape 80 | return torch.sum(x * mask.float(), dim=dim, keepdim=keepdim) 81 | 82 | 83 | def create_hook(name): 84 | def hook(grad): 85 | print(name, "nan-grad: ", np.isnan(grad.cpu().numpy()).sum()) 86 | return hook 87 | 88 | 89 | def count_parameters(net): 90 | """ Returns total number of trainable parameters in net """ 91 | return sum(p.numel() for p in net.parameters() if p.requires_grad) 92 | 93 | 94 | def _list_to_sequence(x, indices): 95 | return torch.nn.utils.rnn.pack_sequence(list(chain.from_iterable(itemgetter(*indices)(x))), enforce_sorted=False) -------------------------------------------------------------------------------- /tasks/cleanHouse1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( IF c( leftIsClear c) i( turnLeft i) move IF c( markersPresent c) i( pickMarker i) w) m) 2 | -------------------------------------------------------------------------------- /tasks/fourCorners1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( WHILE c( frontIsClear c) w( move w) IF c( noMarkersPresent c) i( putMarker turnLeft move i) w) m) 2 | -------------------------------------------------------------------------------- /tasks/harvester1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( markersPresent c) w( WHILE c( markersPresent c) w( pickMarker move w) turnRight move turnLeft WHILE c( markersPresent c) w( pickMarker move w) turnLeft move turnRight w) m) 2 | -------------------------------------------------------------------------------- /tasks/maze.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( IFELSE c( rightIsClear c) i( turnRight i) ELSE e( WHILE c( not c( frontIsClear c) c) w( turnLeft w) e) move w) m) 2 | -------------------------------------------------------------------------------- /tasks/maze1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( IFELSE c( rightIsClear c) i( turnRight i) ELSE e( WHILE c( not c( frontIsClear c) c) w( turnLeft w) e) move w) m) 2 | -------------------------------------------------------------------------------- /tasks/randomMaze1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( IFELSE c( rightIsClear c) i( turnRight i) ELSE e( WHILE c( not c( frontIsClear c) c) w( turnLeft w) e) move w) m) 2 | -------------------------------------------------------------------------------- /tasks/stairClimber1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( noMarkersPresent c) w( turnLeft move turnRight move w) m) 2 | -------------------------------------------------------------------------------- /tasks/test1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( frontIsClear c) w( turnRight move pickMarker turnRight w) m) 2 | -------------------------------------------------------------------------------- /tasks/test2.txt: -------------------------------------------------------------------------------- 1 | DEF run m( IFELSE c( markersPresent c) i( move turnRight i) ELSE e( move e) move move WHILE c( leftIsClear c) w( turnLeft w) m) 2 | -------------------------------------------------------------------------------- /tasks/test3.txt: -------------------------------------------------------------------------------- 1 | DEF run m( IF c( frontIsClear c) i( putMarker i) move IF c( rightIsClear c) i( move i) IFELSE c( frontIsClear c) i( move i) ELSE e( move e) m) 2 | -------------------------------------------------------------------------------- /tasks/test4.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( leftIsClear c) w( turnLeft w) IF c( frontIsClear c) i( putMarker move i) move IF c( rightIsClear c) i( turnRight move i) IFELSE c( frontIsClear c) i( move i) ELSE e( turnLeft move e) m) 2 | -------------------------------------------------------------------------------- /tasks/topOff1.txt: -------------------------------------------------------------------------------- 1 | DEF run m( WHILE c( frontIsClear c) w( IF c( markersPresent c) i( putMarker i) move w) m) 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/utils/__init__.py -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | """misc utils for training neural networks""" 2 | 3 | 4 | class HyperParameterScheduler(object): 5 | def __init__(self, initial_val, num_updates, final_val=None, func='linear', gamma=0.999): 6 | """ Initialize HyperParameter Scheduler class 7 | 8 | :param initial_val: initial value of the hyper-parameter 9 | :param num_updates: total number of updates for hyper-parameter 10 | :param final_val: final value of the hyper-parameter, if None then decay rate is fixed (0.999 for exponential 0 11 | for linear) 12 | :param func: decay type ['exponential', linear'] 13 | :param gamma: fixed decay rate for exponential decay (if final value is given then this gamma is ignored) 14 | """ 15 | self.initial_val = initial_val 16 | self.total_num_epoch = num_updates 17 | self.final_val = final_val 18 | self.cur_hp = self.initial_val 19 | self.cur_step = 0 20 | 21 | if final_val is not None: 22 | assert final_val >= 0, 'final value should be positive' 23 | 24 | if func == "linear": 25 | self.hp_lambda = self.linear_scheduler 26 | elif func == "exponential": 27 | self.hp_lambda = self.exponential_scheduler 28 | if initial_val == final_val: 29 | self.gamma = 1 30 | else: 31 | self.gamma = pow(final_val / initial_val, 1 / self.total_num_epoch) if final_val is not None else gamma 32 | else: 33 | raise NotImplementedError('scheduler not implemented') 34 | 35 | def linear_scheduler(self, epoch): 36 | if self.final_val is not None: 37 | return (self.final_val - self.initial_val)*(epoch/self.total_num_epoch) + self.initial_val 38 | else: 39 | return self.initial_val - (self.initial_val * (epoch / self.total_num_epoch)) 40 | 41 | def exponential_scheduler(self, epoch): 42 | return self.initial_val * (self.gamma ** epoch) 43 | 44 | def step(self, epoch=None): 45 | assert self.cur_step <= self.total_num_epoch, "scheduler step shouldn't be larger than total steps" 46 | if epoch is None: 47 | epoch = self.cur_step 48 | self.cur_hp = self.hp_lambda(epoch) 49 | self.cur_step += 1 50 | return self.cur_hp 51 | 52 | @property 53 | def get_value(self): 54 | return self.cur_hp -------------------------------------------------------------------------------- /weights/LEAPS/best_valid_params.ptp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/weights/LEAPS/best_valid_params.ptp -------------------------------------------------------------------------------- /weights/LEAPSP/best_valid_params.ptp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/weights/LEAPSP/best_valid_params.ptp -------------------------------------------------------------------------------- /weights/LEAPSPL/best_valid_params.ptp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/weights/LEAPSPL/best_valid_params.ptp -------------------------------------------------------------------------------- /weights/LEAPSPR/best_valid_params.ptp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/leaps/d89beb11d1c9e1845f61c3a58e69e9c3f2672c39/weights/LEAPSPR/best_valid_params.ptp --------------------------------------------------------------------------------