├── .gitignore ├── LICENSE ├── README.md ├── drex-atari ├── LearnAtariSyntheticRankingsBinning.py ├── README.md ├── baselines │ ├── Dockerfile │ ├── LICENSE │ ├── README.md │ ├── baselines │ │ ├── __init__.py │ │ ├── a2c │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── a2c.py │ │ │ ├── runner.py │ │ │ └── utils.py │ │ ├── acer │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acer.py │ │ │ ├── buffer.py │ │ │ ├── defaults.py │ │ │ ├── policies.py │ │ │ └── runner.py │ │ ├── acktr │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acktr.py │ │ │ ├── defaults.py │ │ │ ├── kfac.py │ │ │ ├── kfac_utils.py │ │ │ └── utils.py │ │ ├── bench │ │ │ ├── __init__.py │ │ │ ├── benchmarks.py │ │ │ └── monitor.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── atari_wrappers.py │ │ │ ├── cg.py │ │ │ ├── cmd_util.py │ │ │ ├── console_util.py │ │ │ ├── custom_reward_wrapper.py │ │ │ ├── dataset.py │ │ │ ├── distributions.py │ │ │ ├── environment.yaml │ │ │ ├── input.py │ │ │ ├── math_util.py │ │ │ ├── misc_util.py │ │ │ ├── models.py │ │ │ ├── mpi_adam.py │ │ │ ├── mpi_adam_optimizer.py │ │ │ ├── mpi_fork.py │ │ │ ├── mpi_moments.py │ │ │ ├── mpi_running_mean_std.py │ │ │ ├── mpi_util.py │ │ │ ├── plot_util.py │ │ │ ├── policies.py │ │ │ ├── retro_wrappers.py │ │ │ ├── runners.py │ │ │ ├── running_mean_std.py │ │ │ ├── schedules.py │ │ │ ├── segment_tree.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── envs │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fixed_sequence_env.py │ │ │ │ │ ├── identity_env.py │ │ │ │ │ └── mnist_env.py │ │ │ │ ├── test_cartpole.py │ │ │ │ ├── test_doc_examples.py │ │ │ │ ├── test_env_after_learn.py │ │ │ │ ├── test_fixed_sequence.py │ │ │ │ ├── test_identity.py │ │ │ │ ├── test_mnist.py │ │ │ │ ├── test_schedules.py │ │ │ │ ├── test_segment_tree.py │ │ │ │ ├── test_serialization.py │ │ │ │ ├── test_tf_util.py │ │ │ │ └── util.py │ │ │ ├── tf_util.py │ │ │ ├── tile_images.py │ │ │ ├── trex_utils.py │ │ │ └── vec_env │ │ │ │ ├── __init__.py │ │ │ │ ├── dummy_vec_env.py │ │ │ │ ├── shmem_vec_env.py │ │ │ │ ├── subproc_vec_env.py │ │ │ │ ├── test_vec_env.py │ │ │ │ ├── test_video_recorder.py │ │ │ │ ├── util.py │ │ │ │ ├── vec_frame_stack.py │ │ │ │ ├── vec_monitor.py │ │ │ │ ├── vec_normalize.py │ │ │ │ └── vec_video_recorder.py │ │ ├── ddpg │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── ddpg.py │ │ │ ├── ddpg_learner.py │ │ │ ├── memory.py │ │ │ ├── models.py │ │ │ └── noise.py │ │ ├── deepq │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── build_graph.py │ │ │ ├── deepq.py │ │ │ ├── defaults.py │ │ │ ├── experiments │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_cartpole.py │ │ │ │ ├── enjoy_cartpole.py │ │ │ │ ├── enjoy_mountaincar.py │ │ │ │ ├── enjoy_pong.py │ │ │ │ ├── train_cartpole.py │ │ │ │ ├── train_mountaincar.py │ │ │ │ └── train_pong.py │ │ │ ├── models.py │ │ │ ├── replay_buffer.py │ │ │ └── utils.py │ │ ├── gail │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── adversary.py │ │ │ ├── behavior_clone.py │ │ │ ├── dataset │ │ │ │ ├── __init__.py │ │ │ │ └── mujoco_dset.py │ │ │ ├── gail-eval.py │ │ │ ├── mlp_policy.py │ │ │ ├── result │ │ │ │ ├── HalfCheetah-normalized-deterministic-scores.png │ │ │ │ ├── HalfCheetah-normalized-stochastic-scores.png │ │ │ │ ├── HalfCheetah-unnormalized-deterministic-scores.png │ │ │ │ ├── HalfCheetah-unnormalized-stochastic-scores.png │ │ │ │ ├── Hopper-normalized-deterministic-scores.png │ │ │ │ ├── Hopper-normalized-stochastic-scores.png │ │ │ │ ├── Hopper-unnormalized-deterministic-scores.png │ │ │ │ ├── Hopper-unnormalized-stochastic-scores.png │ │ │ │ ├── Humanoid-normalized-deterministic-scores.png │ │ │ │ ├── Humanoid-normalized-stochastic-scores.png │ │ │ │ ├── Humanoid-unnormalized-deterministic-scores.png │ │ │ │ ├── Humanoid-unnormalized-stochastic-scores.png │ │ │ │ ├── HumanoidStandup-normalized-deterministic-scores.png │ │ │ │ ├── HumanoidStandup-normalized-stochastic-scores.png │ │ │ │ ├── HumanoidStandup-unnormalized-deterministic-scores.png │ │ │ │ ├── HumanoidStandup-unnormalized-stochastic-scores.png │ │ │ │ ├── Walker2d-normalized-deterministic-scores.png │ │ │ │ ├── Walker2d-normalized-stochastic-scores.png │ │ │ │ ├── Walker2d-unnormalized-deterministic-scores.png │ │ │ │ ├── Walker2d-unnormalized-stochastic-scores.png │ │ │ │ ├── gail-result.md │ │ │ │ ├── halfcheetah-training.png │ │ │ │ ├── hopper-training.png │ │ │ │ ├── humanoid-training.png │ │ │ │ ├── humanoidstandup-training.png │ │ │ │ └── walker2d-training.png │ │ │ ├── run_mujoco.py │ │ │ ├── statistics.py │ │ │ └── trpo_mpi.py │ │ ├── her │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── actor_critic.py │ │ │ ├── ddpg.py │ │ │ ├── experiment │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── data_generation │ │ │ │ │ └── fetch_data_generation.py │ │ │ │ ├── play.py │ │ │ │ ├── plot.py │ │ │ │ └── train.py │ │ │ ├── her.py │ │ │ ├── normalizer.py │ │ │ ├── replay_buffer.py │ │ │ ├── rollout.py │ │ │ └── util.py │ │ ├── logger.py │ │ ├── ppo1 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── cnn_policy.py │ │ │ ├── mlp_policy.py │ │ │ ├── pposgd_simple.py │ │ │ ├── run_atari.py │ │ │ ├── run_humanoid.py │ │ │ ├── run_mujoco.py │ │ │ └── run_robotics.py │ │ ├── ppo2 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── defaults.py │ │ │ ├── microbatched_model.py │ │ │ ├── model.py │ │ │ ├── ppo2.py │ │ │ ├── runner.py │ │ │ └── test_microbatches.py │ │ ├── results_plotter.py │ │ ├── run.py │ │ └── trpo_mpi │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── defaults.py │ │ │ └── trpo_mpi.py │ ├── data │ │ ├── cartpole.gif │ │ ├── fetchPickAndPlaceContrast.png │ │ └── logo.jpg │ ├── docs │ │ └── viz │ │ │ └── viz.ipynb │ ├── setup.cfg │ └── setup.py ├── bc.py ├── bc_degredation_data │ ├── beamrider_batch_rewards.csv │ ├── beamrider_degredation_plot.png │ ├── breakout_batch_rewards.csv │ ├── breakout_degredation_plot.png │ ├── create_table.py │ ├── enduro_batch_rewards.csv │ ├── enduro_degredation_plot.png │ ├── performance_table.txt │ ├── plot_degredation.py │ ├── pong_batch_rewards.csv │ ├── pong_degredation_plot.png │ ├── qbert_batch_rewards.csv │ ├── qbert_degredation_plot.png │ ├── seaquest_batch_rewards.csv │ ├── seaquest_degredation_plot.png │ ├── spaceinvaders_batch_rewards.csv │ └── spaceinvaders_degredation_plot.png ├── bc_degredation_data_generator.py ├── checkpoints │ ├── beamrider_novice_demos_network.pth.tar │ ├── breakout_novice_demos_network.pth.tar │ ├── breakout_standard_bc_network.pth.tar │ ├── enduro_novice_demos_network.pth.tar │ ├── hero_novice_demos_network.pth.tar │ ├── pong_novice_demos_network.pth.tar │ ├── qbert_novice_demos_network.pth.tar │ ├── seaquest_novice_demos_network.pth.tar │ └── spaceinvaders_novice_demos_network.pth.tar ├── cnn.py ├── dataset.py ├── evaluateDREXpolicy.py ├── evaluate_bc.py ├── evaluator.py ├── figs │ ├── beamrider_gt_vs_pred_rewards_progress_sigmoid.png │ ├── beamridermax_attention.png │ ├── beamridermax_frames.png │ ├── beamridermin_attention.png │ ├── beamridermin_frames.png │ ├── breakout_gt_vs_pred_rewards_progress_sigmoid.png │ ├── breakoutmax_attention.png │ ├── breakoutmax_frames.png │ ├── breakoutmin_attention.png │ ├── breakoutmin_frames.png │ ├── enduro_gt_vs_pred_rewards_progress_sigmoid.png │ ├── enduromax_attention.png │ ├── enduromax_frames.png │ ├── enduromin_attention.png │ ├── enduromin_frames.png │ ├── pong_gt_vs_pred_rewards_progress_sigmoid.png │ ├── pongmax_attention.png │ ├── pongmax_frames.png │ ├── pongmin_attention.png │ ├── pongmin_frames.png │ ├── qbert_gt_vs_pred_rewards_progress_sigmoid.png │ ├── qbertmax_attention.png │ ├── qbertmax_frames.png │ ├── qbertmin_attention.png │ ├── qbertmin_frames.png │ ├── seaquest_gt_vs_pred_rewards_progress_sigmoid.png │ ├── seaquestmax_attention.png │ ├── seaquestmax_frames.png │ ├── seaquestmin_attention.png │ ├── seaquestmin_frames.png │ ├── spaceinvaders_gt_vs_pred_rewards_progress_sigmoid.png │ ├── spaceinvadersmax_attention.png │ ├── spaceinvadersmax_frames.png │ ├── spaceinvadersmin_attention.png │ └── spaceinvadersmin_frames.png ├── generate_reward_extrapolation_plots.py ├── learned_models │ ├── beamrider_five_bins_noop_earlystop.params │ ├── breakout_five_bins_noop_earlystop.params │ ├── enduro_five_bins_noop_earlystop.params │ ├── pong_five_bins_noop_earlystop.params │ ├── qbert_five_bins_noop_earlystop.params │ ├── seaquest_five_bins_noop_earlystop.params │ └── spaceinvaders_five_bins_noop_earlystop.params ├── main.py ├── main_bc_degredation.py ├── preprocess.py ├── run_test.py ├── state.py ├── synthesize_rankings_bc.py ├── train.py └── utils.py └── drex-mujoco ├── .gitignore ├── README.md ├── bc_mujoco.py ├── bc_noise_dataset.py ├── demos └── suboptimal_demos │ ├── halfcheetah │ └── dataset.pkl │ └── hopper │ └── dataset.pkl ├── drex.py ├── environment.yml ├── learner ├── .gitignore ├── README.md ├── baselines │ ├── .benchmark_pattern │ ├── .gitignore │ ├── .travis.yml │ ├── Dockerfile │ ├── LICENSE │ ├── README.md │ ├── baselines │ │ ├── __init__.py │ │ ├── a2c │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── a2c.py │ │ │ ├── runner.py │ │ │ └── utils.py │ │ ├── acer │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acer.py │ │ │ ├── buffer.py │ │ │ ├── defaults.py │ │ │ ├── policies.py │ │ │ └── runner.py │ │ ├── acktr │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── acktr.py │ │ │ ├── defaults.py │ │ │ ├── kfac.py │ │ │ ├── kfac_utils.py │ │ │ └── utils.py │ │ ├── bench │ │ │ ├── __init__.py │ │ │ ├── benchmarks.py │ │ │ └── monitor.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── atari_wrappers.py │ │ │ ├── cg.py │ │ │ ├── cmd_util.py │ │ │ ├── console_util.py │ │ │ ├── custom_reward_wrapper.py │ │ │ ├── dataset.py │ │ │ ├── distributions.py │ │ │ ├── input.py │ │ │ ├── math_util.py │ │ │ ├── misc_util.py │ │ │ ├── models.py │ │ │ ├── mpi_adam.py │ │ │ ├── mpi_adam_optimizer.py │ │ │ ├── mpi_fork.py │ │ │ ├── mpi_moments.py │ │ │ ├── mpi_running_mean_std.py │ │ │ ├── mpi_util.py │ │ │ ├── plot_util.py │ │ │ ├── policies.py │ │ │ ├── retro_wrappers.py │ │ │ ├── runners.py │ │ │ ├── running_mean_std.py │ │ │ ├── schedules.py │ │ │ ├── segment_tree.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── envs │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fixed_sequence_env.py │ │ │ │ │ ├── identity_env.py │ │ │ │ │ └── mnist_env.py │ │ │ │ ├── test_cartpole.py │ │ │ │ ├── test_doc_examples.py │ │ │ │ ├── test_env_after_learn.py │ │ │ │ ├── test_fixed_sequence.py │ │ │ │ ├── test_identity.py │ │ │ │ ├── test_mnist.py │ │ │ │ ├── test_schedules.py │ │ │ │ ├── test_segment_tree.py │ │ │ │ ├── test_serialization.py │ │ │ │ ├── test_tf_util.py │ │ │ │ └── util.py │ │ │ ├── tf_util.py │ │ │ ├── tile_images.py │ │ │ └── vec_env │ │ │ │ ├── __init__.py │ │ │ │ ├── dummy_vec_env.py │ │ │ │ ├── shmem_vec_env.py │ │ │ │ ├── subproc_vec_env.py │ │ │ │ ├── test_vec_env.py │ │ │ │ ├── test_video_recorder.py │ │ │ │ ├── util.py │ │ │ │ ├── vec_frame_stack.py │ │ │ │ ├── vec_monitor.py │ │ │ │ ├── vec_normalize.py │ │ │ │ └── vec_video_recorder.py │ │ ├── ddpg │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── ddpg.py │ │ │ ├── ddpg_learner.py │ │ │ ├── memory.py │ │ │ ├── models.py │ │ │ └── noise.py │ │ ├── deepq │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── build_graph.py │ │ │ ├── deepq.py │ │ │ ├── defaults.py │ │ │ ├── experiments │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_cartpole.py │ │ │ │ ├── enjoy_cartpole.py │ │ │ │ ├── enjoy_mountaincar.py │ │ │ │ ├── enjoy_pong.py │ │ │ │ ├── train_cartpole.py │ │ │ │ ├── train_mountaincar.py │ │ │ │ └── train_pong.py │ │ │ ├── models.py │ │ │ ├── replay_buffer.py │ │ │ └── utils.py │ │ ├── gail │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── adversary.py │ │ │ ├── adversary_atari.py │ │ │ ├── behavior_clone.py │ │ │ ├── cnn_policy.py │ │ │ ├── dataset │ │ │ │ ├── __init__.py │ │ │ │ ├── atari_dset.py │ │ │ │ ├── atari_gen.py │ │ │ │ └── mujoco_dset.py │ │ │ ├── gail-eval.py │ │ │ ├── mlp_policy.py │ │ │ ├── result │ │ │ │ ├── HalfCheetah-normalized-deterministic-scores.png │ │ │ │ ├── HalfCheetah-normalized-stochastic-scores.png │ │ │ │ ├── HalfCheetah-unnormalized-deterministic-scores.png │ │ │ │ ├── HalfCheetah-unnormalized-stochastic-scores.png │ │ │ │ ├── Hopper-normalized-deterministic-scores.png │ │ │ │ ├── Hopper-normalized-stochastic-scores.png │ │ │ │ ├── Hopper-unnormalized-deterministic-scores.png │ │ │ │ ├── Hopper-unnormalized-stochastic-scores.png │ │ │ │ ├── Humanoid-normalized-deterministic-scores.png │ │ │ │ ├── Humanoid-normalized-stochastic-scores.png │ │ │ │ ├── Humanoid-unnormalized-deterministic-scores.png │ │ │ │ ├── Humanoid-unnormalized-stochastic-scores.png │ │ │ │ ├── HumanoidStandup-normalized-deterministic-scores.png │ │ │ │ ├── HumanoidStandup-normalized-stochastic-scores.png │ │ │ │ ├── HumanoidStandup-unnormalized-deterministic-scores.png │ │ │ │ ├── HumanoidStandup-unnormalized-stochastic-scores.png │ │ │ │ ├── Walker2d-normalized-deterministic-scores.png │ │ │ │ ├── Walker2d-normalized-stochastic-scores.png │ │ │ │ ├── Walker2d-unnormalized-deterministic-scores.png │ │ │ │ ├── Walker2d-unnormalized-stochastic-scores.png │ │ │ │ ├── gail-result.md │ │ │ │ ├── halfcheetah-training.png │ │ │ │ ├── hopper-training.png │ │ │ │ ├── humanoid-training.png │ │ │ │ ├── humanoidstandup-training.png │ │ │ │ └── walker2d-training.png │ │ │ ├── run_atari.py │ │ │ ├── run_mujoco.py │ │ │ ├── statistics.py │ │ │ └── trpo_mpi.py │ │ ├── her │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── actor_critic.py │ │ │ ├── ddpg.py │ │ │ ├── experiment │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── data_generation │ │ │ │ │ └── fetch_data_generation.py │ │ │ │ ├── play.py │ │ │ │ ├── plot.py │ │ │ │ └── train.py │ │ │ ├── her.py │ │ │ ├── normalizer.py │ │ │ ├── replay_buffer.py │ │ │ ├── rollout.py │ │ │ └── util.py │ │ ├── logger.py │ │ ├── ppo1 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── cnn_policy.py │ │ │ ├── mlp_policy.py │ │ │ ├── pposgd_simple.py │ │ │ ├── run_atari.py │ │ │ ├── run_humanoid.py │ │ │ ├── run_mujoco.py │ │ │ └── run_robotics.py │ │ ├── ppo2 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── defaults.py │ │ │ ├── microbatched_model.py │ │ │ ├── model.py │ │ │ ├── ppo2.py │ │ │ ├── runner.py │ │ │ └── test_microbatches.py │ │ ├── results_plotter.py │ │ ├── run.py │ │ └── trpo_mpi │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── defaults.py │ │ │ └── trpo_mpi.py │ ├── benchmarks_atari10M.htm │ ├── benchmarks_mujoco1M.htm │ ├── docs │ │ └── viz │ │ │ └── viz.ipynb │ ├── setup.cfg │ └── setup.py └── run_test.py ├── requirements.txt ├── tf_commons └── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Daniel Brown and Wonjoon Goo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Better-than-Demonstrator Imitation Learning via Automatically-Ranked Demonstrations 2 | Daniel Brown, Wonjoon Goo, and Scott Niekum. 3 | 4 | 5 |

6 | View on ArXiv | 7 | Project Website 8 |

9 | 10 | 11 | This repository contains a code used to conduct experiments reported in the paper [Better-than-Demonstrator Imitation Learning via Automatically-Ranked Demonstrations](https://arxiv.org/pdf/1907.03976.pdf), presented at the Conference on Robot Learning (CoRL), 2019. 12 | 13 | If you find this repository is useful in your research, please cite the paper: 14 | ``` 15 | @InProceedings{brown2019drex, 16 | title = {Better-than-Demonstrator Imitation Learning via Automatically-Ranked Demonstrations}, 17 | author = {Brown, Daniel S. and Goo, Wonjoon and Niekum, Scott}, 18 | booktitle = {Proceedings of the 3rd Conference on Robot Learning}, 19 | year = {2019} 20 | } 21 | ``` 22 | 23 | If you have any questions or problems with the code, please send an email to [Daniel](http://www.cs.utexas.edu/~dsbrown/) or [Wonjoon](http://dev.wonjoon.me/). 24 | -------------------------------------------------------------------------------- /drex-atari/baselines/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | 3 | RUN apt-get -y update && apt-get -y install ffmpeg 4 | # RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv 5 | 6 | ENV CODE_DIR /root/code 7 | 8 | COPY . $CODE_DIR/baselines 9 | WORKDIR $CODE_DIR/baselines 10 | 11 | # Clean up pycache and pyc files 12 | RUN rm -rf __pycache__ && \ 13 | find . -name "*.pyc" -delete && \ 14 | pip install tensorflow && \ 15 | pip install -e .[test] 16 | 17 | 18 | CMD /bin/bash 19 | -------------------------------------------------------------------------------- /drex-atari/baselines/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/a2c/README.md: -------------------------------------------------------------------------------- 1 | # A2C 2 | 3 | - Original paper: https://arxiv.org/abs/1602.01783 4 | - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ 5 | - `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options 6 | - also refer to the repo-wide [README.md](../../README.md#training-models) 7 | 8 | ## Files 9 | - `run_atari`: file used to run the algorithm. 10 | - `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...). 11 | - `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training) 12 | - learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm. 13 | - `runner.py`: class used to generates a batch of experiences 14 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/a2c/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acer/README.md: -------------------------------------------------------------------------------- 1 | # ACER 2 | 3 | - Original paper: https://arxiv.org/abs/1611.01224 4 | - `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 5 | - also refer to the repo-wide [README.md](../../README.md#training-models) 6 | 7 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/acer/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acer/defaults.py: -------------------------------------------------------------------------------- 1 | def atari(): 2 | return dict( 3 | lrschedule='constant' 4 | ) 5 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acktr/README.md: -------------------------------------------------------------------------------- 1 | # ACKTR 2 | 3 | - Original paper: https://arxiv.org/abs/1708.05144 4 | - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ 5 | - `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 6 | - also refer to the repo-wide [README.md](../../README.md#training-models) 7 | 8 | ## ACKTR with continuous action spaces 9 | The code of ACKTR has been refactored to handle both discrete and continuous action spaces uniformly. In the original version, discrete and continuous action spaces were handled by different code (actkr_disc.py and acktr_cont.py) with little overlap. If interested in the original version of the acktr for continuous action spaces, use `old_acktr_cont` branch. Note that original code performs better on the mujoco tasks than the refactored version; we are still investigating why. 10 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acktr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/acktr/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acktr/defaults.py: -------------------------------------------------------------------------------- 1 | def mujoco(): 2 | return dict( 3 | nsteps=2500, 4 | value_network='copy' 5 | ) 6 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/acktr/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None): 4 | with tf.variable_scope(name, reuse=reuse): 5 | assert (len(tf.get_variable_scope().name.split('/')) == 2) 6 | 7 | w = tf.get_variable("w", [x.get_shape()[1], size], initializer=weight_init) 8 | b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init)) 9 | weight_decay_fc = 3e-4 10 | 11 | if weight_loss_dict is not None: 12 | weight_decay = tf.multiply(tf.nn.l2_loss(w), weight_decay_fc, name='weight_decay_loss') 13 | if weight_loss_dict is not None: 14 | weight_loss_dict[w] = weight_decay_fc 15 | weight_loss_dict[b] = 0.0 16 | 17 | tf.add_to_collection(tf.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay) 18 | 19 | return tf.nn.bias_add(tf.matmul(x, w), b) 20 | 21 | def kl_div(action_dist1, action_dist2, action_size): 22 | mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:] 23 | mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:] 24 | 25 | numerator = tf.square(mean1 - mean2) + tf.square(std1) - tf.square(std2) 26 | denominator = 2 * tf.square(std2) + 1e-8 27 | return tf.reduce_sum( 28 | numerator/denominator + tf.log(std2) - tf.log(std1),reduction_indices=-1) 29 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/bench/__init__.py: -------------------------------------------------------------------------------- 1 | from baselines.bench.benchmarks import * 2 | from baselines.bench.monitor import * 3 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from baselines.common.console_util import * 3 | from baselines.common.dataset import Dataset 4 | from baselines.common.math_util import * 5 | from baselines.common.misc_util import * 6 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/cg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): 3 | """ 4 | Demmel p 312 5 | """ 6 | p = b.copy() 7 | r = b.copy() 8 | x = np.zeros_like(b) 9 | rdotr = r.dot(r) 10 | 11 | fmtstr = "%10i %10.3g %10.3g" 12 | titlestr = "%10s %10s %10s" 13 | if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) 14 | 15 | for i in range(cg_iters): 16 | if callback is not None: 17 | callback(x) 18 | if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) 19 | z = f_Ax(p) 20 | v = rdotr / p.dot(z) 21 | x += v*p 22 | r -= v*z 23 | newrdotr = r.dot(r) 24 | mu = newrdotr/rdotr 25 | p = r + mu*p 26 | 27 | rdotr = newrdotr 28 | if rdotr < residual_tol: 29 | break 30 | 31 | if callback is not None: 32 | callback(x) 33 | if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 34 | return x 35 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Dataset(object): 4 | def __init__(self, data_map, deterministic=False, shuffle=True): 5 | self.data_map = data_map 6 | self.deterministic = deterministic 7 | self.enable_shuffle = shuffle 8 | self.n = next(iter(data_map.values())).shape[0] 9 | self._next_id = 0 10 | self.shuffle() 11 | 12 | def shuffle(self): 13 | if self.deterministic: 14 | return 15 | perm = np.arange(self.n) 16 | np.random.shuffle(perm) 17 | 18 | for key in self.data_map: 19 | self.data_map[key] = self.data_map[key][perm] 20 | 21 | self._next_id = 0 22 | 23 | def next_batch(self, batch_size): 24 | if self._next_id >= self.n and self.enable_shuffle: 25 | self.shuffle() 26 | 27 | cur_id = self._next_id 28 | cur_batch_size = min(batch_size, self.n - self._next_id) 29 | self._next_id += cur_batch_size 30 | 31 | data_map = dict() 32 | for key in self.data_map: 33 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size] 34 | return data_map 35 | 36 | def iterate_once(self, batch_size): 37 | if self.enable_shuffle: self.shuffle() 38 | 39 | while self._next_id <= self.n - batch_size: 40 | yield self.next_batch(batch_size) 41 | self._next_id = 0 42 | 43 | def subset(self, num_elements, deterministic=True): 44 | data_map = dict() 45 | for key in self.data_map: 46 | data_map[key] = self.data_map[key][:num_elements] 47 | return Dataset(data_map, deterministic) 48 | 49 | 50 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 51 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 52 | arrays = tuple(map(np.asarray, arrays)) 53 | n = arrays[0].shape[0] 54 | assert all(a.shape[0] == n for a in arrays[1:]) 55 | inds = np.arange(n) 56 | if shuffle: np.random.shuffle(inds) 57 | sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches 58 | for batch_inds in np.array_split(inds, sections): 59 | if include_final_partial_batch or len(batch_inds) == batch_size: 60 | yield tuple(a[batch_inds] for a in arrays) 61 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from gym.spaces import Discrete, Box, MultiDiscrete 4 | 5 | def observation_placeholder(ob_space, batch_size=None, name='Ob'): 6 | ''' 7 | Create placeholder to feed observations into of the size appropriate to the observation space 8 | 9 | Parameters: 10 | ---------- 11 | 12 | ob_space: gym.Space observation space 13 | 14 | batch_size: int size of the batch to be fed into input. Can be left None in most cases. 15 | 16 | name: str name of the placeholder 17 | 18 | Returns: 19 | ------- 20 | 21 | tensorflow placeholder tensor 22 | ''' 23 | 24 | assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 25 | 'Can only deal with Discrete and Box observation spaces for now' 26 | 27 | dtype = ob_space.dtype 28 | if dtype == np.int8: 29 | dtype = np.uint8 30 | 31 | return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) 32 | 33 | 34 | def observation_input(ob_space, batch_size=None, name='Ob'): 35 | ''' 36 | Create placeholder to feed observations into of the size appropriate to the observation space, and add input 37 | encoder of the appropriate type. 38 | ''' 39 | 40 | placeholder = observation_placeholder(ob_space, batch_size, name) 41 | return placeholder, encode_observation(ob_space, placeholder) 42 | 43 | def encode_observation(ob_space, placeholder): 44 | ''' 45 | Encode input in the way that is appropriate to the observation space 46 | 47 | Parameters: 48 | ---------- 49 | 50 | ob_space: gym.Space observation space 51 | 52 | placeholder: tf.placeholder observation input placeholder 53 | ''' 54 | if isinstance(ob_space, Discrete): 55 | return tf.to_float(tf.one_hot(placeholder, ob_space.n)) 56 | elif isinstance(ob_space, Box): 57 | return tf.to_float(placeholder) 58 | elif isinstance(ob_space, MultiDiscrete): 59 | placeholder = tf.cast(placeholder, tf.int32) 60 | one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] 61 | return tf.concat(one_hots, axis=-1) 62 | else: 63 | raise NotImplementedError 64 | 65 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/mpi_adam_optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mpi4py import MPI 4 | 5 | class MpiAdamOptimizer(tf.train.AdamOptimizer): 6 | """Adam optimizer that averages gradients across mpi processes.""" 7 | def __init__(self, comm, **kwargs): 8 | self.comm = comm 9 | tf.train.AdamOptimizer.__init__(self, **kwargs) 10 | def compute_gradients(self, loss, var_list, **kwargs): 11 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) 12 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 13 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) 14 | shapes = [v.shape.as_list() for g, v in grads_and_vars] 15 | sizes = [int(np.prod(s)) for s in shapes] 16 | 17 | num_tasks = self.comm.Get_size() 18 | buf = np.zeros(sum(sizes), np.float32) 19 | 20 | def _collect_grads(flat_grad): 21 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) 22 | np.divide(buf, float(num_tasks), out=buf) 23 | return buf 24 | 25 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32) 26 | avg_flat_grad.set_shape(flat_grad.shape) 27 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0) 28 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v) 29 | for g, (_, v) in zip(avg_grads, grads_and_vars)] 30 | 31 | return avg_grads_and_vars 32 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/mpi_fork.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, sys 2 | 3 | def mpi_fork(n, bind_to_core=False): 4 | """Re-launches the current script with workers 5 | Returns "parent" for original parent, "child" for MPI children 6 | """ 7 | if n<=1: 8 | return "child" 9 | if os.getenv("IN_MPI") is None: 10 | env = os.environ.copy() 11 | env.update( 12 | MKL_NUM_THREADS="1", 13 | OMP_NUM_THREADS="1", 14 | IN_MPI="1" 15 | ) 16 | args = ["mpirun", "-np", str(n)] 17 | if bind_to_core: 18 | args += ["-bind-to", "core"] 19 | args += [sys.executable] + sys.argv 20 | subprocess.check_call(args, env=env) 21 | return "parent" 22 | else: 23 | return "child" 24 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | from baselines.common import zipsame 4 | 5 | 6 | def mpi_mean(x, axis=0, comm=None, keepdims=False): 7 | x = np.asarray(x) 8 | assert x.ndim > 0 9 | if comm is None: comm = MPI.COMM_WORLD 10 | xsum = x.sum(axis=axis, keepdims=keepdims) 11 | n = xsum.size 12 | localsum = np.zeros(n+1, x.dtype) 13 | localsum[:n] = xsum.ravel() 14 | localsum[n] = x.shape[axis] 15 | globalsum = np.zeros_like(localsum) 16 | comm.Allreduce(localsum, globalsum, op=MPI.SUM) 17 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] 18 | 19 | def mpi_moments(x, axis=0, comm=None, keepdims=False): 20 | x = np.asarray(x) 21 | assert x.ndim > 0 22 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) 23 | sqdiffs = np.square(x - mean) 24 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 25 | assert count1 == count 26 | std = np.sqrt(meansqdiff) 27 | if not keepdims: 28 | newshape = mean.shape[:axis] + mean.shape[axis+1:] 29 | mean = mean.reshape(newshape) 30 | std = std.reshape(newshape) 31 | return mean, std, count 32 | 33 | 34 | def test_runningmeanstd(): 35 | import subprocess 36 | subprocess.check_call(['mpirun', '-np', '3', 37 | 'python','-c', 38 | 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) 39 | 40 | def _helper_runningmeanstd(): 41 | comm = MPI.COMM_WORLD 42 | np.random.seed(0) 43 | for (triple,axis) in [ 44 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), 45 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), 46 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), 47 | ]: 48 | 49 | 50 | x = np.concatenate(triple, axis=axis) 51 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] 52 | 53 | 54 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) 55 | 56 | for (a1,a2) in zipsame(ms1, ms2): 57 | print(a1, a2) 58 | assert np.allclose(a1, a2) 59 | print("ok!") 60 | 61 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/runners.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | 4 | class AbstractEnvRunner(ABC): 5 | def __init__(self, *, env, model, nsteps): 6 | self.env = env 7 | self.model = model 8 | self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1 9 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape 10 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) 11 | self.obs[:] = env.reset() 12 | self.nsteps = nsteps 13 | self.states = model.initial_state 14 | self.dones = [False for _ in range(nenv)] 15 | 16 | @abstractmethod 17 | def run(self): 18 | raise NotImplementedError 19 | 20 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/common/tests/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/common/tests/envs/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/envs/fixed_sequence_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Env 3 | from gym.spaces import Discrete 4 | 5 | 6 | class FixedSequenceEnv(Env): 7 | def __init__( 8 | self, 9 | n_actions=10, 10 | seed=0, 11 | episode_len=100 12 | ): 13 | self.np_random = np.random.RandomState() 14 | self.np_random.seed(seed) 15 | self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)] 16 | 17 | self.action_space = Discrete(n_actions) 18 | self.observation_space = Discrete(1) 19 | 20 | self.episode_len = episode_len 21 | self.time = 0 22 | self.reset() 23 | 24 | def reset(self): 25 | self.time = 0 26 | return 0 27 | 28 | def step(self, actions): 29 | rew = self._get_reward(actions) 30 | self._choose_next_state() 31 | done = False 32 | if self.episode_len and self.time >= self.episode_len: 33 | rew = 0 34 | done = True 35 | 36 | return 0, rew, done, {} 37 | 38 | def _choose_next_state(self): 39 | self.time += 1 40 | 41 | def _get_reward(self, actions): 42 | return 1 if actions == self.sequence[self.time] else 0 43 | 44 | 45 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/envs/identity_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import abstractmethod 3 | from gym import Env 4 | from gym.spaces import MultiDiscrete, Discrete, Box 5 | 6 | 7 | class IdentityEnv(Env): 8 | def __init__( 9 | self, 10 | episode_len=None 11 | ): 12 | 13 | self.episode_len = episode_len 14 | self.time = 0 15 | self.reset() 16 | 17 | def reset(self): 18 | self._choose_next_state() 19 | self.time = 0 20 | self.observation_space = self.action_space 21 | 22 | return self.state 23 | 24 | def step(self, actions): 25 | rew = self._get_reward(actions) 26 | self._choose_next_state() 27 | done = False 28 | if self.episode_len and self.time >= self.episode_len: 29 | rew = 0 30 | done = True 31 | 32 | return self.state, rew, done, {} 33 | 34 | def _choose_next_state(self): 35 | self.state = self.action_space.sample() 36 | self.time += 1 37 | 38 | @abstractmethod 39 | def _get_reward(self, actions): 40 | raise NotImplementedError 41 | 42 | 43 | class DiscreteIdentityEnv(IdentityEnv): 44 | def __init__( 45 | self, 46 | dim, 47 | episode_len=None, 48 | ): 49 | 50 | self.action_space = Discrete(dim) 51 | super().__init__(episode_len=episode_len) 52 | 53 | def _get_reward(self, actions): 54 | return 1 if self.state == actions else 0 55 | 56 | class MultiDiscreteIdentityEnv(IdentityEnv): 57 | def __init__( 58 | self, 59 | dims, 60 | episode_len=None, 61 | ): 62 | 63 | self.action_space = MultiDiscrete(dims) 64 | super().__init__(episode_len=episode_len) 65 | 66 | def _get_reward(self, actions): 67 | return 1 if all(self.state == actions) else 0 68 | 69 | 70 | class BoxIdentityEnv(IdentityEnv): 71 | def __init__( 72 | self, 73 | shape, 74 | episode_len=None, 75 | ): 76 | 77 | self.action_space = Box(low=-1.0, high=1.0, shape=shape) 78 | super().__init__(episode_len=episode_len) 79 | 80 | def _get_reward(self, actions): 81 | diff = actions - self.state 82 | diff = diff[:] 83 | return -0.5 * np.dot(diff, diff) 84 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/envs/mnist_env.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import tempfile 4 | from gym import Env 5 | from gym.spaces import Discrete, Box 6 | 7 | 8 | 9 | class MnistEnv(Env): 10 | def __init__( 11 | self, 12 | seed=0, 13 | episode_len=None, 14 | no_images=None 15 | ): 16 | import filelock 17 | from tensorflow.examples.tutorials.mnist import input_data 18 | # we could use temporary directory for this with a context manager and 19 | # TemporaryDirecotry, but then each test that uses mnist would re-download the data 20 | # this way the data is not cleaned up, but we only download it once per machine 21 | mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data') 22 | with filelock.FileLock(mnist_path + '.lock'): 23 | self.mnist = input_data.read_data_sets(mnist_path) 24 | 25 | self.np_random = np.random.RandomState() 26 | self.np_random.seed(seed) 27 | 28 | self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1)) 29 | self.action_space = Discrete(10) 30 | self.episode_len = episode_len 31 | self.time = 0 32 | self.no_images = no_images 33 | 34 | self.train_mode() 35 | self.reset() 36 | 37 | def reset(self): 38 | self._choose_next_state() 39 | self.time = 0 40 | 41 | return self.state[0] 42 | 43 | def step(self, actions): 44 | rew = self._get_reward(actions) 45 | self._choose_next_state() 46 | done = False 47 | if self.episode_len and self.time >= self.episode_len: 48 | rew = 0 49 | done = True 50 | 51 | return self.state[0], rew, done, {} 52 | 53 | def train_mode(self): 54 | self.dataset = self.mnist.train 55 | 56 | def test_mode(self): 57 | self.dataset = self.mnist.test 58 | 59 | def _choose_next_state(self): 60 | max_index = (self.no_images if self.no_images is not None else self.dataset.num_examples) - 1 61 | index = self.np_random.randint(0, max_index) 62 | image = self.dataset.images[index].reshape(28,28,1)*255 63 | label = self.dataset.labels[index] 64 | self.state = (image, label) 65 | self.time += 1 66 | 67 | def _get_reward(self, actions): 68 | return 1 if self.state[1] == actions else 0 69 | 70 | 71 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_cartpole.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | 4 | from baselines.run import get_learn_function 5 | from baselines.common.tests.util import reward_per_episode_test 6 | 7 | common_kwargs = dict( 8 | total_timesteps=30000, 9 | network='mlp', 10 | gamma=1.0, 11 | seed=0, 12 | ) 13 | 14 | learn_kwargs = { 15 | 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), 16 | 'acer': dict(value_network='copy'), 17 | 'acktr': dict(nsteps=32, value_network='copy', is_async=False), 18 | 'deepq': dict(total_timesteps=20000), 19 | 'ppo2': dict(value_network='copy'), 20 | 'trpo_mpi': {} 21 | } 22 | 23 | @pytest.mark.slow 24 | @pytest.mark.parametrize("alg", learn_kwargs.keys()) 25 | def test_cartpole(alg): 26 | ''' 27 | Test if the algorithm (with an mlp policy) 28 | can learn to balance the cartpole 29 | ''' 30 | 31 | kwargs = common_kwargs.copy() 32 | kwargs.update(learn_kwargs[alg]) 33 | 34 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 35 | def env_fn(): 36 | 37 | env = gym.make('CartPole-v0') 38 | env.seed(0) 39 | return env 40 | 41 | reward_per_episode_test(env_fn, learn_fn, 100) 42 | 43 | if __name__ == '__main__': 44 | test_cartpole('acer') 45 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_doc_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | try: 3 | import mujoco_py 4 | _mujoco_present = True 5 | except BaseException: 6 | mujoco_py = None 7 | _mujoco_present = False 8 | 9 | 10 | @pytest.mark.skipif( 11 | not _mujoco_present, 12 | reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library' 13 | ) 14 | def test_lstm_example(): 15 | import tensorflow as tf 16 | from baselines.common import policies, models, cmd_util 17 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 18 | 19 | # create vectorized environment 20 | venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)]) 21 | 22 | with tf.Session() as sess: 23 | # build policy based on lstm network with 128 units 24 | policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1) 25 | 26 | # initialize tensorflow variables 27 | sess.run(tf.global_variables_initializer()) 28 | 29 | # prepare environment variables 30 | ob = venv.reset() 31 | state = policy.initial_state 32 | done = [False] 33 | step_counter = 0 34 | 35 | # run a single episode until the end (i.e. until done) 36 | while True: 37 | action, _, state, _ = policy.step(ob, S=state, M=done) 38 | ob, reward, done, _ = venv.step(action) 39 | step_counter += 1 40 | if done: 41 | break 42 | 43 | 44 | assert step_counter > 5 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_env_after_learn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | import tensorflow as tf 4 | 5 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 6 | from baselines.run import get_learn_function 7 | from baselines.common.tf_util import make_session 8 | 9 | algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] 10 | 11 | @pytest.mark.parametrize('algo', algos) 12 | def test_env_after_learn(algo): 13 | def make_env(): 14 | # acktr requires too much RAM, fails on travis 15 | env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4') 16 | return env 17 | 18 | make_session(make_default=True, graph=tf.Graph()) 19 | env = SubprocVecEnv([make_env]) 20 | 21 | learn = get_learn_function(algo) 22 | 23 | # Commenting out the following line resolves the issue, though crash happens at env.reset(). 24 | learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None) 25 | 26 | env.reset() 27 | env.close() 28 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_fixed_sequence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from baselines.common.tests.envs.fixed_sequence_env import FixedSequenceEnv 3 | 4 | from baselines.common.tests.util import simple_test 5 | from baselines.run import get_learn_function 6 | 7 | common_kwargs = dict( 8 | seed=0, 9 | total_timesteps=50000, 10 | ) 11 | 12 | learn_kwargs = { 13 | 'a2c': {}, 14 | 'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1), 15 | # TODO enable sequential models for trpo_mpi (proper handling of nbatch and nsteps) 16 | # github issue: https://github.com/openai/baselines/issues/188 17 | # 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001) 18 | } 19 | 20 | 21 | alg_list = learn_kwargs.keys() 22 | rnn_list = ['lstm'] 23 | 24 | @pytest.mark.slow 25 | @pytest.mark.parametrize("alg", alg_list) 26 | @pytest.mark.parametrize("rnn", rnn_list) 27 | def test_fixed_sequence(alg, rnn): 28 | ''' 29 | Test if the algorithm (with a given policy) 30 | can learn an identity transformation (i.e. return observation as an action) 31 | ''' 32 | 33 | kwargs = learn_kwargs[alg] 34 | kwargs.update(common_kwargs) 35 | 36 | episode_len = 5 37 | env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len) 38 | learn = lambda e: get_learn_function(alg)( 39 | env=e, 40 | network=rnn, 41 | **kwargs 42 | ) 43 | 44 | simple_test(env_fn, learn, 0.7) 45 | 46 | 47 | if __name__ == '__main__': 48 | test_fixed_sequence('ppo2', 'lstm') 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | # from baselines.acer import acer_simple as acer 4 | from baselines.common.tests.envs.mnist_env import MnistEnv 5 | from baselines.common.tests.util import simple_test 6 | from baselines.run import get_learn_function 7 | 8 | 9 | # TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem? 10 | # GitHub issue https://github.com/openai/baselines/issues/189 11 | common_kwargs = { 12 | 'seed': 0, 13 | 'network':'cnn', 14 | 'gamma':0.9, 15 | 'pad':'SAME' 16 | } 17 | 18 | learn_args = { 19 | 'a2c': dict(total_timesteps=50000), 20 | 'acer': dict(total_timesteps=20000), 21 | 'deepq': dict(total_timesteps=5000), 22 | 'acktr': dict(total_timesteps=30000), 23 | 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), 24 | 'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001) 25 | } 26 | 27 | 28 | #tests pass, but are too slow on travis. Same algorithms are covered 29 | # by other tests with less compute-hungry nn's and by benchmarks 30 | @pytest.mark.skip 31 | @pytest.mark.slow 32 | @pytest.mark.parametrize("alg", learn_args.keys()) 33 | def test_mnist(alg): 34 | ''' 35 | Test if the algorithm can learn to classify MNIST digits. 36 | Uses CNN policy. 37 | ''' 38 | 39 | learn_kwargs = learn_args[alg] 40 | learn_kwargs.update(common_kwargs) 41 | 42 | learn = get_learn_function(alg) 43 | learn_fn = lambda e: learn(env=e, **learn_kwargs) 44 | env_fn = lambda: MnistEnv(seed=0, episode_len=100) 45 | 46 | simple_test(env_fn, learn_fn, 0.6) 47 | 48 | if __name__ == '__main__': 49 | test_mnist('acer') 50 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from baselines.common.schedules import ConstantSchedule, PiecewiseSchedule 4 | 5 | 6 | def test_piecewise_schedule(): 7 | ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500) 8 | 9 | assert np.isclose(ps.value(-10), 500) 10 | assert np.isclose(ps.value(0), 150) 11 | assert np.isclose(ps.value(5), 200) 12 | assert np.isclose(ps.value(9), 80) 13 | assert np.isclose(ps.value(50), 50) 14 | assert np.isclose(ps.value(80), 50) 15 | assert np.isclose(ps.value(150), 0) 16 | assert np.isclose(ps.value(175), -25) 17 | assert np.isclose(ps.value(201), 500) 18 | assert np.isclose(ps.value(500), 500) 19 | 20 | assert np.isclose(ps.value(200 - 1e-10), -50) 21 | 22 | 23 | def test_constant_schedule(): 24 | cs = ConstantSchedule(5) 25 | for i in range(-100, 100): 26 | assert np.isclose(cs.value(i), 5) 27 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tests/test_tf_util.py: -------------------------------------------------------------------------------- 1 | # tests for tf_util 2 | import tensorflow as tf 3 | from baselines.common.tf_util import ( 4 | function, 5 | initialize, 6 | single_threaded_session 7 | ) 8 | 9 | 10 | def test_function(): 11 | with tf.Graph().as_default(): 12 | x = tf.placeholder(tf.int32, (), name="x") 13 | y = tf.placeholder(tf.int32, (), name="y") 14 | z = 3 * x + 2 * y 15 | lin = function([x, y], z, givens={y: 0}) 16 | 17 | with single_threaded_session(): 18 | initialize() 19 | 20 | assert lin(2) == 6 21 | assert lin(2, 2) == 10 22 | 23 | 24 | def test_multikwargs(): 25 | with tf.Graph().as_default(): 26 | x = tf.placeholder(tf.int32, (), name="x") 27 | with tf.variable_scope("other"): 28 | x2 = tf.placeholder(tf.int32, (), name="x") 29 | z = 3 * x + 2 * x2 30 | 31 | lin = function([x, x2], z, givens={x2: 0}) 32 | with single_threaded_session(): 33 | initialize() 34 | assert lin(2) == 6 35 | assert lin(2, 2) == 10 36 | 37 | 38 | if __name__ == '__main__': 39 | test_function() 40 | test_multikwargs() 41 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def tile_images(img_nhwc): 4 | """ 5 | Tile N images into one big PxQ image 6 | (P,Q) are chosen to be as close as possible, and if N 7 | is square, then P=Q. 8 | 9 | input: img_nhwc, list or array of images, ndim=4 once turned into array 10 | n = batch index, h = height, w = width, c = channel 11 | returns: 12 | bigim_HWc, ndarray with ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | N, h, w, c = img_nhwc.shape 16 | H = int(np.ceil(np.sqrt(N))) 17 | W = int(np.ceil(float(N)/H)) 18 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 19 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 20 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 21 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 22 | return img_Hh_Ww_c 23 | 24 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/trex_utils.py: -------------------------------------------------------------------------------- 1 | def normalize_state(obs): 2 | return obs / 255.0 3 | 4 | 5 | #custom masking function for covering up the score/life portions of atari games 6 | def mask_score(obs, env_name): 7 | obs_copy = obs.copy() 8 | if env_name == "spaceinvaders" or env_name == "breakout" or env_name == "pong": 9 | #takes a stack of four observations and blacks out (sets to zero) top n rows 10 | n = 10 11 | #no_score_obs = copy.deepcopy(obs) 12 | obs_copy[:,:n,:,:] = 0 13 | elif env_name == "beamrider": 14 | n_top = 16 15 | n_bottom = 11 16 | obs_copy[:,:n_top,:,:] = 0 17 | obs_copy[:,-n_bottom:,:,:] = 0 18 | elif env_name == "enduro": 19 | n_top = 0 20 | n_bottom = 14 21 | obs_copy[:,:n_top,:,:] = 0 22 | obs_copy[:,-n_bottom:,:,:] = 0 23 | elif env_name == "hero": 24 | n_top = 0 25 | n_bottom = 30 26 | obs_copy[:,:n_top,:,:] = 0 27 | obs_copy[:,-n_bottom:,:,:] = 0 28 | elif env_name == "qbert": 29 | n_top = 12 30 | #n_bottom = 0 31 | obs_copy[:,:n_top,:,:] = 0 32 | #obs_copy[:,-n_bottom:,:,:] = 0 33 | elif env_name == "seaquest": 34 | n_top = 12 35 | n_bottom = 16 36 | obs_copy[:,:n_top,:,:] = 0 37 | obs_copy[:,-n_bottom:,:,:] = 0 38 | #cuts out divers and oxygen 39 | elif env_name == "mspacman": 40 | n_bottom = 15 #mask score and number lives left 41 | obs_copy[:,-n_bottom:,:,:] = 0 42 | elif env_name == "videopinball": 43 | n_top = 15 44 | obs_copy[:,:n_top,:,:] = 0 45 | elif env_name == "montezumarevenge": 46 | n_top = 10 47 | obs_copy[:,:n_top,:,:] = 0 48 | else: 49 | print("NOT MASKING SCORE FOR GAME: " + env_name) 50 | pass 51 | #n = 20 52 | #obs_copy[:,-n:,:,:] = 0 53 | return obs_copy 54 | 55 | def preprocess(ob, env_name): 56 | #print("masking on env", env_name) 57 | return mask_score(normalize_state(ob), env_name) 58 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/vec_env/test_video_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for asynchronous vectorized environments. 3 | """ 4 | 5 | import gym 6 | import pytest 7 | import os 8 | import glob 9 | import tempfile 10 | 11 | from .dummy_vec_env import DummyVecEnv 12 | from .shmem_vec_env import ShmemVecEnv 13 | from .subproc_vec_env import SubprocVecEnv 14 | from .vec_video_recorder import VecVideoRecorder 15 | 16 | @pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv)) 17 | @pytest.mark.parametrize('num_envs', (1, 4)) 18 | @pytest.mark.parametrize('video_length', (10, 100)) 19 | @pytest.mark.parametrize('video_interval', (1, 50)) 20 | def test_video_recorder(klass, num_envs, video_length, video_interval): 21 | """ 22 | Wrap an existing VecEnv with VevVideoRecorder, 23 | Make (video_interval + video_length + 1) steps, 24 | then check that the file is present 25 | """ 26 | 27 | def make_fn(): 28 | env = gym.make('PongNoFrameskip-v4') 29 | return env 30 | fns = [make_fn for _ in range(num_envs)] 31 | env = klass(fns) 32 | 33 | with tempfile.TemporaryDirectory() as video_path: 34 | env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length) 35 | 36 | env.reset() 37 | for _ in range(video_interval + video_length + 1): 38 | env.step([0] * num_envs) 39 | env.close() 40 | 41 | 42 | recorded_video = glob.glob(os.path.join(video_path, "*.mp4")) 43 | 44 | # first and second step 45 | assert len(recorded_video) == 2 46 | # Files are not empty 47 | assert all(os.stat(p).st_size != 0 for p in recorded_video) 48 | 49 | 50 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy an observation dict. 14 | """ 15 | return {k: np.copy(v) for k, v in obs.items()} 16 | 17 | 18 | def dict_to_obs(obs_dict): 19 | """ 20 | Convert an observation dict into a raw array if the 21 | original observation space was not a Dict space. 22 | """ 23 | if set(obs_dict.keys()) == {None}: 24 | return obs_dict[None] 25 | return obs_dict 26 | 27 | 28 | def obs_space_info(obs_space): 29 | """ 30 | Get dict-structured information about a gym.Space. 31 | 32 | Returns: 33 | A tuple (keys, shapes, dtypes): 34 | keys: a list of dict keys. 35 | shapes: a dict mapping keys to shapes. 36 | dtypes: a dict mapping keys to dtypes. 37 | """ 38 | if isinstance(obs_space, gym.spaces.Dict): 39 | assert isinstance(obs_space.spaces, OrderedDict) 40 | subspaces = obs_space.spaces 41 | else: 42 | subspaces = {None: obs_space} 43 | keys = [] 44 | shapes = {} 45 | dtypes = {} 46 | for key, box in subspaces.items(): 47 | keys.append(key) 48 | shapes[key] = box.shape 49 | dtypes[key] = box.dtype 50 | return keys, shapes, dtypes 51 | 52 | 53 | def obs_to_dict(obs): 54 | """ 55 | Convert an observation into a dict. 56 | """ 57 | if isinstance(obs, dict): 58 | return obs 59 | return {None: obs} 60 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | import numpy as np 3 | from gym import spaces 4 | 5 | 6 | class VecFrameStack(VecEnvWrapper): 7 | def __init__(self, venv, nstack): 8 | self.venv = venv 9 | self.nstack = nstack 10 | wos = venv.observation_space # wrapped ob space 11 | low = np.repeat(wos.low, self.nstack, axis=-1) 12 | high = np.repeat(wos.high, self.nstack, axis=-1) 13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 16 | 17 | def step_wait(self): 18 | obs, rews, news, infos = self.venv.step_wait() 19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 20 | for (i, new) in enumerate(news): 21 | if new: 22 | self.stackedobs[i] = 0 23 | self.stackedobs[..., -obs.shape[-1]:] = obs 24 | return self.stackedobs, rews, news, infos 25 | 26 | def reset(self): 27 | obs = self.venv.reset() 28 | self.stackedobs[...] = 0 29 | self.stackedobs[..., -obs.shape[-1]:] = obs 30 | return self.stackedobs 31 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | from baselines.bench.monitor import ResultsWriter 3 | import numpy as np 4 | import time 5 | 6 | 7 | class VecMonitor(VecEnvWrapper): 8 | def __init__(self, venv, filename=None): 9 | VecEnvWrapper.__init__(self, venv) 10 | self.eprets = None 11 | self.eplens = None 12 | self.tstart = time.time() 13 | self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) 14 | 15 | def reset(self): 16 | obs = self.venv.reset() 17 | self.eprets = np.zeros(self.num_envs, 'f') 18 | self.eplens = np.zeros(self.num_envs, 'i') 19 | return obs 20 | 21 | def step_wait(self): 22 | obs, rews, dones, infos = self.venv.step_wait() 23 | self.eprets += rews 24 | self.eplens += 1 25 | newinfos = [] 26 | for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): 27 | info = info.copy() 28 | if done: 29 | epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} 30 | info['episode'] = epinfo 31 | self.eprets[i] = 0 32 | self.eplens[i] = 0 33 | self.results_writer.write_row(epinfo) 34 | 35 | newinfos.append(info) 36 | 37 | return obs, rews, dones, newinfos 38 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ddpg/README.md: -------------------------------------------------------------------------------- 1 | # DDPG 2 | 3 | - Original paper: https://arxiv.org/abs/1509.02971 4 | - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/ 5 | - `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. 6 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ddpg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/ddpg/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ddpg/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.common.models import get_network_builder 3 | 4 | 5 | class Model(object): 6 | def __init__(self, name, network='mlp', **network_kwargs): 7 | self.name = name 8 | self.network_builder = get_network_builder(network)(**network_kwargs) 9 | 10 | @property 11 | def vars(self): 12 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 13 | 14 | @property 15 | def trainable_vars(self): 16 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 17 | 18 | @property 19 | def perturbable_vars(self): 20 | return [var for var in self.trainable_vars if 'LayerNorm' not in var.name] 21 | 22 | 23 | class Actor(Model): 24 | def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs): 25 | super().__init__(name=name, network=network, **network_kwargs) 26 | self.nb_actions = nb_actions 27 | 28 | def __call__(self, obs, reuse=False): 29 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 30 | x = self.network_builder(obs) 31 | x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3)) 32 | x = tf.nn.tanh(x) 33 | return x 34 | 35 | 36 | class Critic(Model): 37 | def __init__(self, name='critic', network='mlp', **network_kwargs): 38 | super().__init__(name=name, network=network, **network_kwargs) 39 | self.layer_norm = True 40 | 41 | def __call__(self, obs, action, reuse=False): 42 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 43 | x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated 44 | x = self.network_builder(x) 45 | x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3)) 46 | return x 47 | 48 | @property 49 | def output_vars(self): 50 | output_vars = [var for var in self.trainable_vars if 'output' in var.name] 51 | return output_vars 52 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ddpg/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveParamNoiseSpec(object): 5 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.1, adoption_coefficient=1.01): 6 | self.initial_stddev = initial_stddev 7 | self.desired_action_stddev = desired_action_stddev 8 | self.adoption_coefficient = adoption_coefficient 9 | 10 | self.current_stddev = initial_stddev 11 | 12 | def adapt(self, distance): 13 | if distance > self.desired_action_stddev: 14 | # Decrease stddev. 15 | self.current_stddev /= self.adoption_coefficient 16 | else: 17 | # Increase stddev. 18 | self.current_stddev *= self.adoption_coefficient 19 | 20 | def get_stats(self): 21 | stats = { 22 | 'param_noise_stddev': self.current_stddev, 23 | } 24 | return stats 25 | 26 | def __repr__(self): 27 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adoption_coefficient={})' 28 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adoption_coefficient) 29 | 30 | 31 | class ActionNoise(object): 32 | def reset(self): 33 | pass 34 | 35 | 36 | class NormalActionNoise(ActionNoise): 37 | def __init__(self, mu, sigma): 38 | self.mu = mu 39 | self.sigma = sigma 40 | 41 | def __call__(self): 42 | return np.random.normal(self.mu, self.sigma) 43 | 44 | def __repr__(self): 45 | return 'NormalActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma) 46 | 47 | 48 | # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 49 | class OrnsteinUhlenbeckActionNoise(ActionNoise): 50 | def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None): 51 | self.theta = theta 52 | self.mu = mu 53 | self.sigma = sigma 54 | self.dt = dt 55 | self.x0 = x0 56 | self.reset() 57 | 58 | def __call__(self): 59 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape) 60 | self.x_prev = x 61 | return x 62 | 63 | def reset(self): 64 | self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu) 65 | 66 | def __repr__(self): 67 | return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma) 68 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/README.md: -------------------------------------------------------------------------------- 1 | ## If you are curious. 2 | 3 | ##### Train a Cartpole agent and watch it play once it converges! 4 | 5 | Here's a list of commands to run to quickly get a working example: 6 | 7 | 8 | 9 | 10 | ```bash 11 | # Train model and save the results to cartpole_model.pkl 12 | python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5 13 | # Load the model saved in cartpole_model.pkl and visualize the learned policy 14 | python -m baselines.run --alg=deepq --env=CartPole-v0 --load_path=./cartpole_model.pkl --num_timesteps=0 --play 15 | ``` 16 | 17 | ## If you wish to apply DQN to solve a problem. 18 | 19 | Check out our simple agent trained with one stop shop `deepq.learn` function. 20 | 21 | - [baselines/deepq/experiments/train_cartpole.py](experiments/train_cartpole.py) - train a Cartpole agent. 22 | 23 | In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. Complimentary file `enjoy_cartpole.py` loads and visualizes the learned policy. 24 | 25 | ## If you wish to experiment with the algorithm 26 | 27 | ##### Check out the examples 28 | 29 | - [baselines/deepq/experiments/custom_cartpole.py](experiments/custom_cartpole.py) - Cartpole training with more fine grained control over the internals of DQN algorithm. 30 | - [baselines/deepq/defaults.py](defaults.py) - settings for training on atari. Run 31 | 32 | ```bash 33 | python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 34 | ``` 35 | to train on Atari Pong (see more in repo-wide [README.md](../../README.md#training-models)) 36 | 37 | 38 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/__init__.py: -------------------------------------------------------------------------------- 1 | from baselines.deepq import models # noqa 2 | from baselines.deepq.build_graph import build_act, build_train # noqa 3 | from baselines.deepq.deepq import learn, load_act # noqa 4 | from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa 5 | 6 | def wrap_atari_dqn(env): 7 | from baselines.common.atari_wrappers import wrap_deepmind 8 | return wrap_deepmind(env, frame_stack=True, scale=True) 9 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/defaults.py: -------------------------------------------------------------------------------- 1 | def atari(): 2 | return dict( 3 | network='conv_only', 4 | lr=1e-4, 5 | buffer_size=10000, 6 | exploration_fraction=0.1, 7 | exploration_final_eps=0.01, 8 | train_freq=4, 9 | learning_starts=10000, 10 | target_network_update_freq=1000, 11 | gamma=0.99, 12 | prioritized_replay=True, 13 | prioritized_replay_alpha=0.6, 14 | checkpoint_freq=10000, 15 | checkpoint_path=None, 16 | dueling=True 17 | ) 18 | 19 | def retro(): 20 | return atari() 21 | 22 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/deepq/experiments/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/enjoy_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | 5 | 6 | def main(): 7 | env = gym.make("CartPole-v0") 8 | act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl") 9 | 10 | while True: 11 | obs, done = env.reset(), False 12 | episode_rew = 0 13 | while not done: 14 | env.render() 15 | obs, rew, done, _ = env.step(act(obs[None])[0]) 16 | episode_rew += rew 17 | print("Episode reward", episode_rew) 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/enjoy_mountaincar.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | from baselines.common import models 5 | 6 | 7 | def main(): 8 | env = gym.make("MountainCar-v0") 9 | act = deepq.learn( 10 | env, 11 | network=models.mlp(num_layers=1, num_hidden=64), 12 | total_timesteps=0, 13 | load_path='mountaincar_model.pkl' 14 | ) 15 | 16 | while True: 17 | obs, done = env.reset(), False 18 | episode_rew = 0 19 | while not done: 20 | env.render() 21 | obs, rew, done, _ = env.step(act(obs[None])[0]) 22 | episode_rew += rew 23 | print("Episode reward", episode_rew) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/enjoy_pong.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from baselines import deepq 3 | 4 | 5 | def main(): 6 | env = gym.make("PongNoFrameskip-v4") 7 | env = deepq.wrap_atari_dqn(env) 8 | model = deepq.learn( 9 | env, 10 | "conv_only", 11 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 12 | hiddens=[256], 13 | dueling=True, 14 | total_timesteps=0 15 | ) 16 | 17 | while True: 18 | obs, done = env.reset(), False 19 | episode_rew = 0 20 | while not done: 21 | env.render() 22 | obs, rew, done, _ = env.step(model(obs[None])[0]) 23 | episode_rew += rew 24 | print("Episode reward", episode_rew) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/train_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | 5 | 6 | def callback(lcl, _glb): 7 | # stop training if reward exceeds 199 8 | is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199 9 | return is_solved 10 | 11 | 12 | def main(): 13 | env = gym.make("CartPole-v0") 14 | act = deepq.learn( 15 | env, 16 | network='mlp', 17 | lr=1e-3, 18 | total_timesteps=100000, 19 | buffer_size=50000, 20 | exploration_fraction=0.1, 21 | exploration_final_eps=0.02, 22 | print_freq=10, 23 | callback=callback 24 | ) 25 | print("Saving model to cartpole_model.pkl") 26 | act.save("cartpole_model.pkl") 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/train_mountaincar.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | from baselines.common import models 5 | 6 | 7 | def main(): 8 | env = gym.make("MountainCar-v0") 9 | # Enabling layer_norm here is import for parameter space noise! 10 | act = deepq.learn( 11 | env, 12 | network=models.mlp(num_hidden=64, num_layers=1), 13 | lr=1e-3, 14 | total_timesteps=100000, 15 | buffer_size=50000, 16 | exploration_fraction=0.1, 17 | exploration_final_eps=0.1, 18 | print_freq=10, 19 | param_noise=True 20 | ) 21 | print("Saving model to mountaincar_model.pkl") 22 | act.save("mountaincar_model.pkl") 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/experiments/train_pong.py: -------------------------------------------------------------------------------- 1 | from baselines import deepq 2 | from baselines import bench 3 | from baselines import logger 4 | from baselines.common.atari_wrappers import make_atari 5 | 6 | 7 | def main(): 8 | logger.configure() 9 | env = make_atari('PongNoFrameskip-v4') 10 | env = bench.Monitor(env, logger.get_dir()) 11 | env = deepq.wrap_atari_dqn(env) 12 | 13 | model = deepq.learn( 14 | env, 15 | "conv_only", 16 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 17 | hiddens=[256], 18 | dueling=True, 19 | lr=1e-4, 20 | total_timesteps=int(1e7), 21 | buffer_size=10000, 22 | exploration_fraction=0.1, 23 | exploration_final_eps=0.01, 24 | train_freq=4, 25 | learning_starts=10000, 26 | target_network_update_freq=1000, 27 | gamma=0.99, 28 | ) 29 | 30 | model.save('pong_model.pkl') 31 | env.close() 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/deepq/utils.py: -------------------------------------------------------------------------------- 1 | from baselines.common.input import observation_input 2 | from baselines.common.tf_util import adjust_shape 3 | 4 | # ================================================================ 5 | # Placeholders 6 | # ================================================================ 7 | 8 | 9 | class TfInput(object): 10 | def __init__(self, name="(unnamed)"): 11 | """Generalized Tensorflow placeholder. The main differences are: 12 | - possibly uses multiple placeholders internally and returns multiple values 13 | - can apply light postprocessing to the value feed to placeholder. 14 | """ 15 | self.name = name 16 | 17 | def get(self): 18 | """Return the tf variable(s) representing the possibly postprocessed value 19 | of placeholder(s). 20 | """ 21 | raise NotImplementedError 22 | 23 | def make_feed_dict(data): 24 | """Given data input it to the placeholder(s).""" 25 | raise NotImplementedError 26 | 27 | 28 | class PlaceholderTfInput(TfInput): 29 | def __init__(self, placeholder): 30 | """Wrapper for regular tensorflow placeholder.""" 31 | super().__init__(placeholder.name) 32 | self._placeholder = placeholder 33 | 34 | def get(self): 35 | return self._placeholder 36 | 37 | def make_feed_dict(self, data): 38 | return {self._placeholder: adjust_shape(self._placeholder, data)} 39 | 40 | 41 | class ObservationInput(PlaceholderTfInput): 42 | def __init__(self, observation_space, name=None): 43 | """Creates an input placeholder tailored to a specific observation space 44 | 45 | Parameters 46 | ---------- 47 | 48 | observation_space: 49 | observation space of the environment. Should be one of the gym.spaces types 50 | name: str 51 | tensorflow name of the underlying placeholder 52 | """ 53 | inpt, self.processed_inpt = observation_input(observation_space, name=name) 54 | super().__init__(inpt) 55 | 56 | def get(self): 57 | return self.processed_inpt 58 | 59 | 60 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Imitation Learning (GAIL) 2 | 3 | - Original paper: https://arxiv.org/abs/1606.03476 4 | 5 | For results benchmarking on MuJoCo, please navigate to [here](result/gail-result.md) 6 | 7 | ## If you want to train an imitation learning agent 8 | 9 | ### Step 1: Download expert data 10 | 11 | Download the expert data into `./data`, [download link](https://drive.google.com/drive/folders/1h3H4AY_ZBx08hz-Ct0Nxxus-V1melu1U?usp=sharing) 12 | 13 | ### Step 2: Run GAIL 14 | 15 | Run with single thread: 16 | 17 | ```bash 18 | python -m baselines.gail.run_mujoco 19 | ``` 20 | 21 | Run with multiple threads: 22 | 23 | ```bash 24 | mpirun -np 16 python -m baselines.gail.run_mujoco 25 | ``` 26 | 27 | See help (`-h`) for more options. 28 | 29 | #### In case you want to run Behavior Cloning (BC) 30 | 31 | ```bash 32 | python -m baselines.gail.behavior_clone 33 | ``` 34 | 35 | See help (`-h`) for more options. 36 | 37 | 38 | ## Contributing 39 | 40 | Bug reports and pull requests are welcome on GitHub at https://github.com/openai/baselines/pulls. 41 | 42 | ## Maintainers 43 | 44 | - Yuan-Hong Liao, andrewliao11_at_gmail_dot_com 45 | - Ryan Julian, ryanjulian_at_gmail_dot_com 46 | 47 | ## Others 48 | 49 | Thanks to the open source: 50 | 51 | - @openai/imitation 52 | - @carpedm20/deep-rl-tensorflow 53 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/dataset/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HalfCheetah-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HalfCheetah-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HalfCheetah-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HalfCheetah-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HalfCheetah-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HalfCheetah-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HalfCheetah-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HalfCheetah-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Hopper-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Hopper-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Hopper-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Hopper-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Hopper-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Hopper-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Hopper-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Hopper-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Humanoid-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Humanoid-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Humanoid-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Humanoid-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Humanoid-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Humanoid-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Humanoid-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Humanoid-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HumanoidStandup-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HumanoidStandup-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HumanoidStandup-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HumanoidStandup-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HumanoidStandup-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HumanoidStandup-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/HumanoidStandup-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/HumanoidStandup-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Walker2d-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Walker2d-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Walker2d-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Walker2d-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Walker2d-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Walker2d-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/Walker2d-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/Walker2d-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/halfcheetah-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/halfcheetah-training.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/hopper-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/hopper-training.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/humanoid-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/humanoid-training.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/humanoidstandup-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/humanoidstandup-training.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/result/walker2d-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/gail/result/walker2d-training.png -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/gail/statistics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is highly based on https://github.com/carpedm20/deep-rl-tensorflow/blob/master/agents/statistic.py 3 | ''' 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | import baselines.common.tf_util as U 9 | 10 | 11 | class stats(): 12 | 13 | def __init__(self, scalar_keys=[], histogram_keys=[]): 14 | self.scalar_keys = scalar_keys 15 | self.histogram_keys = histogram_keys 16 | self.scalar_summaries = [] 17 | self.scalar_summaries_ph = [] 18 | self.histogram_summaries_ph = [] 19 | self.histogram_summaries = [] 20 | with tf.variable_scope('summary'): 21 | for k in scalar_keys: 22 | ph = tf.placeholder('float32', None, name=k+'.scalar.summary') 23 | sm = tf.summary.scalar(k+'.scalar.summary', ph) 24 | self.scalar_summaries_ph.append(ph) 25 | self.scalar_summaries.append(sm) 26 | for k in histogram_keys: 27 | ph = tf.placeholder('float32', None, name=k+'.histogram.summary') 28 | sm = tf.summary.scalar(k+'.histogram.summary', ph) 29 | self.histogram_summaries_ph.append(ph) 30 | self.histogram_summaries.append(sm) 31 | 32 | self.summaries = tf.summary.merge(self.scalar_summaries+self.histogram_summaries) 33 | 34 | def add_all_summary(self, writer, values, iter): 35 | # Note that the order of the incoming ```values``` should be the same as the that of the 36 | # ```scalar_keys``` given in ```__init__``` 37 | if np.sum(np.isnan(values)+0) != 0: 38 | return 39 | sess = U.get_session() 40 | keys = self.scalar_summaries_ph + self.histogram_summaries_ph 41 | feed_dict = {} 42 | for k, v in zip(keys, values): 43 | feed_dict.update({k: v}) 44 | summaries_str = sess.run(self.summaries, feed_dict) 45 | writer.add_summary(summaries_str, iter) 46 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/her/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/her/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/her/actor_critic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.her.util import store_args, nn 3 | 4 | 5 | class ActorCritic: 6 | @store_args 7 | def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, 8 | **kwargs): 9 | """The actor-critic network and related training code. 10 | 11 | Args: 12 | inputs_tf (dict of tensors): all necessary inputs for the network: the 13 | observation (o), the goal (g), and the action (u) 14 | dimo (int): the dimension of the observations 15 | dimg (int): the dimension of the goals 16 | dimu (int): the dimension of the actions 17 | max_u (float): the maximum magnitude of actions; action outputs will be scaled 18 | accordingly 19 | o_stats (baselines.her.Normalizer): normalizer for observations 20 | g_stats (baselines.her.Normalizer): normalizer for goals 21 | hidden (int): number of hidden units that should be used in hidden layers 22 | layers (int): number of hidden layers 23 | """ 24 | self.o_tf = inputs_tf['o'] 25 | self.g_tf = inputs_tf['g'] 26 | self.u_tf = inputs_tf['u'] 27 | 28 | # Prepare inputs for actor and critic. 29 | o = self.o_stats.normalize(self.o_tf) 30 | g = self.g_stats.normalize(self.g_tf) 31 | input_pi = tf.concat(axis=1, values=[o, g]) # for actor 32 | 33 | # Networks. 34 | with tf.variable_scope('pi'): 35 | self.pi_tf = self.max_u * tf.tanh(nn( 36 | input_pi, [self.hidden] * self.layers + [self.dimu])) 37 | with tf.variable_scope('Q'): 38 | # for policy training 39 | input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u]) 40 | self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1]) 41 | # for critic training 42 | input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u]) 43 | self._input_Q = input_Q # exposed for tests 44 | self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True) 45 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/her/experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/her/experiment/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/her/experiment/play.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import pickle 4 | 5 | from baselines import logger 6 | from baselines.common import set_global_seeds 7 | import baselines.her.experiment.config as config 8 | from baselines.her.rollout import RolloutWorker 9 | 10 | 11 | @click.command() 12 | @click.argument('policy_file', type=str) 13 | @click.option('--seed', type=int, default=0) 14 | @click.option('--n_test_rollouts', type=int, default=10) 15 | @click.option('--render', type=int, default=1) 16 | def main(policy_file, seed, n_test_rollouts, render): 17 | set_global_seeds(seed) 18 | 19 | # Load policy. 20 | with open(policy_file, 'rb') as f: 21 | policy = pickle.load(f) 22 | env_name = policy.info['env_name'] 23 | 24 | # Prepare params. 25 | params = config.DEFAULT_PARAMS 26 | if env_name in config.DEFAULT_ENV_PARAMS: 27 | params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in 28 | params['env_name'] = env_name 29 | params = config.prepare_params(params) 30 | config.log_params(params, logger=logger) 31 | 32 | dims = config.configure_dims(params) 33 | 34 | eval_params = { 35 | 'exploit': True, 36 | 'use_target_net': params['test_with_polyak'], 37 | 'compute_Q': True, 38 | 'rollout_batch_size': 1, 39 | 'render': bool(render), 40 | } 41 | 42 | for name in ['T', 'gamma', 'noise_eps', 'random_eps']: 43 | eval_params[name] = params[name] 44 | 45 | evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) 46 | evaluator.seed(seed) 47 | 48 | # Run evaluation. 49 | evaluator.clear_history() 50 | for _ in range(n_test_rollouts): 51 | evaluator.generate_rollouts() 52 | 53 | # record logs 54 | for key, val in evaluator.logs('test'): 55 | logger.record_tabular(key, np.mean(val)) 56 | logger.dump_tabular() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo1/README.md: -------------------------------------------------------------------------------- 1 | # PPOSGD 2 | 3 | - Original paper: https://arxiv.org/abs/1707.06347 4 | - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ 5 | - `mpirun -np 8 python -m baselines.ppo1.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. 6 | - `python -m baselines.ppo1.run_mujoco` runs the algorithm for 1M frames on a Mujoco environment. 7 | 8 | - Train mujoco 3d humanoid (with optimal-ish hyperparameters): `mpirun -np 16 python -m baselines.ppo1.run_humanoid --model-path=/path/to/model` 9 | - Render the 3d humanoid: `python -m baselines.ppo1.run_humanoid --play --model-path=/path/to/model` 10 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/ppo1/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo1/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from mpi4py import MPI 4 | from baselines.common import set_global_seeds 5 | from baselines import bench 6 | import os.path as osp 7 | from baselines import logger 8 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 9 | from baselines.common.cmd_util import atari_arg_parser 10 | 11 | def train(env_id, num_timesteps, seed): 12 | from baselines.ppo1 import pposgd_simple, cnn_policy 13 | import baselines.common.tf_util as U 14 | rank = MPI.COMM_WORLD.Get_rank() 15 | sess = U.single_threaded_session() 16 | sess.__enter__() 17 | if rank == 0: 18 | logger.configure() 19 | else: 20 | logger.configure(format_strs=[]) 21 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed is not None else None 22 | set_global_seeds(workerseed) 23 | env = make_atari(env_id) 24 | def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613 25 | return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space) 26 | env = bench.Monitor(env, logger.get_dir() and 27 | osp.join(logger.get_dir(), str(rank))) 28 | env.seed(workerseed) 29 | 30 | env = wrap_deepmind(env) 31 | env.seed(workerseed) 32 | 33 | pposgd_simple.learn(env, policy_fn, 34 | max_timesteps=int(num_timesteps * 1.1), 35 | timesteps_per_actorbatch=256, 36 | clip_param=0.2, entcoeff=0.01, 37 | optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64, 38 | gamma=0.99, lam=0.95, 39 | schedule='linear' 40 | ) 41 | env.close() 42 | 43 | def main(): 44 | args = atari_arg_parser().parse_args() 45 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo1/run_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser 4 | from baselines.common import tf_util as U 5 | from baselines import logger 6 | 7 | def train(env_id, num_timesteps, seed): 8 | from baselines.ppo1 import mlp_policy, pposgd_simple 9 | U.make_session(num_cpu=1).__enter__() 10 | def policy_fn(name, ob_space, ac_space): 11 | return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, 12 | hid_size=64, num_hid_layers=2) 13 | env = make_mujoco_env(env_id, seed) 14 | pposgd_simple.learn(env, policy_fn, 15 | max_timesteps=num_timesteps, 16 | timesteps_per_actorbatch=2048, 17 | clip_param=0.2, entcoeff=0.0, 18 | optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64, 19 | gamma=0.99, lam=0.95, schedule='linear', 20 | ) 21 | env.close() 22 | 23 | def main(): 24 | args = mujoco_arg_parser().parse_args() 25 | logger.configure() 26 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo1/run_robotics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from mpi4py import MPI 4 | from baselines.common import set_global_seeds 5 | from baselines import logger 6 | from baselines.common.cmd_util import make_robotics_env, robotics_arg_parser 7 | import mujoco_py 8 | 9 | 10 | def train(env_id, num_timesteps, seed): 11 | from baselines.ppo1 import mlp_policy, pposgd_simple 12 | import baselines.common.tf_util as U 13 | rank = MPI.COMM_WORLD.Get_rank() 14 | sess = U.single_threaded_session() 15 | sess.__enter__() 16 | mujoco_py.ignore_mujoco_warnings().__enter__() 17 | workerseed = seed + 10000 * rank 18 | set_global_seeds(workerseed) 19 | env = make_robotics_env(env_id, workerseed, rank=rank) 20 | def policy_fn(name, ob_space, ac_space): 21 | return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, 22 | hid_size=256, num_hid_layers=3) 23 | 24 | pposgd_simple.learn(env, policy_fn, 25 | max_timesteps=num_timesteps, 26 | timesteps_per_actorbatch=2048, 27 | clip_param=0.2, entcoeff=0.0, 28 | optim_epochs=5, optim_stepsize=3e-4, optim_batchsize=256, 29 | gamma=0.99, lam=0.95, schedule='linear', 30 | ) 31 | env.close() 32 | 33 | 34 | def main(): 35 | args = robotics_arg_parser().parse_args() 36 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo2/README.md: -------------------------------------------------------------------------------- 1 | # PPO2 2 | 3 | - Original paper: https://arxiv.org/abs/1707.06347 4 | - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ 5 | 6 | - `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 7 | - `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment. 8 | - also refer to the repo-wide [README.md](../../README.md#training-models) 9 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/ppo2/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo2/defaults.py: -------------------------------------------------------------------------------- 1 | def mujoco(): 2 | return dict( 3 | nsteps=2048, 4 | nminibatches=32, 5 | lam=0.95, 6 | gamma=0.99, 7 | noptepochs=10, 8 | log_interval=1, 9 | ent_coef=0.0, 10 | lr=lambda f: 3e-4 * f, 11 | cliprange=0.2, 12 | value_network='copy' 13 | ) 14 | 15 | def atari(): 16 | return dict( 17 | nsteps=128, nminibatches=4, 18 | lam=0.95, gamma=0.99, noptepochs=4, log_interval=1, 19 | ent_coef=.01, 20 | lr=lambda f : f * 2.5e-4, 21 | cliprange=lambda f : f * 0.1, 22 | ) 23 | 24 | def retro(): 25 | return atari() 26 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/ppo2/test_microbatches.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import tensorflow as tf 3 | import numpy as np 4 | from functools import partial 5 | 6 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 7 | from baselines.common.tf_util import make_session 8 | from baselines.ppo2.ppo2 import learn 9 | 10 | from baselines.ppo2.microbatched_model import MicrobatchedModel 11 | 12 | def test_microbatches(): 13 | def env_fn(): 14 | env = gym.make('CartPole-v0') 15 | env.seed(0) 16 | return env 17 | 18 | learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0) 19 | 20 | env_ref = DummyVecEnv([env_fn]) 21 | sess_ref = make_session(make_default=True, graph=tf.Graph()) 22 | learn_fn(env=env_ref) 23 | vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()} 24 | 25 | env_test = DummyVecEnv([env_fn]) 26 | sess_test = make_session(make_default=True, graph=tf.Graph()) 27 | learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2)) 28 | vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()} 29 | 30 | for v in vars_ref: 31 | np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3) 32 | 33 | if __name__ == '__main__': 34 | test_microbatches() 35 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/trpo_mpi/README.md: -------------------------------------------------------------------------------- 1 | # trpo_mpi 2 | 3 | - Original paper: https://arxiv.org/abs/1502.05477 4 | - Baselines blog post https://blog.openai.com/openai-baselines-ppo/ 5 | - `mpirun -np 16 python -m baselines.run --alg=trpo_mpi --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 6 | - `python -m baselines.run --alg=trpo_mpi --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M timesteps on a Mujoco Ant environment. 7 | - also refer to the repo-wide [README.md](../../README.md#training-models) 8 | -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/trpo_mpi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/baselines/trpo_mpi/__init__.py -------------------------------------------------------------------------------- /drex-atari/baselines/baselines/trpo_mpi/defaults.py: -------------------------------------------------------------------------------- 1 | from baselines.common.models import mlp, cnn_small 2 | 3 | 4 | def atari(): 5 | return dict( 6 | network = cnn_small(), 7 | timesteps_per_batch=512, 8 | max_kl=0.001, 9 | cg_iters=10, 10 | cg_damping=1e-3, 11 | gamma=0.98, 12 | lam=1.0, 13 | vf_iters=3, 14 | vf_stepsize=1e-4, 15 | entcoeff=0.00, 16 | ) 17 | 18 | def mujoco(): 19 | return dict( 20 | network = mlp(num_hidden=32, num_layers=2), 21 | timesteps_per_batch=1024, 22 | max_kl=0.01, 23 | cg_iters=10, 24 | cg_damping=0.1, 25 | gamma=0.99, 26 | lam=0.98, 27 | vf_iters=5, 28 | vf_stepsize=1e-3, 29 | normalize_observations=True, 30 | ) 31 | -------------------------------------------------------------------------------- /drex-atari/baselines/data/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/data/cartpole.gif -------------------------------------------------------------------------------- /drex-atari/baselines/data/fetchPickAndPlaceContrast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/data/fetchPickAndPlaceContrast.png -------------------------------------------------------------------------------- /drex-atari/baselines/data/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/baselines/data/logo.jpg -------------------------------------------------------------------------------- /drex-atari/baselines/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = F,E999,W291,W293 3 | exclude = 4 | .git, 5 | __pycache__, 6 | baselines/her, 7 | baselines/ppo1, 8 | baselines/bench, 9 | -------------------------------------------------------------------------------- /drex-atari/baselines/setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from setuptools import setup, find_packages 3 | import sys 4 | 5 | if sys.version_info.major != 3: 6 | print('This Python is only compatible with Python 3, but you are running ' 7 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major)) 8 | 9 | 10 | extras = { 11 | 'test': [ 12 | 'filelock', 13 | 'pytest', 14 | 'pytest-forked', 15 | 'atari-py' 16 | ], 17 | 'bullet': [ 18 | 'pybullet', 19 | ], 20 | 'mpi': [ 21 | 'mpi4py' 22 | ] 23 | } 24 | 25 | all_deps = [] 26 | for group_name in extras: 27 | all_deps += extras[group_name] 28 | 29 | extras['all'] = all_deps 30 | 31 | setup(name='baselines', 32 | packages=[package for package in find_packages() 33 | if package.startswith('baselines')], 34 | install_requires=[ 35 | 'gym', 36 | 'scipy', 37 | 'tqdm', 38 | 'joblib', 39 | 'dill', 40 | 'progressbar2', 41 | 'cloudpickle', 42 | 'click', 43 | 'opencv-python' 44 | ], 45 | extras_require=extras, 46 | description='OpenAI baselines: high quality implementations of reinforcement learning algorithms', 47 | author='OpenAI', 48 | url='https://github.com/openai/baselines', 49 | author_email='gym@openai.com', 50 | version='0.1.5') 51 | 52 | 53 | # ensure there is some tensorflow build with version above 1.4 54 | import pkg_resources 55 | tf_pkg = None 56 | for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']: 57 | try: 58 | tf_pkg = pkg_resources.get_distribution(tf_pkg_name) 59 | except pkg_resources.DistributionNotFound: 60 | pass 61 | assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4' 62 | from distutils.version import StrictVersion 63 | assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0') 64 | -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/beamrider_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/beamrider_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/breakout_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/breakout_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/create_table.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import sys 4 | 5 | env_names = {'beamrider':'Beam Rider', 'breakout':'Breakout', 'enduro':'Enduro', 6 | 'pong':'Pong', 'qbert':'Q*bert', 'seaquest':'Seaquest', 'spaceinvaders':'Space Invaders'} 7 | 8 | 9 | writer = open("performance_table.txt",'w') 10 | for env_name in env_names: 11 | reader = open(env_name + "_batch_rewards.csv") 12 | bc_returns = [] 13 | demo_returns = [] 14 | for line in reader: 15 | parsed = line.split(", ") 16 | if parsed[0] == "demos": 17 | demo_returns = np.array([float(r) for r in parsed[1:]]) 18 | 19 | elif float(parsed[0]) == 0.01: 20 | bc_returns = [float(r) for r in parsed[1:]] 21 | 22 | writer.write("{} & {} & {} & & () & {} & ({:.1f}) \\\\ \n".format(env_names[env_name], np.mean(demo_returns), np.max(demo_returns), np.mean(bc_returns), np.std(bc_returns))) 23 | -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/enduro_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/enduro_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/performance_table.txt: -------------------------------------------------------------------------------- 1 | Beam Rider & 1524.0 & 2216.0 & & () & 1268.6 & (776.6) \\ 2 | Breakout & 34.5 & 59.0 & & () & 29.75 & (10.1) \\ 3 | Enduro & 85.5 & 134.0 & & () & 83.4 & (27.0) \\ 4 | Pong & 3.7 & 14.0 & & () & 8.6 & (9.5) \\ 5 | Q*bert & 770.0 & 850.0 & & () & 1013.75 & (721.1) \\ 6 | Seaquest & 524.0 & 720.0 & & () & 530.0 & (109.8) \\ 7 | Space Invaders & 538.5 & 930.0 & & () & 426.5 & (187.1) \\ 8 | -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/plot_degredation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import sys 4 | 5 | if len(sys.argv) < 2: 6 | print("usage: python plot_degredation.py [env_name]") 7 | sys.exit() 8 | env_name = sys.argv[1] 9 | reader = open(env_name + "_batch_rewards.csv") 10 | noise_levels = [] 11 | returns = [] 12 | for line in reader: 13 | parsed = line.split(", ") 14 | if parsed[0] == "demos": 15 | demo_returns = np.array([float(r) for r in parsed[1:]]) 16 | print(demo_returns) 17 | else: 18 | noise_levels.append(float(parsed[0])) 19 | returns.append([float(r) for r in parsed[1:]]) 20 | returns = np.array(returns) 21 | print(noise_levels) 22 | print(returns) 23 | #plot the average of the demos in line 24 | demo_ave = np.mean(demo_returns) 25 | demo_std = np.std(demo_returns) 26 | 27 | import matplotlib.pylab as pylab 28 | params = {'legend.fontsize': 'xx-large', 29 | 'figure.figsize': (5, 4), 30 | 'axes.labelsize': 'xx-large', 31 | 'axes.titlesize':'xx-large', 32 | 'xtick.labelsize':'xx-large', 33 | 'ytick.labelsize':'xx-large'} 34 | pylab.rcParams.update(params) 35 | 36 | #plt.plot([0,1.0],[demo_ave, demo_ave]) 37 | plt.fill_between([0.01, 1.0], [demo_ave - demo_std, demo_ave - demo_std], [demo_ave + demo_std, demo_ave + demo_std], alpha = 0.3) 38 | plt.plot([0.01,1.0],[demo_ave, demo_ave], label='demos') 39 | plt.fill_between(noise_levels, np.mean(returns, axis=1)-np.std(returns, axis=1), np.mean(returns, axis=1) + np.std(returns, axis=1), alpha = 0.3) 40 | plt.plot(noise_levels, np.mean(returns, axis = 1),'-.', label="bc") 41 | #plot the average of pure noise in dashed line for baseline 42 | plt.fill_between([0.01, 1.0], [np.mean(returns[-1]) - np.std(returns[-1]), np.mean(returns[-1]) - np.std(returns[-1])], 43 | [np.mean(returns[-1]) + np.std(returns[-1]), np.mean(returns[-1]) + np.std(returns[-1])], alpha = 0.3) 44 | plt.plot([0.01,1.0], [np.mean(returns[-1]), np.mean(returns[-1])],'--', label="random") 45 | plt.legend(loc="best") 46 | plt.xlabel("Epsilon") 47 | plt.xticks([0.0, 0.25, 0.5, 0.75, 1.0]) 48 | plt.ylabel("Return") 49 | plt.tight_layout() 50 | plt.savefig(env_name + "_degredation_plot.png") 51 | plt.show() 52 | -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/pong_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/pong_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/qbert_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/qbert_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/seaquest_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/seaquest_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/bc_degredation_data/spaceinvaders_degredation_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/bc_degredation_data/spaceinvaders_degredation_plot.png -------------------------------------------------------------------------------- /drex-atari/checkpoints/beamrider_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/beamrider_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/breakout_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/breakout_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/breakout_standard_bc_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/breakout_standard_bc_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/enduro_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/enduro_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/hero_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/hero_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/pong_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/pong_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/qbert_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/qbert_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/seaquest_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/seaquest_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/checkpoints/spaceinvaders_novice_demos_network.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/checkpoints/spaceinvaders_novice_demos_network.pth.tar -------------------------------------------------------------------------------- /drex-atari/cnn.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Network(nn.Module): 6 | def __init__(self, num_output_actions, hist_len=4): 7 | super(Network, self).__init__() 8 | self.conv1 = nn.Conv2d(hist_len, 32, kernel_size=8, stride=4) 9 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 10 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 11 | self.fc1 = nn.Linear(64 * 7 * 7, 512) 12 | self.output = nn.Linear(512, num_output_actions) 13 | 14 | def forward(self, input): 15 | conv1_output = F.relu(self.conv1(input)) 16 | conv2_output = F.relu(self.conv2(conv1_output)) 17 | conv3_output = F.relu(self.conv3(conv2_output)) 18 | fc1_output = F.relu(self.fc1(conv3_output.view(conv3_output.size(0), -1))) 19 | output = self.output(fc1_output) 20 | return conv1_output, conv2_output, conv3_output, fc1_output, output 21 | -------------------------------------------------------------------------------- /drex-atari/figs/beamrider_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/beamrider_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/beamridermax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/beamridermax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/beamridermax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/beamridermax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/beamridermin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/beamridermin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/beamridermin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/beamridermin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/breakout_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/breakout_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/breakoutmax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/breakoutmax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/breakoutmax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/breakoutmax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/breakoutmin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/breakoutmin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/breakoutmin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/breakoutmin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/enduro_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/enduro_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/enduromax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/enduromax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/enduromax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/enduromax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/enduromin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/enduromin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/enduromin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/enduromin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/pong_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/pong_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/pongmax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/pongmax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/pongmax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/pongmax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/pongmin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/pongmin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/pongmin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/pongmin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/qbert_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/qbert_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/qbertmax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/qbertmax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/qbertmax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/qbertmax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/qbertmin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/qbertmin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/qbertmin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/qbertmin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/seaquest_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/seaquest_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/seaquestmax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/seaquestmax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/seaquestmax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/seaquestmax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/seaquestmin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/seaquestmin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/seaquestmin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/seaquestmin_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/spaceinvaders_gt_vs_pred_rewards_progress_sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/spaceinvaders_gt_vs_pred_rewards_progress_sigmoid.png -------------------------------------------------------------------------------- /drex-atari/figs/spaceinvadersmax_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/spaceinvadersmax_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/spaceinvadersmax_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/spaceinvadersmax_frames.png -------------------------------------------------------------------------------- /drex-atari/figs/spaceinvadersmin_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/spaceinvadersmin_attention.png -------------------------------------------------------------------------------- /drex-atari/figs/spaceinvadersmin_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/figs/spaceinvadersmin_frames.png -------------------------------------------------------------------------------- /drex-atari/learned_models/beamrider_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/beamrider_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/breakout_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/breakout_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/enduro_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/enduro_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/pong_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/pong_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/qbert_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/qbert_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/seaquest_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/seaquest_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/learned_models/spaceinvaders_five_bins_noop_earlystop.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-atari/learned_models/spaceinvaders_five_bins_noop_earlystop.params -------------------------------------------------------------------------------- /drex-atari/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from collections import deque 4 | 5 | class Preprocessor: 6 | 7 | def __init__(self): 8 | self.preprocess_stack = deque([], 2) 9 | 10 | def add(self, aleRGB): 11 | self.preprocess_stack.append(aleRGB) 12 | 13 | ''' 14 | Implement the preprocessing step phi, from 15 | the Nature paper. It takes the maximum pixel 16 | values of two consecutive frames. It then 17 | grayscales the image, and resizes it to 84x84. 18 | ''' 19 | def preprocess(self): 20 | assert len(self.preprocess_stack) == 2 21 | return self.resize(self.grayscale(np.maximum(self.preprocess_stack[0], 22 | self.preprocess_stack[1]))) 23 | 24 | ''' 25 | Takes in an RGB image and returns a grayscaled 26 | image. 27 | ''' 28 | def grayscale(self, img): 29 | return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 30 | 31 | ''' 32 | Resizes the input to an 84x84 image. 33 | ''' 34 | def resize(self, image): 35 | return cv2.resize(image, (84, 84)) -------------------------------------------------------------------------------- /drex-atari/state.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | 4 | class State: 5 | def __init__(self, hist_len): 6 | # Initialize a 1 x hist_len x 84 x 84 state 7 | self.hist_len = hist_len 8 | self.state = np.zeros((1, hist_len, 84, 84), dtype=np.float32) 9 | ''' 10 | index of next image, indicates that the first 84 x 84 11 | image should be at (0, 0) 12 | ''' 13 | self.insertLoc = 0 14 | 15 | def add_frame(self, img): 16 | self.state[0, self.insertLoc, ...] = img.astype(np.float32)/255.0 17 | # The index to insert at cycles from betweem 0, 1, 2, 3 18 | self.insertLoc = (self.insertLoc + 1) % self.hist_len 19 | 20 | def get_state(self): 21 | ''' 22 | return the stacked four frames in the correct order 23 | Example: Suppose the state contains the following frames: 24 | [f4 f1 f2 f3]. The return value should be [f1 f2 f3 f4]. 25 | Since the most recent frame inserted in this example was 26 | f4 at index 0, self.insertLoc equals 1. Thus, we 27 | "roll" the image a single space to the left, resulting in 28 | [f1 f2 f3 f4] 29 | ''' 30 | return np.roll(self.state, 0 - self.insertLoc, axis=1) -------------------------------------------------------------------------------- /drex-atari/utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import numpy as np 4 | ''' 5 | checkpointing source: 6 | https://blog.floydhub.com/checkpointing-tutorial-for-tensorflow-keras-and-pytorch/ 7 | ''' 8 | def save_checkpoint(state, checkpoint_dir, env_name, extra_info): 9 | filename = checkpoint_dir + '/' + env_name +'_' + extra_info + '_network.pth.tar' 10 | print("Saving checkpoint at " + filename + " ...") 11 | torch.save(state, filename) # save checkpoint 12 | print("Saved checkpoint.") 13 | 14 | def get_checkpoint(checkpoint_dir): 15 | resume_weights = checkpoint_dir + '/network.pth.tar' 16 | if torch.cuda.is_available(): 17 | print("Attempting to load Cuda weights...") 18 | checkpoint = torch.load(resume_weights) 19 | print("Loaded weights.") 20 | else: 21 | print("Attempting to load weights for CPU...") 22 | # Load GPU model on CPU 23 | checkpoint = torch.load(resume_weights, 24 | map_location=lambda storage, 25 | loc: storage) 26 | print("Loaded weights.") 27 | return checkpoint 28 | 29 | def long_tensor(input): 30 | if torch.cuda.is_available(): 31 | return torch.cuda.LongTensor(input) 32 | else: 33 | return torch.LongTensor(input) 34 | 35 | def float_tensor(input): 36 | if torch.cuda.is_available(): 37 | return torch.cuda.FloatTensor(input) 38 | else: 39 | return torch.FloatTensor(input) 40 | 41 | def perform_no_ops(ale, no_op_max, preprocessor, state): 42 | # perform nullops 43 | num_no_ops = np.random.randint(1, no_op_max + 1) 44 | for _ in range(num_no_ops): 45 | ale.act(0) 46 | preprocessor.add(ale.getScreenRGB()) 47 | if len(preprocessor.preprocess_stack) < 2: 48 | ale.act(0) 49 | preprocessor.add(ale.getScreenRGB()) 50 | state.add_frame(preprocessor.preprocess()) 51 | -------------------------------------------------------------------------------- /drex-mujoco/.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | demos/full_demos/ 3 | -------------------------------------------------------------------------------- /drex-mujoco/README.md: -------------------------------------------------------------------------------- 1 | # D-REX mujoco 2 | 3 | ## Requirements 4 | 5 | This code are tested with `python 3.6` and `tensorflow-gpu==v1.14.0`. For other library dependencies, please refer `requirements.txt` or check `env.yml`. 6 | 7 | 8 | ## Run Experiment 9 | 10 | 1. Behavior Cloning (BC) 11 | 12 | ``` 13 | python bc_mujoco.py --env_id HalfCheetah-v2 --log_path ./log/drex/halfcheetah/bc --demo_trajs demos/suboptimal_demos/halfcheetah/dataset.pkl 14 | python bc_mujoco.py --env_id Hopper-v2 --log_path ./log/drex/hopper/bc --demo_trajs demos/suboptimal_demos/hopper/dataset.pkl 15 | ``` 16 | 17 | 2. Generate Noise Injected Trajectories 18 | 19 | ``` 20 | python bc_noise_dataset.py --log_dir ./log/drex/halfcheetah --env_id HalfCheetah-v2 --bc_agent ./log/drex/halfcheetah/bc/model.ckpt --demo_trajs ./demos/suboptimal_demos/halfcheetah/dataset.pkl 21 | python bc_noise_dataset.py --log_dir ./log/drex/hopper --env_id Hopper-v2 --bc_agent ./log/drex/hopper/bc/model.ckpt --demo_trajs ./demos/suboptimal_demos/hopper/dataset.pkl 22 | ``` 23 | 24 | 3. Run T-REX 25 | 26 | ``` 27 | python drex.py --log_dir ./log/drex/halfcheetah --env_id HalfCheetah-v2 --bc_trajs ./demos/suboptimal_demos/halfcheetah/dataset.pkl --unseen_trajs ./demos/full_demos/halfcheetah/trajs.pkl --noise_injected_trajs ./log/drex/halfcheetah/prebuilt.pkl 28 | python drex.py --log_dir ./log/drex/hopper --env_id Hopper-v2 --bc_trajs ./demos/suboptimal_demos/hopper/dataset.pkl --unseen_trajs ./demos/full_demos/hopper/trajs.pkl --noise_injected_trajs ./log/drex/hopper/prebuilt.pkl 29 | ``` 30 | 31 | You can download pregenerated unseen trajectories from [here](https://github.com/dsbrown1331/CoRL2019-DREX/releases). Instead, you can just erase the `--unseeen_trajs` option. It is just used for generating the plot shown in the paper. 32 | 33 | 4. Run PPO 34 | 35 | ``` 36 | python drex.py --log_dir ./log/drex/halfcheetah --env_id HalfCheetah-v2 --mode train_rl --ctrl_coeff 0.1 37 | python drex.py --log_dir ./log/drex/hopper --env_id Hopper-v2 --mode train_rl --ctrl_coeff 0.001 38 | ``` 39 | -------------------------------------------------------------------------------- /drex-mujoco/demos/suboptimal_demos/halfcheetah/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/demos/suboptimal_demos/halfcheetah/dataset.pkl -------------------------------------------------------------------------------- /drex-mujoco/demos/suboptimal_demos/hopper/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/demos/suboptimal_demos/hopper/dataset.pkl -------------------------------------------------------------------------------- /drex-mujoco/learner/.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | *.meta.json 3 | demo_models/ 4 | -------------------------------------------------------------------------------- /drex-mujoco/learner/README.md: -------------------------------------------------------------------------------- 1 | ## Clone appropriately 2 | 3 | ``` 4 | git submodule update --init --recursive 5 | ``` 6 | 7 | You should see the contents of `baselines` 8 | 9 | ## Train 10 | 11 | ``` 12 | OPENAI_LOG_FORMAT='stdout,log,csv,tensorboard' OPENAI_LOGDIR=/home/user/workspace/LfL_new/learner/models/log/reacher python -m baselines.run --alg=ppo2 --env=Reacher-v2 --save_interval=20 13 | ``` 14 | 15 | ### Create & Use Custom Reward Function 16 | 17 | - First, define your reward function. You need to change two different files. `baselines/common/custom_reward_wrapper.py` and `baselines/run.py`. 18 | - Then, start training with the custom reward function by passing the `custom_reward` argument. For example, 19 | ``` 20 | python -m baselines.run --alg=ppo2 --env=Reacher-v2 --save_interval=20 --custom_reward 'live_long' 21 | ``` 22 | 23 | ### Train with preference-based reward 24 | 25 | ``` 26 | OPENAI_LOG_FORMAT='stdout,log,csv,tensorboard' OPENAI_LOGDIR=/home/user/workspace/LfL/learner/models_preference/swimmer python -m baselines.run --alg=ppo2 --env=Swimmer-v2 --num_timesteps=1e6 --save_interval=10 --custom_reward 'preference' --custom_reward_kwargs='{"num_models":3,"model_dir":"../../log_preference/Swimmer-v2"}' 27 | ``` 28 | 29 | ## Generate Trajectory and Videos 30 | 31 | First, download the pretrained models [link anonymized for review], and extract under the `models` directory. 32 | 33 | ``` 34 | ./run_test.py --env_id BreakoutNoFrameskip-v4 --env_type atari --model_path ./models/breakout/checkpoints/03600 --record_video 35 | ``` 36 | or, 37 | ``` 38 | ./run_test.py --env_id Reacher-v2 --env_type mujoco --model_path ./models/reacher/checkpoints/01000 --record_video 39 | ``` 40 | 41 | 42 | Replace the arguments as you want. Currently models for each 100 learning steps (upto 3600 learning steps) are uploaded. 43 | 44 | You can omit the last flag `--record_video`. When it is turned on, then the videos will be recorded under the current directory. 45 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/.benchmark_pattern: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pkl 4 | *.py~ 5 | .pytest_cache 6 | .DS_Store 7 | .idea 8 | 9 | # Setuptools distribution and build folders. 10 | /dist/ 11 | /build 12 | keys/ 13 | 14 | # Virtualenv 15 | /env 16 | 17 | 18 | *.sublime-project 19 | *.sublime-workspace 20 | 21 | .idea 22 | 23 | logs/ 24 | 25 | .ipynb_checkpoints 26 | ghostdriver.log 27 | 28 | htmlcov 29 | 30 | junk 31 | src 32 | 33 | *.egg-info 34 | .cache 35 | 36 | MUJOCO_LOG.TXT 37 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | 5 | services: 6 | - docker 7 | 8 | install: 9 | - pip install flake8 10 | - docker build . -t baselines-test 11 | 12 | script: 13 | - flake8 . --show-source --statistics 14 | - docker run baselines-test pytest -v --forked . 15 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | 3 | RUN apt-get -y update && apt-get -y install ffmpeg 4 | # RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv 5 | 6 | ENV CODE_DIR /root/code 7 | 8 | COPY . $CODE_DIR/baselines 9 | WORKDIR $CODE_DIR/baselines 10 | 11 | # Clean up pycache and pyc files 12 | RUN rm -rf __pycache__ && \ 13 | find . -name "*.pyc" -delete && \ 14 | pip install tensorflow && \ 15 | pip install -e .[test] 16 | 17 | 18 | CMD /bin/bash 19 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/a2c/README.md: -------------------------------------------------------------------------------- 1 | # A2C 2 | 3 | - Original paper: https://arxiv.org/abs/1602.01783 4 | - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ 5 | - `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options 6 | - also refer to the repo-wide [README.md](../../README.md#training-models) 7 | 8 | ## Files 9 | - `run_atari`: file used to run the algorithm. 10 | - `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...). 11 | - `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training) 12 | - learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm. 13 | - `runner.py`: class used to generates a batch of experiences 14 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/a2c/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acer/README.md: -------------------------------------------------------------------------------- 1 | # ACER 2 | 3 | - Original paper: https://arxiv.org/abs/1611.01224 4 | - `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 5 | - also refer to the repo-wide [README.md](../../README.md#training-models) 6 | 7 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/acer/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acer/defaults.py: -------------------------------------------------------------------------------- 1 | def atari(): 2 | return dict( 3 | lrschedule='constant' 4 | ) 5 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acktr/README.md: -------------------------------------------------------------------------------- 1 | # ACKTR 2 | 3 | - Original paper: https://arxiv.org/abs/1708.05144 4 | - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ 5 | - `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 6 | - also refer to the repo-wide [README.md](../../README.md#training-models) 7 | 8 | ## ACKTR with continuous action spaces 9 | The code of ACKTR has been refactored to handle both discrete and continuous action spaces uniformly. In the original version, discrete and continuous action spaces were handled by different code (actkr_disc.py and acktr_cont.py) with little overlap. If interested in the original version of the acktr for continuous action spaces, use `old_acktr_cont` branch. Note that original code performs better on the mujoco tasks than the refactored version; we are still investigating why. 10 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acktr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/acktr/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acktr/defaults.py: -------------------------------------------------------------------------------- 1 | def mujoco(): 2 | return dict( 3 | nsteps=2500, 4 | value_network='copy' 5 | ) 6 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/acktr/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None): 4 | with tf.variable_scope(name, reuse=reuse): 5 | assert (len(tf.get_variable_scope().name.split('/')) == 2) 6 | 7 | w = tf.get_variable("w", [x.get_shape()[1], size], initializer=weight_init) 8 | b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init)) 9 | weight_decay_fc = 3e-4 10 | 11 | if weight_loss_dict is not None: 12 | weight_decay = tf.multiply(tf.nn.l2_loss(w), weight_decay_fc, name='weight_decay_loss') 13 | if weight_loss_dict is not None: 14 | weight_loss_dict[w] = weight_decay_fc 15 | weight_loss_dict[b] = 0.0 16 | 17 | tf.add_to_collection(tf.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay) 18 | 19 | return tf.nn.bias_add(tf.matmul(x, w), b) 20 | 21 | def kl_div(action_dist1, action_dist2, action_size): 22 | mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:] 23 | mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:] 24 | 25 | numerator = tf.square(mean1 - mean2) + tf.square(std1) - tf.square(std2) 26 | denominator = 2 * tf.square(std2) + 1e-8 27 | return tf.reduce_sum( 28 | numerator/denominator + tf.log(std2) - tf.log(std1),reduction_indices=-1) 29 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/bench/__init__.py: -------------------------------------------------------------------------------- 1 | from baselines.bench.benchmarks import * 2 | from baselines.bench.monitor import * 3 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from baselines.common.console_util import * 3 | from baselines.common.dataset import Dataset 4 | from baselines.common.math_util import * 5 | from baselines.common.misc_util import * 6 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/cg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): 3 | """ 4 | Demmel p 312 5 | """ 6 | p = b.copy() 7 | r = b.copy() 8 | x = np.zeros_like(b) 9 | rdotr = r.dot(r) 10 | 11 | fmtstr = "%10i %10.3g %10.3g" 12 | titlestr = "%10s %10s %10s" 13 | if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) 14 | 15 | for i in range(cg_iters): 16 | if callback is not None: 17 | callback(x) 18 | if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) 19 | z = f_Ax(p) 20 | v = rdotr / p.dot(z) 21 | x += v*p 22 | r -= v*z 23 | newrdotr = r.dot(r) 24 | mu = newrdotr/rdotr 25 | p = r + mu*p 26 | 27 | rdotr = newrdotr 28 | if rdotr < residual_tol: 29 | break 30 | 31 | if callback is not None: 32 | callback(x) 33 | if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 34 | return x 35 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Dataset(object): 4 | def __init__(self, data_map, deterministic=False, shuffle=True): 5 | self.data_map = data_map 6 | self.deterministic = deterministic 7 | self.enable_shuffle = shuffle 8 | self.n = next(iter(data_map.values())).shape[0] 9 | self._next_id = 0 10 | self.shuffle() 11 | 12 | def shuffle(self): 13 | if self.deterministic: 14 | return 15 | perm = np.arange(self.n) 16 | np.random.shuffle(perm) 17 | 18 | for key in self.data_map: 19 | self.data_map[key] = self.data_map[key][perm] 20 | 21 | self._next_id = 0 22 | 23 | def next_batch(self, batch_size): 24 | if self._next_id >= self.n and self.enable_shuffle: 25 | self.shuffle() 26 | 27 | cur_id = self._next_id 28 | cur_batch_size = min(batch_size, self.n - self._next_id) 29 | self._next_id += cur_batch_size 30 | 31 | data_map = dict() 32 | for key in self.data_map: 33 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size] 34 | return data_map 35 | 36 | def iterate_once(self, batch_size): 37 | if self.enable_shuffle: self.shuffle() 38 | 39 | while self._next_id <= self.n - batch_size: 40 | yield self.next_batch(batch_size) 41 | self._next_id = 0 42 | 43 | def subset(self, num_elements, deterministic=True): 44 | data_map = dict() 45 | for key in self.data_map: 46 | data_map[key] = self.data_map[key][:num_elements] 47 | return Dataset(data_map, deterministic) 48 | 49 | 50 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 51 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 52 | arrays = tuple(map(np.asarray, arrays)) 53 | n = arrays[0].shape[0] 54 | assert all(a.shape[0] == n for a in arrays[1:]) 55 | inds = np.arange(n) 56 | if shuffle: np.random.shuffle(inds) 57 | sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches 58 | for batch_inds in np.array_split(inds, sections): 59 | if include_final_partial_batch or len(batch_inds) == batch_size: 60 | yield tuple(a[batch_inds] for a in arrays) 61 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from gym.spaces import Discrete, Box, MultiDiscrete 4 | 5 | def observation_placeholder(ob_space, batch_size=None, name='Ob'): 6 | ''' 7 | Create placeholder to feed observations into of the size appropriate to the observation space 8 | 9 | Parameters: 10 | ---------- 11 | 12 | ob_space: gym.Space observation space 13 | 14 | batch_size: int size of the batch to be fed into input. Can be left None in most cases. 15 | 16 | name: str name of the placeholder 17 | 18 | Returns: 19 | ------- 20 | 21 | tensorflow placeholder tensor 22 | ''' 23 | 24 | assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 25 | 'Can only deal with Discrete and Box observation spaces for now' 26 | 27 | dtype = ob_space.dtype 28 | if dtype == np.int8: 29 | dtype = np.uint8 30 | 31 | return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) 32 | 33 | 34 | def observation_input(ob_space, batch_size=None, name='Ob'): 35 | ''' 36 | Create placeholder to feed observations into of the size appropriate to the observation space, and add input 37 | encoder of the appropriate type. 38 | ''' 39 | 40 | placeholder = observation_placeholder(ob_space, batch_size, name) 41 | return placeholder, encode_observation(ob_space, placeholder) 42 | 43 | def encode_observation(ob_space, placeholder): 44 | ''' 45 | Encode input in the way that is appropriate to the observation space 46 | 47 | Parameters: 48 | ---------- 49 | 50 | ob_space: gym.Space observation space 51 | 52 | placeholder: tf.placeholder observation input placeholder 53 | ''' 54 | if isinstance(ob_space, Discrete): 55 | return tf.to_float(tf.one_hot(placeholder, ob_space.n)) 56 | elif isinstance(ob_space, Box): 57 | return tf.to_float(placeholder) 58 | elif isinstance(ob_space, MultiDiscrete): 59 | placeholder = tf.cast(placeholder, tf.int32) 60 | one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] 61 | return tf.concat(one_hots, axis=-1) 62 | else: 63 | raise NotImplementedError 64 | 65 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/mpi_adam_optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mpi4py import MPI 4 | 5 | class MpiAdamOptimizer(tf.train.AdamOptimizer): 6 | """Adam optimizer that averages gradients across mpi processes.""" 7 | def __init__(self, comm, **kwargs): 8 | self.comm = comm 9 | tf.train.AdamOptimizer.__init__(self, **kwargs) 10 | def compute_gradients(self, loss, var_list, **kwargs): 11 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) 12 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 13 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) 14 | shapes = [v.shape.as_list() for g, v in grads_and_vars] 15 | sizes = [int(np.prod(s)) for s in shapes] 16 | 17 | num_tasks = self.comm.Get_size() 18 | buf = np.zeros(sum(sizes), np.float32) 19 | 20 | def _collect_grads(flat_grad): 21 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) 22 | np.divide(buf, float(num_tasks), out=buf) 23 | return buf 24 | 25 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32) 26 | avg_flat_grad.set_shape(flat_grad.shape) 27 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0) 28 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v) 29 | for g, (_, v) in zip(avg_grads, grads_and_vars)] 30 | 31 | return avg_grads_and_vars 32 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/mpi_fork.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, sys 2 | 3 | def mpi_fork(n, bind_to_core=False): 4 | """Re-launches the current script with workers 5 | Returns "parent" for original parent, "child" for MPI children 6 | """ 7 | if n<=1: 8 | return "child" 9 | if os.getenv("IN_MPI") is None: 10 | env = os.environ.copy() 11 | env.update( 12 | MKL_NUM_THREADS="1", 13 | OMP_NUM_THREADS="1", 14 | IN_MPI="1" 15 | ) 16 | args = ["mpirun", "-np", str(n)] 17 | if bind_to_core: 18 | args += ["-bind-to", "core"] 19 | args += [sys.executable] + sys.argv 20 | subprocess.check_call(args, env=env) 21 | return "parent" 22 | else: 23 | return "child" 24 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | from baselines.common import zipsame 4 | 5 | 6 | def mpi_mean(x, axis=0, comm=None, keepdims=False): 7 | x = np.asarray(x) 8 | assert x.ndim > 0 9 | if comm is None: comm = MPI.COMM_WORLD 10 | xsum = x.sum(axis=axis, keepdims=keepdims) 11 | n = xsum.size 12 | localsum = np.zeros(n+1, x.dtype) 13 | localsum[:n] = xsum.ravel() 14 | localsum[n] = x.shape[axis] 15 | globalsum = np.zeros_like(localsum) 16 | comm.Allreduce(localsum, globalsum, op=MPI.SUM) 17 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] 18 | 19 | def mpi_moments(x, axis=0, comm=None, keepdims=False): 20 | x = np.asarray(x) 21 | assert x.ndim > 0 22 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) 23 | sqdiffs = np.square(x - mean) 24 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 25 | assert count1 == count 26 | std = np.sqrt(meansqdiff) 27 | if not keepdims: 28 | newshape = mean.shape[:axis] + mean.shape[axis+1:] 29 | mean = mean.reshape(newshape) 30 | std = std.reshape(newshape) 31 | return mean, std, count 32 | 33 | 34 | def test_runningmeanstd(): 35 | import subprocess 36 | subprocess.check_call(['mpirun', '-np', '3', 37 | 'python','-c', 38 | 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) 39 | 40 | def _helper_runningmeanstd(): 41 | comm = MPI.COMM_WORLD 42 | np.random.seed(0) 43 | for (triple,axis) in [ 44 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), 45 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), 46 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), 47 | ]: 48 | 49 | 50 | x = np.concatenate(triple, axis=axis) 51 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] 52 | 53 | 54 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) 55 | 56 | for (a1,a2) in zipsame(ms1, ms2): 57 | print(a1, a2) 58 | assert np.allclose(a1, a2) 59 | print("ok!") 60 | 61 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/runners.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | 4 | class AbstractEnvRunner(ABC): 5 | def __init__(self, *, env, model, nsteps): 6 | self.env = env 7 | self.model = model 8 | self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1 9 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape 10 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) 11 | self.obs[:] = env.reset() 12 | self.nsteps = nsteps 13 | self.states = model.initial_state 14 | self.dones = [False for _ in range(nenv)] 15 | 16 | @abstractmethod 17 | def run(self): 18 | raise NotImplementedError 19 | 20 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/common/tests/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/common/tests/envs/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/envs/fixed_sequence_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Env 3 | from gym.spaces import Discrete 4 | 5 | 6 | class FixedSequenceEnv(Env): 7 | def __init__( 8 | self, 9 | n_actions=10, 10 | seed=0, 11 | episode_len=100 12 | ): 13 | self.np_random = np.random.RandomState() 14 | self.np_random.seed(seed) 15 | self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)] 16 | 17 | self.action_space = Discrete(n_actions) 18 | self.observation_space = Discrete(1) 19 | 20 | self.episode_len = episode_len 21 | self.time = 0 22 | self.reset() 23 | 24 | def reset(self): 25 | self.time = 0 26 | return 0 27 | 28 | def step(self, actions): 29 | rew = self._get_reward(actions) 30 | self._choose_next_state() 31 | done = False 32 | if self.episode_len and self.time >= self.episode_len: 33 | rew = 0 34 | done = True 35 | 36 | return 0, rew, done, {} 37 | 38 | def _choose_next_state(self): 39 | self.time += 1 40 | 41 | def _get_reward(self, actions): 42 | return 1 if actions == self.sequence[self.time] else 0 43 | 44 | 45 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/envs/identity_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import abstractmethod 3 | from gym import Env 4 | from gym.spaces import MultiDiscrete, Discrete, Box 5 | 6 | 7 | class IdentityEnv(Env): 8 | def __init__( 9 | self, 10 | episode_len=None 11 | ): 12 | 13 | self.episode_len = episode_len 14 | self.time = 0 15 | self.reset() 16 | 17 | def reset(self): 18 | self._choose_next_state() 19 | self.time = 0 20 | self.observation_space = self.action_space 21 | 22 | return self.state 23 | 24 | def step(self, actions): 25 | rew = self._get_reward(actions) 26 | self._choose_next_state() 27 | done = False 28 | if self.episode_len and self.time >= self.episode_len: 29 | rew = 0 30 | done = True 31 | 32 | return self.state, rew, done, {} 33 | 34 | def _choose_next_state(self): 35 | self.state = self.action_space.sample() 36 | self.time += 1 37 | 38 | @abstractmethod 39 | def _get_reward(self, actions): 40 | raise NotImplementedError 41 | 42 | 43 | class DiscreteIdentityEnv(IdentityEnv): 44 | def __init__( 45 | self, 46 | dim, 47 | episode_len=None, 48 | ): 49 | 50 | self.action_space = Discrete(dim) 51 | super().__init__(episode_len=episode_len) 52 | 53 | def _get_reward(self, actions): 54 | return 1 if self.state == actions else 0 55 | 56 | class MultiDiscreteIdentityEnv(IdentityEnv): 57 | def __init__( 58 | self, 59 | dims, 60 | episode_len=None, 61 | ): 62 | 63 | self.action_space = MultiDiscrete(dims) 64 | super().__init__(episode_len=episode_len) 65 | 66 | def _get_reward(self, actions): 67 | return 1 if all(self.state == actions) else 0 68 | 69 | 70 | class BoxIdentityEnv(IdentityEnv): 71 | def __init__( 72 | self, 73 | shape, 74 | episode_len=None, 75 | ): 76 | 77 | self.action_space = Box(low=-1.0, high=1.0, shape=shape) 78 | super().__init__(episode_len=episode_len) 79 | 80 | def _get_reward(self, actions): 81 | diff = actions - self.state 82 | diff = diff[:] 83 | return -0.5 * np.dot(diff, diff) 84 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/envs/mnist_env.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import tempfile 4 | from gym import Env 5 | from gym.spaces import Discrete, Box 6 | 7 | 8 | 9 | class MnistEnv(Env): 10 | def __init__( 11 | self, 12 | seed=0, 13 | episode_len=None, 14 | no_images=None 15 | ): 16 | import filelock 17 | from tensorflow.examples.tutorials.mnist import input_data 18 | # we could use temporary directory for this with a context manager and 19 | # TemporaryDirecotry, but then each test that uses mnist would re-download the data 20 | # this way the data is not cleaned up, but we only download it once per machine 21 | mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data') 22 | with filelock.FileLock(mnist_path + '.lock'): 23 | self.mnist = input_data.read_data_sets(mnist_path) 24 | 25 | self.np_random = np.random.RandomState() 26 | self.np_random.seed(seed) 27 | 28 | self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1)) 29 | self.action_space = Discrete(10) 30 | self.episode_len = episode_len 31 | self.time = 0 32 | self.no_images = no_images 33 | 34 | self.train_mode() 35 | self.reset() 36 | 37 | def reset(self): 38 | self._choose_next_state() 39 | self.time = 0 40 | 41 | return self.state[0] 42 | 43 | def step(self, actions): 44 | rew = self._get_reward(actions) 45 | self._choose_next_state() 46 | done = False 47 | if self.episode_len and self.time >= self.episode_len: 48 | rew = 0 49 | done = True 50 | 51 | return self.state[0], rew, done, {} 52 | 53 | def train_mode(self): 54 | self.dataset = self.mnist.train 55 | 56 | def test_mode(self): 57 | self.dataset = self.mnist.test 58 | 59 | def _choose_next_state(self): 60 | max_index = (self.no_images if self.no_images is not None else self.dataset.num_examples) - 1 61 | index = self.np_random.randint(0, max_index) 62 | image = self.dataset.images[index].reshape(28,28,1)*255 63 | label = self.dataset.labels[index] 64 | self.state = (image, label) 65 | self.time += 1 66 | 67 | def _get_reward(self, actions): 68 | return 1 if self.state[1] == actions else 0 69 | 70 | 71 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_cartpole.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | 4 | from baselines.run import get_learn_function 5 | from baselines.common.tests.util import reward_per_episode_test 6 | 7 | common_kwargs = dict( 8 | total_timesteps=30000, 9 | network='mlp', 10 | gamma=1.0, 11 | seed=0, 12 | ) 13 | 14 | learn_kwargs = { 15 | 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), 16 | 'acer': dict(value_network='copy'), 17 | 'acktr': dict(nsteps=32, value_network='copy', is_async=False), 18 | 'deepq': dict(total_timesteps=20000), 19 | 'ppo2': dict(value_network='copy'), 20 | 'trpo_mpi': {} 21 | } 22 | 23 | @pytest.mark.slow 24 | @pytest.mark.parametrize("alg", learn_kwargs.keys()) 25 | def test_cartpole(alg): 26 | ''' 27 | Test if the algorithm (with an mlp policy) 28 | can learn to balance the cartpole 29 | ''' 30 | 31 | kwargs = common_kwargs.copy() 32 | kwargs.update(learn_kwargs[alg]) 33 | 34 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) 35 | def env_fn(): 36 | 37 | env = gym.make('CartPole-v0') 38 | env.seed(0) 39 | return env 40 | 41 | reward_per_episode_test(env_fn, learn_fn, 100) 42 | 43 | if __name__ == '__main__': 44 | test_cartpole('acer') 45 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_doc_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | try: 3 | import mujoco_py 4 | _mujoco_present = True 5 | except BaseException: 6 | mujoco_py = None 7 | _mujoco_present = False 8 | 9 | 10 | @pytest.mark.skipif( 11 | not _mujoco_present, 12 | reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library' 13 | ) 14 | def test_lstm_example(): 15 | import tensorflow as tf 16 | from baselines.common import policies, models, cmd_util 17 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 18 | 19 | # create vectorized environment 20 | venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)]) 21 | 22 | with tf.Session() as sess: 23 | # build policy based on lstm network with 128 units 24 | policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1) 25 | 26 | # initialize tensorflow variables 27 | sess.run(tf.global_variables_initializer()) 28 | 29 | # prepare environment variables 30 | ob = venv.reset() 31 | state = policy.initial_state 32 | done = [False] 33 | step_counter = 0 34 | 35 | # run a single episode until the end (i.e. until done) 36 | while True: 37 | action, _, state, _ = policy.step(ob, S=state, M=done) 38 | ob, reward, done, _ = venv.step(action) 39 | step_counter += 1 40 | if done: 41 | break 42 | 43 | 44 | assert step_counter > 5 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_env_after_learn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | import tensorflow as tf 4 | 5 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 6 | from baselines.run import get_learn_function 7 | from baselines.common.tf_util import make_session 8 | 9 | algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] 10 | 11 | @pytest.mark.parametrize('algo', algos) 12 | def test_env_after_learn(algo): 13 | def make_env(): 14 | # acktr requires too much RAM, fails on travis 15 | env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4') 16 | return env 17 | 18 | make_session(make_default=True, graph=tf.Graph()) 19 | env = SubprocVecEnv([make_env]) 20 | 21 | learn = get_learn_function(algo) 22 | 23 | # Commenting out the following line resolves the issue, though crash happens at env.reset(). 24 | learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None) 25 | 26 | env.reset() 27 | env.close() 28 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_fixed_sequence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from baselines.common.tests.envs.fixed_sequence_env import FixedSequenceEnv 3 | 4 | from baselines.common.tests.util import simple_test 5 | from baselines.run import get_learn_function 6 | 7 | common_kwargs = dict( 8 | seed=0, 9 | total_timesteps=50000, 10 | ) 11 | 12 | learn_kwargs = { 13 | 'a2c': {}, 14 | 'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1), 15 | # TODO enable sequential models for trpo_mpi (proper handling of nbatch and nsteps) 16 | # github issue: https://github.com/openai/baselines/issues/188 17 | # 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001) 18 | } 19 | 20 | 21 | alg_list = learn_kwargs.keys() 22 | rnn_list = ['lstm'] 23 | 24 | @pytest.mark.slow 25 | @pytest.mark.parametrize("alg", alg_list) 26 | @pytest.mark.parametrize("rnn", rnn_list) 27 | def test_fixed_sequence(alg, rnn): 28 | ''' 29 | Test if the algorithm (with a given policy) 30 | can learn an identity transformation (i.e. return observation as an action) 31 | ''' 32 | 33 | kwargs = learn_kwargs[alg] 34 | kwargs.update(common_kwargs) 35 | 36 | episode_len = 5 37 | env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len) 38 | learn = lambda e: get_learn_function(alg)( 39 | env=e, 40 | network=rnn, 41 | **kwargs 42 | ) 43 | 44 | simple_test(env_fn, learn, 0.7) 45 | 46 | 47 | if __name__ == '__main__': 48 | test_fixed_sequence('ppo2', 'lstm') 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | # from baselines.acer import acer_simple as acer 4 | from baselines.common.tests.envs.mnist_env import MnistEnv 5 | from baselines.common.tests.util import simple_test 6 | from baselines.run import get_learn_function 7 | 8 | 9 | # TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem? 10 | # GitHub issue https://github.com/openai/baselines/issues/189 11 | common_kwargs = { 12 | 'seed': 0, 13 | 'network':'cnn', 14 | 'gamma':0.9, 15 | 'pad':'SAME' 16 | } 17 | 18 | learn_args = { 19 | 'a2c': dict(total_timesteps=50000), 20 | 'acer': dict(total_timesteps=20000), 21 | 'deepq': dict(total_timesteps=5000), 22 | 'acktr': dict(total_timesteps=30000), 23 | 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), 24 | 'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001) 25 | } 26 | 27 | 28 | #tests pass, but are too slow on travis. Same algorithms are covered 29 | # by other tests with less compute-hungry nn's and by benchmarks 30 | @pytest.mark.skip 31 | @pytest.mark.slow 32 | @pytest.mark.parametrize("alg", learn_args.keys()) 33 | def test_mnist(alg): 34 | ''' 35 | Test if the algorithm can learn to classify MNIST digits. 36 | Uses CNN policy. 37 | ''' 38 | 39 | learn_kwargs = learn_args[alg] 40 | learn_kwargs.update(common_kwargs) 41 | 42 | learn = get_learn_function(alg) 43 | learn_fn = lambda e: learn(env=e, **learn_kwargs) 44 | env_fn = lambda: MnistEnv(seed=0, episode_len=100) 45 | 46 | simple_test(env_fn, learn_fn, 0.6) 47 | 48 | if __name__ == '__main__': 49 | test_mnist('acer') 50 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from baselines.common.schedules import ConstantSchedule, PiecewiseSchedule 4 | 5 | 6 | def test_piecewise_schedule(): 7 | ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500) 8 | 9 | assert np.isclose(ps.value(-10), 500) 10 | assert np.isclose(ps.value(0), 150) 11 | assert np.isclose(ps.value(5), 200) 12 | assert np.isclose(ps.value(9), 80) 13 | assert np.isclose(ps.value(50), 50) 14 | assert np.isclose(ps.value(80), 50) 15 | assert np.isclose(ps.value(150), 0) 16 | assert np.isclose(ps.value(175), -25) 17 | assert np.isclose(ps.value(201), 500) 18 | assert np.isclose(ps.value(500), 500) 19 | 20 | assert np.isclose(ps.value(200 - 1e-10), -50) 21 | 22 | 23 | def test_constant_schedule(): 24 | cs = ConstantSchedule(5) 25 | for i in range(-100, 100): 26 | assert np.isclose(cs.value(i), 5) 27 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tests/test_tf_util.py: -------------------------------------------------------------------------------- 1 | # tests for tf_util 2 | import tensorflow as tf 3 | from baselines.common.tf_util import ( 4 | function, 5 | initialize, 6 | single_threaded_session 7 | ) 8 | 9 | 10 | def test_function(): 11 | with tf.Graph().as_default(): 12 | x = tf.placeholder(tf.int32, (), name="x") 13 | y = tf.placeholder(tf.int32, (), name="y") 14 | z = 3 * x + 2 * y 15 | lin = function([x, y], z, givens={y: 0}) 16 | 17 | with single_threaded_session(): 18 | initialize() 19 | 20 | assert lin(2) == 6 21 | assert lin(2, 2) == 10 22 | 23 | 24 | def test_multikwargs(): 25 | with tf.Graph().as_default(): 26 | x = tf.placeholder(tf.int32, (), name="x") 27 | with tf.variable_scope("other"): 28 | x2 = tf.placeholder(tf.int32, (), name="x") 29 | z = 3 * x + 2 * x2 30 | 31 | lin = function([x, x2], z, givens={x2: 0}) 32 | with single_threaded_session(): 33 | initialize() 34 | assert lin(2) == 6 35 | assert lin(2, 2) == 10 36 | 37 | 38 | if __name__ == '__main__': 39 | test_function() 40 | test_multikwargs() 41 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def tile_images(img_nhwc): 4 | """ 5 | Tile N images into one big PxQ image 6 | (P,Q) are chosen to be as close as possible, and if N 7 | is square, then P=Q. 8 | 9 | input: img_nhwc, list or array of images, ndim=4 once turned into array 10 | n = batch index, h = height, w = width, c = channel 11 | returns: 12 | bigim_HWc, ndarray with ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | N, h, w, c = img_nhwc.shape 16 | H = int(np.ceil(np.sqrt(N))) 17 | W = int(np.ceil(float(N)/H)) 18 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 19 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 20 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 21 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 22 | return img_Hh_Ww_c 23 | 24 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/vec_env/test_video_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for asynchronous vectorized environments. 3 | """ 4 | 5 | import gym 6 | import pytest 7 | import os 8 | import glob 9 | import tempfile 10 | 11 | from .dummy_vec_env import DummyVecEnv 12 | from .shmem_vec_env import ShmemVecEnv 13 | from .subproc_vec_env import SubprocVecEnv 14 | from .vec_video_recorder import VecVideoRecorder 15 | 16 | @pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv)) 17 | @pytest.mark.parametrize('num_envs', (1, 4)) 18 | @pytest.mark.parametrize('video_length', (10, 100)) 19 | @pytest.mark.parametrize('video_interval', (1, 50)) 20 | def test_video_recorder(klass, num_envs, video_length, video_interval): 21 | """ 22 | Wrap an existing VecEnv with VevVideoRecorder, 23 | Make (video_interval + video_length + 1) steps, 24 | then check that the file is present 25 | """ 26 | 27 | def make_fn(): 28 | env = gym.make('PongNoFrameskip-v4') 29 | return env 30 | fns = [make_fn for _ in range(num_envs)] 31 | env = klass(fns) 32 | 33 | with tempfile.TemporaryDirectory() as video_path: 34 | env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length) 35 | 36 | env.reset() 37 | for _ in range(video_interval + video_length + 1): 38 | env.step([0] * num_envs) 39 | env.close() 40 | 41 | 42 | recorded_video = glob.glob(os.path.join(video_path, "*.mp4")) 43 | 44 | # first and second step 45 | assert len(recorded_video) == 2 46 | # Files are not empty 47 | assert all(os.stat(p).st_size != 0 for p in recorded_video) 48 | 49 | 50 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy an observation dict. 14 | """ 15 | return {k: np.copy(v) for k, v in obs.items()} 16 | 17 | 18 | def dict_to_obs(obs_dict): 19 | """ 20 | Convert an observation dict into a raw array if the 21 | original observation space was not a Dict space. 22 | """ 23 | if set(obs_dict.keys()) == {None}: 24 | return obs_dict[None] 25 | return obs_dict 26 | 27 | 28 | def obs_space_info(obs_space): 29 | """ 30 | Get dict-structured information about a gym.Space. 31 | 32 | Returns: 33 | A tuple (keys, shapes, dtypes): 34 | keys: a list of dict keys. 35 | shapes: a dict mapping keys to shapes. 36 | dtypes: a dict mapping keys to dtypes. 37 | """ 38 | if isinstance(obs_space, gym.spaces.Dict): 39 | assert isinstance(obs_space.spaces, OrderedDict) 40 | subspaces = obs_space.spaces 41 | else: 42 | subspaces = {None: obs_space} 43 | keys = [] 44 | shapes = {} 45 | dtypes = {} 46 | for key, box in subspaces.items(): 47 | keys.append(key) 48 | shapes[key] = box.shape 49 | dtypes[key] = box.dtype 50 | return keys, shapes, dtypes 51 | 52 | 53 | def obs_to_dict(obs): 54 | """ 55 | Convert an observation into a dict. 56 | """ 57 | if isinstance(obs, dict): 58 | return obs 59 | return {None: obs} 60 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | import numpy as np 3 | from gym import spaces 4 | 5 | 6 | class VecFrameStack(VecEnvWrapper): 7 | def __init__(self, venv, nstack): 8 | self.venv = venv 9 | self.nstack = nstack 10 | wos = venv.observation_space # wrapped ob space 11 | low = np.repeat(wos.low, self.nstack, axis=-1) 12 | high = np.repeat(wos.high, self.nstack, axis=-1) 13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 16 | 17 | def step_wait(self): 18 | obs, rews, news, infos = self.venv.step_wait() 19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 20 | for (i, new) in enumerate(news): 21 | if new: 22 | self.stackedobs[i] = 0 23 | self.stackedobs[..., -obs.shape[-1]:] = obs 24 | return self.stackedobs, rews, news, infos 25 | 26 | def reset(self): 27 | obs = self.venv.reset() 28 | self.stackedobs[...] = 0 29 | self.stackedobs[..., -obs.shape[-1]:] = obs 30 | return self.stackedobs 31 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/common/vec_env/vec_monitor.py: -------------------------------------------------------------------------------- 1 | from . import VecEnvWrapper 2 | from baselines.bench.monitor import ResultsWriter 3 | import numpy as np 4 | import time 5 | 6 | 7 | class VecMonitor(VecEnvWrapper): 8 | def __init__(self, venv, filename=None): 9 | VecEnvWrapper.__init__(self, venv) 10 | self.eprets = None 11 | self.eplens = None 12 | self.tstart = time.time() 13 | self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) 14 | 15 | def reset(self): 16 | obs = self.venv.reset() 17 | self.eprets = np.zeros(self.num_envs, 'f') 18 | self.eplens = np.zeros(self.num_envs, 'i') 19 | return obs 20 | 21 | def step_wait(self): 22 | obs, rews, dones, infos = self.venv.step_wait() 23 | self.eprets += rews 24 | self.eplens += 1 25 | newinfos = [] 26 | for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): 27 | info = info.copy() 28 | if done: 29 | epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} 30 | info['episode'] = epinfo 31 | self.eprets[i] = 0 32 | self.eplens[i] = 0 33 | self.results_writer.write_row(epinfo) 34 | 35 | newinfos.append(info) 36 | 37 | return obs, rews, dones, newinfos 38 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ddpg/README.md: -------------------------------------------------------------------------------- 1 | # DDPG 2 | 3 | - Original paper: https://arxiv.org/abs/1509.02971 4 | - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/ 5 | - `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. 6 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ddpg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/ddpg/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ddpg/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.common.models import get_network_builder 3 | 4 | 5 | class Model(object): 6 | def __init__(self, name, network='mlp', **network_kwargs): 7 | self.name = name 8 | self.network_builder = get_network_builder(network)(**network_kwargs) 9 | 10 | @property 11 | def vars(self): 12 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 13 | 14 | @property 15 | def trainable_vars(self): 16 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 17 | 18 | @property 19 | def perturbable_vars(self): 20 | return [var for var in self.trainable_vars if 'LayerNorm' not in var.name] 21 | 22 | 23 | class Actor(Model): 24 | def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs): 25 | super().__init__(name=name, network=network, **network_kwargs) 26 | self.nb_actions = nb_actions 27 | 28 | def __call__(self, obs, reuse=False): 29 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 30 | x = self.network_builder(obs) 31 | x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3)) 32 | x = tf.nn.tanh(x) 33 | return x 34 | 35 | 36 | class Critic(Model): 37 | def __init__(self, name='critic', network='mlp', **network_kwargs): 38 | super().__init__(name=name, network=network, **network_kwargs) 39 | self.layer_norm = True 40 | 41 | def __call__(self, obs, action, reuse=False): 42 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 43 | x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated 44 | x = self.network_builder(x) 45 | x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3)) 46 | return x 47 | 48 | @property 49 | def output_vars(self): 50 | output_vars = [var for var in self.trainable_vars if 'output' in var.name] 51 | return output_vars 52 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/README.md: -------------------------------------------------------------------------------- 1 | ## If you are curious. 2 | 3 | ##### Train a Cartpole agent and watch it play once it converges! 4 | 5 | Here's a list of commands to run to quickly get a working example: 6 | 7 | 8 | 9 | 10 | ```bash 11 | # Train model and save the results to cartpole_model.pkl 12 | python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5 13 | # Load the model saved in cartpole_model.pkl and visualize the learned policy 14 | python -m baselines.run --alg=deepq --env=CartPole-v0 --load_path=./cartpole_model.pkl --num_timesteps=0 --play 15 | ``` 16 | 17 | ## If you wish to apply DQN to solve a problem. 18 | 19 | Check out our simple agent trained with one stop shop `deepq.learn` function. 20 | 21 | - [baselines/deepq/experiments/train_cartpole.py](experiments/train_cartpole.py) - train a Cartpole agent. 22 | 23 | In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. Complimentary file `enjoy_cartpole.py` loads and visualizes the learned policy. 24 | 25 | ## If you wish to experiment with the algorithm 26 | 27 | ##### Check out the examples 28 | 29 | - [baselines/deepq/experiments/custom_cartpole.py](experiments/custom_cartpole.py) - Cartpole training with more fine grained control over the internals of DQN algorithm. 30 | - [baselines/deepq/defaults.py](defaults.py) - settings for training on atari. Run 31 | 32 | ```bash 33 | python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 34 | ``` 35 | to train on Atari Pong (see more in repo-wide [README.md](../../README.md#training-models)) 36 | 37 | 38 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/__init__.py: -------------------------------------------------------------------------------- 1 | from baselines.deepq import models # noqa 2 | from baselines.deepq.build_graph import build_act, build_train # noqa 3 | from baselines.deepq.deepq import learn, load_act # noqa 4 | from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa 5 | 6 | def wrap_atari_dqn(env): 7 | from baselines.common.atari_wrappers import wrap_deepmind 8 | return wrap_deepmind(env, frame_stack=True, scale=True) 9 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/defaults.py: -------------------------------------------------------------------------------- 1 | def atari(): 2 | return dict( 3 | network='conv_only', 4 | lr=1e-4, 5 | buffer_size=10000, 6 | exploration_fraction=0.1, 7 | exploration_final_eps=0.01, 8 | train_freq=4, 9 | learning_starts=10000, 10 | target_network_update_freq=1000, 11 | gamma=0.99, 12 | prioritized_replay=True, 13 | prioritized_replay_alpha=0.6, 14 | checkpoint_freq=10000, 15 | checkpoint_path=None, 16 | dueling=True 17 | ) 18 | 19 | def retro(): 20 | return atari() 21 | 22 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/deepq/experiments/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/enjoy_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | 5 | 6 | def main(): 7 | env = gym.make("CartPole-v0") 8 | act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl") 9 | 10 | while True: 11 | obs, done = env.reset(), False 12 | episode_rew = 0 13 | while not done: 14 | env.render() 15 | obs, rew, done, _ = env.step(act(obs[None])[0]) 16 | episode_rew += rew 17 | print("Episode reward", episode_rew) 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/enjoy_mountaincar.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | from baselines.common import models 5 | 6 | 7 | def main(): 8 | env = gym.make("MountainCar-v0") 9 | act = deepq.learn( 10 | env, 11 | network=models.mlp(num_layers=1, num_hidden=64), 12 | total_timesteps=0, 13 | load_path='mountaincar_model.pkl' 14 | ) 15 | 16 | while True: 17 | obs, done = env.reset(), False 18 | episode_rew = 0 19 | while not done: 20 | env.render() 21 | obs, rew, done, _ = env.step(act(obs[None])[0]) 22 | episode_rew += rew 23 | print("Episode reward", episode_rew) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/enjoy_pong.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from baselines import deepq 3 | 4 | 5 | def main(): 6 | env = gym.make("PongNoFrameskip-v4") 7 | env = deepq.wrap_atari_dqn(env) 8 | model = deepq.learn( 9 | env, 10 | "conv_only", 11 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 12 | hiddens=[256], 13 | dueling=True, 14 | total_timesteps=0 15 | ) 16 | 17 | while True: 18 | obs, done = env.reset(), False 19 | episode_rew = 0 20 | while not done: 21 | env.render() 22 | obs, rew, done, _ = env.step(model(obs[None])[0]) 23 | episode_rew += rew 24 | print("Episode reward", episode_rew) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/train_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | 5 | 6 | def callback(lcl, _glb): 7 | # stop training if reward exceeds 199 8 | is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199 9 | return is_solved 10 | 11 | 12 | def main(): 13 | env = gym.make("CartPole-v0") 14 | act = deepq.learn( 15 | env, 16 | network='mlp', 17 | lr=1e-3, 18 | total_timesteps=100000, 19 | buffer_size=50000, 20 | exploration_fraction=0.1, 21 | exploration_final_eps=0.02, 22 | print_freq=10, 23 | callback=callback 24 | ) 25 | print("Saving model to cartpole_model.pkl") 26 | act.save("cartpole_model.pkl") 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/train_mountaincar.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from baselines import deepq 4 | from baselines.common import models 5 | 6 | 7 | def main(): 8 | env = gym.make("MountainCar-v0") 9 | # Enabling layer_norm here is import for parameter space noise! 10 | act = deepq.learn( 11 | env, 12 | network=models.mlp(num_hidden=64, num_layers=1), 13 | lr=1e-3, 14 | total_timesteps=100000, 15 | buffer_size=50000, 16 | exploration_fraction=0.1, 17 | exploration_final_eps=0.1, 18 | print_freq=10, 19 | param_noise=True 20 | ) 21 | print("Saving model to mountaincar_model.pkl") 22 | act.save("mountaincar_model.pkl") 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/experiments/train_pong.py: -------------------------------------------------------------------------------- 1 | from baselines import deepq 2 | from baselines import bench 3 | from baselines import logger 4 | from baselines.common.atari_wrappers import make_atari 5 | 6 | 7 | def main(): 8 | logger.configure() 9 | env = make_atari('PongNoFrameskip-v4') 10 | env = bench.Monitor(env, logger.get_dir()) 11 | env = deepq.wrap_atari_dqn(env) 12 | 13 | model = deepq.learn( 14 | env, 15 | "conv_only", 16 | convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], 17 | hiddens=[256], 18 | dueling=True, 19 | lr=1e-4, 20 | total_timesteps=int(1e7), 21 | buffer_size=10000, 22 | exploration_fraction=0.1, 23 | exploration_final_eps=0.01, 24 | train_freq=4, 25 | learning_starts=10000, 26 | target_network_update_freq=1000, 27 | gamma=0.99, 28 | ) 29 | 30 | model.save('pong_model.pkl') 31 | env.close() 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/deepq/utils.py: -------------------------------------------------------------------------------- 1 | from baselines.common.input import observation_input 2 | from baselines.common.tf_util import adjust_shape 3 | 4 | # ================================================================ 5 | # Placeholders 6 | # ================================================================ 7 | 8 | 9 | class TfInput(object): 10 | def __init__(self, name="(unnamed)"): 11 | """Generalized Tensorflow placeholder. The main differences are: 12 | - possibly uses multiple placeholders internally and returns multiple values 13 | - can apply light postprocessing to the value feed to placeholder. 14 | """ 15 | self.name = name 16 | 17 | def get(self): 18 | """Return the tf variable(s) representing the possibly postprocessed value 19 | of placeholder(s). 20 | """ 21 | raise NotImplementedError 22 | 23 | def make_feed_dict(data): 24 | """Given data input it to the placeholder(s).""" 25 | raise NotImplementedError 26 | 27 | 28 | class PlaceholderTfInput(TfInput): 29 | def __init__(self, placeholder): 30 | """Wrapper for regular tensorflow placeholder.""" 31 | super().__init__(placeholder.name) 32 | self._placeholder = placeholder 33 | 34 | def get(self): 35 | return self._placeholder 36 | 37 | def make_feed_dict(self, data): 38 | return {self._placeholder: adjust_shape(self._placeholder, data)} 39 | 40 | 41 | class ObservationInput(PlaceholderTfInput): 42 | def __init__(self, observation_space, name=None): 43 | """Creates an input placeholder tailored to a specific observation space 44 | 45 | Parameters 46 | ---------- 47 | 48 | observation_space: 49 | observation space of the environment. Should be one of the gym.spaces types 50 | name: str 51 | tensorflow name of the underlying placeholder 52 | """ 53 | inpt, self.processed_inpt = observation_input(observation_space, name=name) 54 | super().__init__(inpt) 55 | 56 | def get(self): 57 | return self.processed_inpt 58 | 59 | 60 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/dataset/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HalfCheetah-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Hopper-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Hopper-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Hopper-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Hopper-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Hopper-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Hopper-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Hopper-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Hopper-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Humanoid-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/HumanoidStandup-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-normalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-normalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-normalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-normalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-unnormalized-deterministic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-unnormalized-deterministic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-unnormalized-stochastic-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/Walker2d-unnormalized-stochastic-scores.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/halfcheetah-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/halfcheetah-training.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/hopper-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/hopper-training.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/humanoid-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/humanoid-training.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/humanoidstandup-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/humanoidstandup-training.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/result/walker2d-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/gail/result/walker2d-training.png -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/gail/statistics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is highly based on https://github.com/carpedm20/deep-rl-tensorflow/blob/master/agents/statistic.py 3 | ''' 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | import baselines.common.tf_util as U 9 | 10 | 11 | class stats(): 12 | 13 | def __init__(self, scalar_keys=[], histogram_keys=[]): 14 | self.scalar_keys = scalar_keys 15 | self.histogram_keys = histogram_keys 16 | self.scalar_summaries = [] 17 | self.scalar_summaries_ph = [] 18 | self.histogram_summaries_ph = [] 19 | self.histogram_summaries = [] 20 | with tf.variable_scope('summary'): 21 | for k in scalar_keys: 22 | ph = tf.placeholder('float32', None, name=k+'.scalar.summary') 23 | sm = tf.summary.scalar(k+'.scalar.summary', ph) 24 | self.scalar_summaries_ph.append(ph) 25 | self.scalar_summaries.append(sm) 26 | for k in histogram_keys: 27 | ph = tf.placeholder('float32', None, name=k+'.histogram.summary') 28 | sm = tf.summary.scalar(k+'.histogram.summary', ph) 29 | self.histogram_summaries_ph.append(ph) 30 | self.histogram_summaries.append(sm) 31 | 32 | self.summaries = tf.summary.merge(self.scalar_summaries+self.histogram_summaries) 33 | 34 | def add_all_summary(self, writer, values, iter): 35 | # Note that the order of the incoming ```values``` should be the same as the that of the 36 | # ```scalar_keys``` given in ```__init__``` 37 | if np.sum(np.isnan(values)+0) != 0: 38 | return 39 | sess = U.get_session() 40 | keys = self.scalar_summaries_ph + self.histogram_summaries_ph 41 | feed_dict = {} 42 | for k, v in zip(keys, values): 43 | feed_dict.update({k: v}) 44 | summaries_str = sess.run(self.summaries, feed_dict) 45 | writer.add_summary(summaries_str, iter) 46 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/her/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/her/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/her/actor_critic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.her.util import store_args, nn 3 | 4 | 5 | class ActorCritic: 6 | @store_args 7 | def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, 8 | **kwargs): 9 | """The actor-critic network and related training code. 10 | 11 | Args: 12 | inputs_tf (dict of tensors): all necessary inputs for the network: the 13 | observation (o), the goal (g), and the action (u) 14 | dimo (int): the dimension of the observations 15 | dimg (int): the dimension of the goals 16 | dimu (int): the dimension of the actions 17 | max_u (float): the maximum magnitude of actions; action outputs will be scaled 18 | accordingly 19 | o_stats (baselines.her.Normalizer): normalizer for observations 20 | g_stats (baselines.her.Normalizer): normalizer for goals 21 | hidden (int): number of hidden units that should be used in hidden layers 22 | layers (int): number of hidden layers 23 | """ 24 | self.o_tf = inputs_tf['o'] 25 | self.g_tf = inputs_tf['g'] 26 | self.u_tf = inputs_tf['u'] 27 | 28 | # Prepare inputs for actor and critic. 29 | o = self.o_stats.normalize(self.o_tf) 30 | g = self.g_stats.normalize(self.g_tf) 31 | input_pi = tf.concat(axis=1, values=[o, g]) # for actor 32 | 33 | # Networks. 34 | with tf.variable_scope('pi'): 35 | self.pi_tf = self.max_u * tf.tanh(nn( 36 | input_pi, [self.hidden] * self.layers + [self.dimu])) 37 | with tf.variable_scope('Q'): 38 | # for policy training 39 | input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u]) 40 | self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1]) 41 | # for critic training 42 | input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u]) 43 | self._input_Q = input_Q # exposed for tests 44 | self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True) 45 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/her/experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/her/experiment/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/her/experiment/play.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import pickle 4 | 5 | from baselines import logger 6 | from baselines.common import set_global_seeds 7 | import baselines.her.experiment.config as config 8 | from baselines.her.rollout import RolloutWorker 9 | 10 | 11 | @click.command() 12 | @click.argument('policy_file', type=str) 13 | @click.option('--seed', type=int, default=0) 14 | @click.option('--n_test_rollouts', type=int, default=10) 15 | @click.option('--render', type=int, default=1) 16 | def main(policy_file, seed, n_test_rollouts, render): 17 | set_global_seeds(seed) 18 | 19 | # Load policy. 20 | with open(policy_file, 'rb') as f: 21 | policy = pickle.load(f) 22 | env_name = policy.info['env_name'] 23 | 24 | # Prepare params. 25 | params = config.DEFAULT_PARAMS 26 | if env_name in config.DEFAULT_ENV_PARAMS: 27 | params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in 28 | params['env_name'] = env_name 29 | params = config.prepare_params(params) 30 | config.log_params(params, logger=logger) 31 | 32 | dims = config.configure_dims(params) 33 | 34 | eval_params = { 35 | 'exploit': True, 36 | 'use_target_net': params['test_with_polyak'], 37 | 'compute_Q': True, 38 | 'rollout_batch_size': 1, 39 | 'render': bool(render), 40 | } 41 | 42 | for name in ['T', 'gamma', 'noise_eps', 'random_eps']: 43 | eval_params[name] = params[name] 44 | 45 | evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) 46 | evaluator.seed(seed) 47 | 48 | # Run evaluation. 49 | evaluator.clear_history() 50 | for _ in range(n_test_rollouts): 51 | evaluator.generate_rollouts() 52 | 53 | # record logs 54 | for key, val in evaluator.logs('test'): 55 | logger.record_tabular(key, np.mean(val)) 56 | logger.dump_tabular() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo1/README.md: -------------------------------------------------------------------------------- 1 | # PPOSGD 2 | 3 | - Original paper: https://arxiv.org/abs/1707.06347 4 | - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ 5 | - `mpirun -np 8 python -m baselines.ppo1.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. 6 | - `python -m baselines.ppo1.run_mujoco` runs the algorithm for 1M frames on a Mujoco environment. 7 | 8 | - Train mujoco 3d humanoid (with optimal-ish hyperparameters): `mpirun -np 16 python -m baselines.ppo1.run_humanoid --model-path=/path/to/model` 9 | - Render the 3d humanoid: `python -m baselines.ppo1.run_humanoid --play --model-path=/path/to/model` 10 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/ppo1/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo1/run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from mpi4py import MPI 4 | from baselines.common import set_global_seeds 5 | from baselines import bench 6 | import os.path as osp 7 | from baselines import logger 8 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 9 | from baselines.common.cmd_util import atari_arg_parser 10 | 11 | def train(env_id, num_timesteps, seed): 12 | from baselines.ppo1 import pposgd_simple, cnn_policy 13 | import baselines.common.tf_util as U 14 | rank = MPI.COMM_WORLD.Get_rank() 15 | sess = U.single_threaded_session() 16 | sess.__enter__() 17 | if rank == 0: 18 | logger.configure() 19 | else: 20 | logger.configure(format_strs=[]) 21 | workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed is not None else None 22 | set_global_seeds(workerseed) 23 | env = make_atari(env_id) 24 | def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613 25 | return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space) 26 | env = bench.Monitor(env, logger.get_dir() and 27 | osp.join(logger.get_dir(), str(rank))) 28 | env.seed(workerseed) 29 | 30 | env = wrap_deepmind(env) 31 | env.seed(workerseed) 32 | 33 | pposgd_simple.learn(env, policy_fn, 34 | max_timesteps=int(num_timesteps * 1.1), 35 | timesteps_per_actorbatch=256, 36 | clip_param=0.2, entcoeff=0.01, 37 | optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64, 38 | gamma=0.99, lam=0.95, 39 | schedule='linear' 40 | ) 41 | env.close() 42 | 43 | def main(): 44 | args = atari_arg_parser().parse_args() 45 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo1/run_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser 4 | from baselines.common import tf_util as U 5 | from baselines import logger 6 | 7 | def train(env_id, num_timesteps, seed): 8 | from baselines.ppo1 import mlp_policy, pposgd_simple 9 | U.make_session(num_cpu=1).__enter__() 10 | def policy_fn(name, ob_space, ac_space): 11 | return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, 12 | hid_size=64, num_hid_layers=2) 13 | env = make_mujoco_env(env_id, seed) 14 | pposgd_simple.learn(env, policy_fn, 15 | max_timesteps=num_timesteps, 16 | timesteps_per_actorbatch=2048, 17 | clip_param=0.2, entcoeff=0.0, 18 | optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64, 19 | gamma=0.99, lam=0.95, schedule='linear', 20 | ) 21 | env.close() 22 | 23 | def main(): 24 | args = mujoco_arg_parser().parse_args() 25 | logger.configure() 26 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo1/run_robotics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from mpi4py import MPI 4 | from baselines.common import set_global_seeds 5 | from baselines import logger 6 | from baselines.common.cmd_util import make_robotics_env, robotics_arg_parser 7 | import mujoco_py 8 | 9 | 10 | def train(env_id, num_timesteps, seed): 11 | from baselines.ppo1 import mlp_policy, pposgd_simple 12 | import baselines.common.tf_util as U 13 | rank = MPI.COMM_WORLD.Get_rank() 14 | sess = U.single_threaded_session() 15 | sess.__enter__() 16 | mujoco_py.ignore_mujoco_warnings().__enter__() 17 | workerseed = seed + 10000 * rank 18 | set_global_seeds(workerseed) 19 | env = make_robotics_env(env_id, workerseed, rank=rank) 20 | def policy_fn(name, ob_space, ac_space): 21 | return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, 22 | hid_size=256, num_hid_layers=3) 23 | 24 | pposgd_simple.learn(env, policy_fn, 25 | max_timesteps=num_timesteps, 26 | timesteps_per_actorbatch=2048, 27 | clip_param=0.2, entcoeff=0.0, 28 | optim_epochs=5, optim_stepsize=3e-4, optim_batchsize=256, 29 | gamma=0.99, lam=0.95, schedule='linear', 30 | ) 31 | env.close() 32 | 33 | 34 | def main(): 35 | args = robotics_arg_parser().parse_args() 36 | train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo2/README.md: -------------------------------------------------------------------------------- 1 | # PPO2 2 | 3 | - Original paper: https://arxiv.org/abs/1707.06347 4 | - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ 5 | 6 | - `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 7 | - `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment. 8 | - also refer to the repo-wide [README.md](../../README.md#training-models) 9 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/ppo2/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo2/defaults.py: -------------------------------------------------------------------------------- 1 | def mujoco(): 2 | return dict( 3 | nsteps=2048, 4 | nminibatches=32, 5 | lam=0.95, 6 | gamma=0.99, 7 | noptepochs=10, 8 | log_interval=1, 9 | ent_coef=0.0, 10 | lr=lambda f: 3e-4 * f, 11 | cliprange=0.2, 12 | value_network='copy' 13 | ) 14 | 15 | def robosuite(): 16 | return mujoco() 17 | 18 | def atari(): 19 | return dict( 20 | nsteps=128, nminibatches=4, 21 | lam=0.95, gamma=0.99, noptepochs=4, log_interval=1, 22 | ent_coef=.01, 23 | lr=lambda f : f * 2.5e-4, 24 | cliprange=lambda f : f * 0.1, 25 | ) 26 | 27 | def retro(): 28 | return atari() 29 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/ppo2/test_microbatches.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import tensorflow as tf 3 | import numpy as np 4 | from functools import partial 5 | 6 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 7 | from baselines.common.tf_util import make_session 8 | from baselines.ppo2.ppo2 import learn 9 | 10 | from baselines.ppo2.microbatched_model import MicrobatchedModel 11 | 12 | def test_microbatches(): 13 | def env_fn(): 14 | env = gym.make('CartPole-v0') 15 | env.seed(0) 16 | return env 17 | 18 | learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0) 19 | 20 | env_ref = DummyVecEnv([env_fn]) 21 | sess_ref = make_session(make_default=True, graph=tf.Graph()) 22 | learn_fn(env=env_ref) 23 | vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()} 24 | 25 | env_test = DummyVecEnv([env_fn]) 26 | sess_test = make_session(make_default=True, graph=tf.Graph()) 27 | learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2)) 28 | vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()} 29 | 30 | for v in vars_ref: 31 | np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3) 32 | 33 | if __name__ == '__main__': 34 | test_microbatches() 35 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/trpo_mpi/README.md: -------------------------------------------------------------------------------- 1 | # trpo_mpi 2 | 3 | - Original paper: https://arxiv.org/abs/1502.05477 4 | - Baselines blog post https://blog.openai.com/openai-baselines-ppo/ 5 | - `mpirun -np 16 python -m baselines.run --alg=trpo_mpi --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. 6 | - `python -m baselines.run --alg=trpo_mpi --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M timesteps on a Mujoco Ant environment. 7 | - also refer to the repo-wide [README.md](../../README.md#training-models) 8 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/trpo_mpi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsbrown1331/CoRL2019-DREX/23ffdd0982cb38f438a9dd84408e290f3c1b3bcc/drex-mujoco/learner/baselines/baselines/trpo_mpi/__init__.py -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/baselines/trpo_mpi/defaults.py: -------------------------------------------------------------------------------- 1 | from baselines.common.models import mlp, cnn_small 2 | 3 | 4 | def atari(): 5 | return dict( 6 | network = cnn_small(), 7 | timesteps_per_batch=512, 8 | max_kl=0.001, 9 | cg_iters=10, 10 | cg_damping=1e-3, 11 | gamma=0.98, 12 | lam=1.0, 13 | vf_iters=3, 14 | vf_stepsize=1e-4, 15 | entcoeff=0.00, 16 | ) 17 | 18 | def mujoco(): 19 | return dict( 20 | network = mlp(num_hidden=32, num_layers=2), 21 | timesteps_per_batch=1024, 22 | max_kl=0.01, 23 | cg_iters=10, 24 | cg_damping=0.1, 25 | gamma=0.99, 26 | lam=0.98, 27 | vf_iters=5, 28 | vf_stepsize=1e-3, 29 | normalize_observations=True, 30 | ) 31 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = F,E999,W291,W293 3 | exclude = 4 | .git, 5 | __pycache__, 6 | baselines/her, 7 | baselines/ppo1, 8 | baselines/bench, 9 | -------------------------------------------------------------------------------- /drex-mujoco/learner/baselines/setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from setuptools import setup, find_packages 3 | import sys 4 | 5 | if sys.version_info.major != 3: 6 | print('This Python is only compatible with Python 3, but you are running ' 7 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major)) 8 | 9 | 10 | extras = { 11 | 'test': [ 12 | 'filelock', 13 | 'pytest', 14 | 'pytest-forked', 15 | 'atari-py' 16 | ], 17 | 'bullet': [ 18 | 'pybullet', 19 | ], 20 | 'mpi': [ 21 | 'mpi4py' 22 | ] 23 | } 24 | 25 | all_deps = [] 26 | for group_name in extras: 27 | all_deps += extras[group_name] 28 | 29 | extras['all'] = all_deps 30 | 31 | setup(name='baselines', 32 | packages=[package for package in find_packages() 33 | if package.startswith('baselines')], 34 | install_requires=[ 35 | 'gym', 36 | 'scipy', 37 | 'tqdm', 38 | 'joblib', 39 | 'dill', 40 | 'progressbar2', 41 | 'cloudpickle', 42 | 'click', 43 | 'opencv-python' 44 | ], 45 | extras_require=extras, 46 | description='OpenAI baselines: high quality implementations of reinforcement learning algorithms', 47 | author='OpenAI', 48 | url='https://github.com/openai/baselines', 49 | author_email='gym@openai.com', 50 | version='0.1.5') 51 | 52 | 53 | # ensure there is some tensorflow build with version above 1.4 54 | import pkg_resources 55 | tf_pkg = None 56 | for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']: 57 | try: 58 | tf_pkg = pkg_resources.get_distribution(tf_pkg_name) 59 | except pkg_resources.DistributionNotFound: 60 | pass 61 | assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4' 62 | from distutils.version import StrictVersion 63 | assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0') 64 | --------------------------------------------------------------------------------