├── .codecov.yml ├── .dockerignore ├── .editorconfig ├── .github └── workflows │ ├── ci-release-2021.03.yml │ ├── ci.yml │ └── ok-to-test.yml ├── .gitignore ├── .mdlrc ├── .mergify.yml ├── .pre-commit-config.yaml ├── .pylintrc ├── CHANGELOG.md ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── VERSION ├── benchmarks ├── README.md ├── setup.py └── src │ └── garage_benchmarks │ ├── __init__.py │ ├── benchmark_algos.py │ ├── benchmark_auto.py │ ├── benchmark_baselines.py │ ├── benchmark_policies.py │ ├── benchmark_q_functions.py │ ├── benchmarks.py │ ├── experiments │ ├── __init__.py │ ├── algos │ │ ├── __init__.py │ │ ├── ddpg_garage_tf.py │ │ ├── her_garage_tf.py │ │ ├── ppo_garage_pytorch.py │ │ ├── ppo_garage_tf.py │ │ ├── td3_garage_pytorch.py │ │ ├── td3_garage_tf.py │ │ ├── trpo_garage_pytorch.py │ │ ├── trpo_garage_tf.py │ │ ├── vpg_garage_pytorch.py │ │ └── vpg_garage_tf.py │ ├── baselines │ │ ├── __init__.py │ │ ├── continuous_mlp_baseline.py │ │ ├── gaussian_cnn_baseline.py │ │ └── gaussian_mlp_baseline.py │ ├── policies │ │ ├── __init__.py │ │ ├── categorical_cnn_policy.py │ │ ├── categorical_gru_policy.py │ │ ├── categorical_lstm_policy.py │ │ ├── categorical_mlp_policy.py │ │ ├── continuous_mlp_policy.py │ │ ├── gaussian_gru_policy.py │ │ ├── gaussian_lstm_policy.py │ │ └── gaussian_mlp_policy.py │ └── q_functions │ │ ├── __init__.py │ │ └── continuous_mlp_q_function.py │ ├── helper.py │ ├── parameters.py │ └── run_benchmarks.py ├── docker ├── Dockerfile ├── entrypoint-headless.sh ├── entrypoint-runtime.sh └── hooks │ └── build ├── docs ├── Makefile ├── _static │ ├── resl_logo.png │ ├── theme_overrides.css │ └── viterbi_logo.png ├── _templates │ └── footer.html ├── autoapi_templates │ └── python │ │ └── module.rst ├── bibtex.json ├── conf.py ├── index.md ├── requirements.txt └── user │ ├── algo_bc.md │ ├── algo_cem.md │ ├── algo_ddpg.md │ ├── algo_dqn.md │ ├── algo_erwr.md │ ├── algo_maml.md │ ├── algo_mtppo.md │ ├── algo_mtsac.md │ ├── algo_mttrpo.md │ ├── algo_pearl.md │ ├── algo_ppo.md │ ├── algo_rl2.md │ ├── algo_sac.md │ ├── algo_td3.md │ ├── algo_trpo.md │ ├── algo_vpg.md │ ├── benchmarking.md │ ├── cluster_setup.md │ ├── concept_experiment.md │ ├── custom_worker.md │ ├── docker.md │ ├── docker_dev.md │ ├── ensure_your_experiments_are_reproducible.md │ ├── environment.md │ ├── environment_libraries.md │ ├── experiments.rst │ ├── get_started.md │ ├── git_workflow.md │ ├── images │ ├── bc_meanLoss.png │ ├── bc_stdLoss.png │ ├── dqn_plots.png │ ├── numpy.png │ ├── pytorch.png │ ├── td3_tf_HalfCheetah-v2.png │ ├── td3_tf_Hopper-v2.png │ ├── td3_tf_InvertedDoublePendulum-v2.png │ ├── td3_tf_InvertedPendulum-v2.png │ ├── td3_tf_Swimmer-v2.png │ ├── td3_torch_HalfCheetah-v2.png │ ├── td3_torch_Hopper-v2.png │ ├── td3_torch_InvertedDoublePendulum-v2.png │ ├── td3_torch_InvertedPendulum-v2.png │ ├── td3_torch_Swimmer-v2.png │ ├── td3_torch_Walker2d-v2.png │ └── tf.png │ ├── implement_algo.md │ ├── implement_env.md │ ├── implement_env.rst │ ├── implement_worker.md │ ├── installation.rst │ ├── logging_plotting.md │ ├── matplotlib_example.png │ ├── max_resource_usage.md │ ├── meta_multi_task_rl_exp.md │ ├── monitor_experiments_with_tensorboard.md │ ├── pixel_observations.md │ ├── preparing_a_pr.md │ ├── references.bib │ ├── reuse_garage_policy.md │ ├── sampling.md │ ├── save_load_resume_exp.md │ ├── setting_up_your_development_environment.md │ ├── testing.md │ ├── training_a_policy.md │ ├── use_pretrained_network_to_start_new_experiment.md │ └── writing_documentation.md ├── readthedocs.yml ├── scripts ├── check_commit_message ├── ci │ ├── check_docs_only.sh │ ├── check_no_deps_changed.sh │ └── check_precommit.sh ├── garage ├── setup_colab.sh ├── setup_linux.sh └── setup_macos.sh ├── setup.cfg ├── setup.py ├── src └── garage │ ├── __init__.py │ ├── _dtypes.py │ ├── _environment.py │ ├── _functions.py │ ├── envs │ ├── __init__.py │ ├── bullet │ │ ├── __init__.py │ │ └── bullet_env.py │ ├── dm_control │ │ ├── __init__.py │ │ ├── dm_control_env.py │ │ └── dm_control_viewer.py │ ├── grid_world_env.py │ ├── gym_env.py │ ├── metaworld_set_task_env.py │ ├── mujoco │ │ ├── __init__.py │ │ ├── half_cheetah_dir_env.py │ │ ├── half_cheetah_env_meta_base.py │ │ └── half_cheetah_vel_env.py │ ├── multi_env_wrapper.py │ ├── normalized_env.py │ ├── point_env.py │ ├── task_name_wrapper.py │ ├── task_onehot_wrapper.py │ └── wrappers │ │ ├── __init__.py │ │ ├── atari_env.py │ │ ├── clip_reward.py │ │ ├── episodic_life.py │ │ ├── fire_reset.py │ │ ├── grayscale.py │ │ ├── max_and_skip.py │ │ ├── noop.py │ │ ├── pixel_observation.py │ │ ├── resize.py │ │ └── stack_frames.py │ ├── examples │ ├── jupyter │ │ ├── custom_env.ipynb │ │ └── trpo_gym_tf_cartpole.ipynb │ ├── np │ │ ├── cem_cartpole.py │ │ ├── cma_es_cartpole.py │ │ └── tutorial_cem.py │ ├── sim_policy.py │ ├── step_bullet_kuka_env.py │ ├── step_dm_control_env.py │ ├── step_gym_env.py │ ├── tf │ │ ├── ddpg_pendulum.py │ │ ├── dqn_cartpole.py │ │ ├── dqn_pong.py │ │ ├── erwr_cartpole.py │ │ ├── her_ddpg_fetchreach.py │ │ ├── multi_env_ppo.py │ │ ├── multi_env_trpo.py │ │ ├── ppo_memorize_digits.py │ │ ├── ppo_pendulum.py │ │ ├── reps_gym_cartpole.py │ │ ├── resume_training.py │ │ ├── rl2_ppo_halfcheetah.py │ │ ├── rl2_ppo_halfcheetah_meta_test.py │ │ ├── rl2_ppo_metaworld_ml10.py │ │ ├── rl2_ppo_metaworld_ml1_push.py │ │ ├── rl2_ppo_metaworld_ml45.py │ │ ├── rl2_trpo_halfcheetah.py │ │ ├── td3_pendulum.py │ │ ├── te_ppo_metaworld_mt10.py │ │ ├── te_ppo_metaworld_mt1_push.py │ │ ├── te_ppo_metaworld_mt50.py │ │ ├── te_ppo_point.py │ │ ├── trpo_cartpole.py │ │ ├── trpo_cartpole_bullet.py │ │ ├── trpo_cartpole_recurrent.py │ │ ├── trpo_cubecrash.py │ │ ├── trpo_gym_tf_cartpole.py │ │ ├── trpo_gym_tf_cartpole_pretrained.py │ │ ├── trpo_swimmer.py │ │ ├── trpo_swimmer_ray_sampler.py │ │ ├── tutorial_vpg.py │ │ └── vpg_cartpole.py │ └── torch │ │ ├── bc_point.py │ │ ├── bc_point_deterministic_policy.py │ │ ├── ddpg_pendulum.py │ │ ├── dqn_atari.py │ │ ├── dqn_cartpole.py │ │ ├── maml_ppo_half_cheetah_dir.py │ │ ├── maml_trpo_half_cheetah_dir.py │ │ ├── maml_trpo_metaworld_ml10.py │ │ ├── maml_trpo_metaworld_ml1_push.py │ │ ├── maml_trpo_metaworld_ml45.py │ │ ├── maml_vpg_half_cheetah_dir.py │ │ ├── mtppo_metaworld_mt10.py │ │ ├── mtppo_metaworld_mt1_push.py │ │ ├── mtppo_metaworld_mt50.py │ │ ├── mtsac_metaworld_mt10.py │ │ ├── mtsac_metaworld_mt1_pick_place.py │ │ ├── mtsac_metaworld_mt50.py │ │ ├── mttrpo_metaworld_mt10.py │ │ ├── mttrpo_metaworld_mt1_push.py │ │ ├── mttrpo_metaworld_mt50.py │ │ ├── pearl_half_cheetah_vel.py │ │ ├── pearl_metaworld_ml10.py │ │ ├── pearl_metaworld_ml1_push.py │ │ ├── pearl_metaworld_ml45.py │ │ ├── ppo_pendulum.py │ │ ├── resume_training.py │ │ ├── sac_half_cheetah_batch.py │ │ ├── td3_halfcheetah.py │ │ ├── td3_pendulum.py │ │ ├── trpo_pendulum.py │ │ ├── trpo_pendulum_ray_sampler.py │ │ ├── tutorial_vpg.py │ │ ├── vpg_pendulum.py │ │ └── watch_atari.py │ ├── experiment │ ├── __init__.py │ ├── deterministic.py │ ├── experiment.py │ ├── meta_evaluator.py │ ├── snapshotter.py │ └── task_sampler.py │ ├── np │ ├── __init__.py │ ├── _functions.py │ ├── algos │ │ ├── __init__.py │ │ ├── cem.py │ │ ├── cma_es.py │ │ ├── meta_rl_algorithm.py │ │ ├── nop.py │ │ └── rl_algorithm.py │ ├── baselines │ │ ├── __init__.py │ │ ├── baseline.py │ │ ├── linear_feature_baseline.py │ │ ├── linear_multi_feature_baseline.py │ │ └── zero_baseline.py │ ├── embeddings │ │ ├── __init__.py │ │ └── encoder.py │ ├── exploration_policies │ │ ├── __init__.py │ │ ├── add_gaussian_noise.py │ │ ├── add_ornstein_uhlenbeck_noise.py │ │ ├── epsilon_greedy_policy.py │ │ └── exploration_policy.py │ ├── optimizers │ │ ├── __init__.py │ │ └── minibatch_dataset.py │ ├── policies │ │ ├── __init__.py │ │ ├── fixed_policy.py │ │ ├── policy.py │ │ ├── scripted_policy.py │ │ └── uniform_random_policy.py │ └── q_functions │ │ ├── __init__.py │ │ └── q_function.py │ ├── plotter │ ├── __init__.py │ └── plotter.py │ ├── replay_buffer │ ├── __init__.py │ ├── her_replay_buffer.py │ ├── path_buffer.py │ └── replay_buffer.py │ ├── sampler │ ├── __init__.py │ ├── _dtypes.py │ ├── _functions.py │ ├── default_worker.py │ ├── env_update.py │ ├── fragment_worker.py │ ├── local_sampler.py │ ├── multiprocessing_sampler.py │ ├── ray_sampler.py │ ├── sampler.py │ ├── utils.py │ ├── vec_worker.py │ ├── worker.py │ └── worker_factory.py │ ├── tf │ ├── __init__.py │ ├── _functions.py │ ├── algos │ │ ├── __init__.py │ │ ├── _rl2npo.py │ │ ├── ddpg.py │ │ ├── dqn.py │ │ ├── erwr.py │ │ ├── npo.py │ │ ├── ppo.py │ │ ├── reps.py │ │ ├── rl2.py │ │ ├── rl2ppo.py │ │ ├── rl2trpo.py │ │ ├── td3.py │ │ ├── te.py │ │ ├── te_npo.py │ │ ├── te_ppo.py │ │ ├── tnpg.py │ │ ├── trpo.py │ │ └── vpg.py │ ├── baselines │ │ ├── __init__.py │ │ ├── continuous_mlp_baseline.py │ │ ├── gaussian_cnn_baseline.py │ │ ├── gaussian_cnn_baseline_model.py │ │ ├── gaussian_mlp_baseline.py │ │ └── gaussian_mlp_baseline_model.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── encoder.py │ │ └── gaussian_mlp_encoder.py │ ├── models │ │ ├── __init__.py │ │ ├── categorical_cnn_model.py │ │ ├── categorical_gru_model.py │ │ ├── categorical_lstm_model.py │ │ ├── categorical_mlp_model.py │ │ ├── cnn.py │ │ ├── cnn_mlp_merge_model.py │ │ ├── cnn_model.py │ │ ├── cnn_model_max_pooling.py │ │ ├── gaussian_cnn_model.py │ │ ├── gaussian_gru_model.py │ │ ├── gaussian_lstm_model.py │ │ ├── gaussian_mlp_model.py │ │ ├── gru.py │ │ ├── gru_model.py │ │ ├── lstm.py │ │ ├── lstm_model.py │ │ ├── mlp.py │ │ ├── mlp_dueling_model.py │ │ ├── mlp_merge_model.py │ │ ├── mlp_model.py │ │ ├── model.py │ │ ├── module.py │ │ ├── normalized_input_mlp_model.py │ │ ├── parameter.py │ │ └── sequential.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── _dtypes.py │ │ ├── conjugate_gradient_optimizer.py │ │ ├── first_order_optimizer.py │ │ ├── lbfgs_optimizer.py │ │ └── penalty_lbfgs_optimizer.py │ ├── plotter │ │ ├── __init__.py │ │ └── plotter.py │ ├── policies │ │ ├── __init__.py │ │ ├── categorical_cnn_policy.py │ │ ├── categorical_gru_policy.py │ │ ├── categorical_lstm_policy.py │ │ ├── categorical_mlp_policy.py │ │ ├── continuous_mlp_policy.py │ │ ├── discrete_qf_argmax_policy.py │ │ ├── gaussian_gru_policy.py │ │ ├── gaussian_lstm_policy.py │ │ ├── gaussian_mlp_policy.py │ │ ├── gaussian_mlp_task_embedding_policy.py │ │ ├── policy.py │ │ ├── task_embedding_policy.py │ │ └── uniform_control_policy.py │ ├── q_functions │ │ ├── __init__.py │ │ ├── continuous_cnn_q_function.py │ │ ├── continuous_mlp_q_function.py │ │ ├── discrete_cnn_q_function.py │ │ ├── discrete_mlp_dueling_q_function.py │ │ └── discrete_mlp_q_function.py │ └── samplers │ │ ├── __init__.py │ │ └── worker.py │ ├── torch │ ├── __init__.py │ ├── _functions.py │ ├── algos │ │ ├── __init__.py │ │ ├── bc.py │ │ ├── ddpg.py │ │ ├── dqn.py │ │ ├── maml.py │ │ ├── maml_ppo.py │ │ ├── maml_trpo.py │ │ ├── maml_vpg.py │ │ ├── mtsac.py │ │ ├── pearl.py │ │ ├── ppo.py │ │ ├── sac.py │ │ ├── td3.py │ │ ├── trpo.py │ │ └── vpg.py │ ├── distributions │ │ ├── __init__.py │ │ └── tanh_normal.py │ ├── embeddings │ │ ├── __init__.py │ │ └── mlp_encoder.py │ ├── modules │ │ ├── __init__.py │ │ ├── cnn_module.py │ │ ├── discrete_cnn_module.py │ │ ├── gaussian_mlp_module.py │ │ ├── mlp_module.py │ │ └── multi_headed_mlp_module.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── conjugate_gradient_optimizer.py │ │ ├── differentiable_sgd.py │ │ └── optimizer_wrapper.py │ ├── policies │ │ ├── __init__.py │ │ ├── categorical_cnn_policy.py │ │ ├── context_conditioned_policy.py │ │ ├── deterministic_mlp_policy.py │ │ ├── discrete_cnn_policy.py │ │ ├── discrete_qf_argmax_policy.py │ │ ├── gaussian_mlp_policy.py │ │ ├── policy.py │ │ ├── stochastic_policy.py │ │ └── tanh_gaussian_mlp_policy.py │ ├── q_functions │ │ ├── __init__.py │ │ ├── continuous_mlp_q_function.py │ │ ├── discrete_cnn_q_function.py │ │ ├── discrete_dueling_cnn_q_function.py │ │ └── discrete_mlp_q_function.py │ └── value_functions │ │ ├── __init__.py │ │ ├── gaussian_mlp_value_function.py │ │ └── value_function.py │ └── trainer.py └── tests ├── __init__.py ├── fixtures ├── __init__.py ├── algos │ ├── __init__.py │ ├── dummy_algo.py │ └── dummy_tf_algo.py ├── envs │ ├── __init__.py │ ├── dummy │ │ ├── __init__.py │ │ ├── base.py │ │ ├── dummy_box_env.py │ │ ├── dummy_dict_env.py │ │ ├── dummy_discrete_2d_env.py │ │ ├── dummy_discrete_env.py │ │ ├── dummy_discrete_pixel_env.py │ │ ├── dummy_discrete_pixel_env_baselines.py │ │ ├── dummy_multitask_box_env.py │ │ └── dummy_reward_box_env.py │ └── wrappers │ │ ├── __init__.py │ │ └── reshape_observation.py ├── experiment │ ├── __init__.py │ └── fixture_experiment.py ├── fixtures.py ├── logger.py ├── models │ ├── __init__.py │ ├── simple_categorical_gru_model.py │ ├── simple_categorical_lstm_model.py │ ├── simple_categorical_mlp_model.py │ ├── simple_cnn_model.py │ ├── simple_cnn_model_with_max_pooling.py │ ├── simple_gru_model.py │ ├── simple_lstm_model.py │ ├── simple_mlp_merge_model.py │ └── simple_mlp_model.py ├── policies │ ├── __init__.py │ ├── dummy_policy.py │ └── dummy_recurrent_policy.py ├── q_functions │ ├── __init__.py │ └── simple_q_function.py ├── sampler │ ├── __init__.py │ └── ray_fixtures.py └── tf │ ├── __init__.py │ └── algos │ └── dummy_off_policy_algo.py ├── garage ├── .pylintrc ├── __init__.py ├── envs │ ├── __init__.py │ ├── box2d │ │ └── parser │ │ │ └── __init__.py │ ├── bullet │ │ ├── __init__.py │ │ └── test_bullet_env.py │ ├── dm_control │ │ ├── __init__.py │ │ ├── test_dm_control_env.py │ │ └── test_dm_control_tf_policy.py │ ├── test_grid_world_env.py │ ├── test_gym_env.py │ ├── test_half_cheetah_meta_envs.py │ ├── test_metaworld_set_task_env.py │ ├── test_multi_env_wrapper.py │ ├── test_normalized_env.py │ ├── test_normalized_gym.py │ ├── test_point_env.py │ ├── test_rl2_env.py │ ├── test_task_onehot_wrapper.py │ └── wrappers │ │ ├── __init__.py │ │ ├── test_atari_env.py │ │ ├── test_clip_reward.py │ │ ├── test_episodic_life.py │ │ ├── test_fire_reset.py │ │ ├── test_grayscale_env.py │ │ ├── test_max_and_skip.py │ │ ├── test_noop.py │ │ ├── test_pixel_observation_wrapper.py │ │ ├── test_resize_env.py │ │ └── test_stack_frames_env.py ├── experiment │ ├── __init__.py │ ├── test_deterministic.py │ ├── test_experiment.py │ ├── test_meta_evaluator.py │ ├── test_resume.py │ ├── test_snapshotter.py │ ├── test_snapshotter_integration.py │ ├── test_task_sampler.py │ └── test_trainer.py ├── np │ ├── __init__.py │ ├── algos │ │ ├── __init__.py │ │ ├── test_cem.py │ │ └── test_cma_es.py │ ├── exploration_strategies │ │ ├── __init__.py │ │ ├── test_add_gaussian_noise.py │ │ └── test_epsilon_greedy_policy.py │ ├── policies │ │ ├── test_fixed_policy.py │ │ ├── test_scripted_policy.py │ │ └── test_uniform_random_policy.py │ └── test_functions.py ├── replay_buffer │ ├── __init__.py │ ├── test_her_replay_buffer.py │ └── test_path_buffer.py ├── sampler │ ├── __init__.py │ ├── test_env_update.py │ ├── test_fragment_worker.py │ ├── test_local_sampler.py │ ├── test_multiprocessing_sampler.py │ ├── test_ray_batched_sampler.py │ ├── test_rl2_worker.py │ └── test_vec_worker.py ├── test_dtypes.py ├── test_environment.py ├── test_functions.py ├── tf │ ├── __init__.py │ ├── algos │ │ ├── __init__.py │ │ ├── test_ddpg.py │ │ ├── test_dqn.py │ │ ├── test_erwr.py │ │ ├── test_npo.py │ │ ├── test_ppo.py │ │ ├── test_reps.py │ │ ├── test_rl2ppo.py │ │ ├── test_rl2trpo.py │ │ ├── test_td3.py │ │ ├── test_te.py │ │ ├── test_tnpg.py │ │ ├── test_trpo.py │ │ └── test_vpg.py │ ├── baselines │ │ ├── __init__.py │ │ ├── test_baselines.py │ │ ├── test_continuous_mlp_baseline.py │ │ ├── test_gaussian_cnn_baseline.py │ │ └── test_gaussian_mlp_baseline.py │ ├── embeddings │ │ ├── __init__.py │ │ └── test_gaussian_mlp_encoder.py │ ├── envs │ │ ├── __init__.py │ │ └── test_gym_base.py │ ├── experiment │ │ ├── __init__.py │ │ └── test_trainer.py │ ├── models │ │ ├── __init__.py │ │ ├── test_categorical_cnn_model.py │ │ ├── test_categorical_gru_model.py │ │ ├── test_categorical_lstm_model.py │ │ ├── test_categorical_mlp_model.py │ │ ├── test_cnn.py │ │ ├── test_cnn_mlp_merge_model.py │ │ ├── test_cnn_model.py │ │ ├── test_gaussian_cnn_model.py │ │ ├── test_gaussian_gru_model.py │ │ ├── test_gaussian_lstm_model.py │ │ ├── test_gaussian_mlp_model.py │ │ ├── test_gru.py │ │ ├── test_gru_model.py │ │ ├── test_lstm.py │ │ ├── test_lstm_model.py │ │ ├── test_mlp.py │ │ ├── test_mlp_concat.py │ │ ├── test_mlp_model.py │ │ ├── test_model.py │ │ └── test_parameter.py │ ├── optimizers │ │ ├── __init__.py │ │ └── test_conjugate_gradient_optimizer.py │ ├── policies │ │ ├── __init__.py │ │ ├── test_categorical_cnn_policy.py │ │ ├── test_categorical_gru_policy.py │ │ ├── test_categorical_lstm_policy.py │ │ ├── test_categorical_mlp_policy.py │ │ ├── test_categorical_policies.py │ │ ├── test_continuous_mlp_policy.py │ │ ├── test_discrete_qf_argmax_policy.py │ │ ├── test_gaussian_gru_policy.py │ │ ├── test_gaussian_lstm_policy.py │ │ ├── test_gaussian_mlp_policy.py │ │ ├── test_gaussian_mlp_task_embedding_policy.py │ │ ├── test_gaussian_policies.py │ │ └── test_policies.py │ ├── q_functions │ │ ├── __init__.py │ │ ├── test_continuous_cnn_q_function.py │ │ ├── test_continuous_mlp_q_function.py │ │ ├── test_discrete_cnn_q_function.py │ │ ├── test_discrete_mlp_dueling_q_function.py │ │ └── test_discrete_mlp_q_function.py │ ├── samplers │ │ ├── __init__.py │ │ ├── test_ray_batched_sampler_tf.py │ │ ├── test_task_embedding_worker.py │ │ └── test_tf_worker.py │ └── test_functions.py └── torch │ ├── __init__.py │ ├── algos │ ├── __init__.py │ ├── test_bc.py │ ├── test_ddpg.py │ ├── test_dqn.py │ ├── test_maml.py │ ├── test_maml_ppo.py │ ├── test_maml_trpo.py │ ├── test_maml_vpg.py │ ├── test_mtsac.py │ ├── test_pearl.py │ ├── test_pearl_worker.py │ ├── test_ppo.py │ ├── test_sac.py │ ├── test_td3.py │ ├── test_trpo.py │ └── test_vpg.py │ ├── distributions │ └── test_tanh_normal_dist.py │ ├── modules │ ├── __init__.py │ ├── test_cnn_module.py │ ├── test_discrete_cnn_module.py │ ├── test_gaussian_mlp_module.py │ ├── test_mlp_module.py │ └── test_multi_headed_mlp_module.py │ ├── optimizers │ ├── test_differentiable_sgd.py │ └── test_torch_conjugate_gradient_optimizer.py │ ├── policies │ ├── __init__.py │ ├── test_categorical_cnn_policy.py │ ├── test_context_conditioned_policy.py │ ├── test_deterministic_mlp_policy.py │ ├── test_discrete_cnn_policy.py │ ├── test_discrete_qf_argmax_policy.py │ ├── test_gaussian_mlp_policy.py │ └── test_tanh_gaussian_mlp_policy.py │ ├── q_functions │ ├── __init__.py │ ├── test_continuous_mlp_q_function.py │ ├── test_discrete_cnn_q_function.py │ ├── test_discrete_dueling_cnn_q_function.py │ └── test_discrete_mlp_q_function.py │ └── test_functions.py ├── helpers.py ├── integration_tests ├── __init__.py ├── test_examples.py └── test_sigint.py ├── mock.py ├── quirks.py └── wrappers.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 60..100 3 | status: 4 | patch: 5 | default: 6 | target: 90% 7 | 8 | codecov: 9 | ci: 10 | - "travis-ci.com" 11 | notify: 12 | wait_for_ci: yes 13 | after_n_builds: 4 14 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | **/data 2 | Pipfile* 3 | .pytest_cache 4 | .idea 5 | **/dist 6 | __pycache__ 7 | venv 8 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # garage EditorConfig 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | charset = utf-8 9 | end_of_line = lf 10 | insert_final_newline = true 11 | trim_trailing_whitespace = true 12 | 13 | # Python config 14 | [*.py] 15 | indent_style = space 16 | indent_size = 4 17 | max_line_length = 79 18 | 19 | # HTML, XML, and XML-like formats 20 | [*.{urdf,html,launch,mako,urdf,world,xml,css}] 21 | indent_style = space 22 | indent_size = 2 23 | 24 | # JS 25 | [*.js] 26 | indent_style = space 27 | indent_size = 2 28 | max_line_length = 80 29 | 30 | # YAML files 31 | [*.{yaml, yml}] 32 | indent_style = space 33 | indent_size = 2 34 | 35 | # Shell scripts 36 | [*.sh] 37 | indent_style = space 38 | indent_size = 2 39 | max_line_length = 80 40 | 41 | # Markdown 42 | [*.md] 43 | indent_style = space 44 | max_line_length = 80 45 | -------------------------------------------------------------------------------- /.github/workflows/ok-to-test.yml: -------------------------------------------------------------------------------- 1 | # If someone with write access comments "/ok-to-test" on a pull request, emit a repository_dispatch event 2 | name: Ok To Test 3 | on: 4 | issue_comment: 5 | types: [created] 6 | jobs: 7 | ok-to-test: 8 | if: ${{ github.event.issue.pull_request }} 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Get PR details 12 | id: pr_details 13 | run: | 14 | pr_details=$(curl -v -H "Accept: application/vnd.github.sailor-v-preview+json" -u ${{ secrets.CI_REGISTRY_TOKEN }} ${{ github.event.issue.pull_request.url }}) 15 | echo "::set-output name=is_from_fork::$(echo $pr_details | jq '.head.repo.fork')" 16 | echo "::set-output name=base_ref::$(echo $pr_details | jq '.base.ref' | sed 's/\"//g')" 17 | - name: Slash Command Dispatch 18 | uses: peter-evans/slash-command-dispatch@v1 19 | if: ${{ steps.pr_details.outputs.is_from_fork && steps.pr_details.outputs.base_ref == 'master' }} 20 | with: 21 | token: ${{ secrets.CI_REGISTRY_TOKEN }} 22 | reaction-token: ${{ secrets.GITHUB_TOKEN }} 23 | issue-type: pull-request 24 | commands: ok-to-test 25 | named-args: true 26 | permission: write 27 | -------------------------------------------------------------------------------- /.mdlrc: -------------------------------------------------------------------------------- 1 | tags "tables", "~MD013" 2 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | # Check docstring completeness 3 | load-plugins = pylint.extensions.docparams, pylint.extensions.docstyle 4 | 5 | # Unit tests have a special configuration which is checked separately 6 | ignore = tests/garage, benchmarks 7 | 8 | # Go as fast as you can 9 | jobs = 0 10 | 11 | # Packages which we need to load so we can see their C extensions 12 | extension-pkg-whitelist = 13 | numpy.random, 14 | mpi4py.MPI, 15 | 16 | 17 | [MESSAGES CONTROL] 18 | enable = all 19 | disable = 20 | # Style rules handled by yapf/flake8/isort 21 | bad-continuation, 22 | invalid-name, 23 | line-too-long, 24 | wrong-import-order, 25 | # Class must have 2 public methods seems arbitary 26 | too-few-public-methods, 27 | # Algorithms and neural networks generally have a lot of variables 28 | too-many-instance-attributes, 29 | too-many-arguments, 30 | too-many-locals, 31 | # Detection seems buggy or unhelpful 32 | duplicate-code, 33 | # Allow more readable code 34 | no-else-return, 35 | # Discourages small interfaces 36 | too-few-public-methods, 37 | 38 | [REPORTS] 39 | msg-template = {path}:{line:3d},{column}: {msg} ({symbol}) 40 | output-format = colorized 41 | 42 | 43 | [TYPECHECK] 44 | # Packages which might not admit static analysis because they have C extensions 45 | generated-members = torch.* 46 | 47 | 48 | [PARAMETER_DOCUMENTATION] 49 | # Docstrings are required and should be complete 50 | accept-no-param-doc=no 51 | accept-no-raise-doc=no 52 | accept-no-return-doc=no 53 | accept-no-yields-doc=no 54 | default-docstring-type=google 55 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # garage Maintainers at root 2 | * @rlworkgroup/maintainers 3 | 4 | # TensorFlow tree 5 | /garage/tf/* @ahtsan 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Reinforcement Learning Working Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # project metadata 2 | # 3 | # NOTE: README.md and VERSION are required to run setup.py. Failure to include 4 | # them will create a broken PyPI distribution. 5 | include README.md 6 | include VERSION 7 | include LICENSE 8 | include CONTRIBUTING.md 9 | include CHANGELOG.md 10 | 11 | # tests 12 | graft tests 13 | include setup.cfg 14 | 15 | # documentation 16 | graft docs 17 | prune docs/_build 18 | 19 | # examples, scripts, etc. 20 | include Makefile 21 | graft docker 22 | graft src/garage/examples 23 | graft scripts 24 | 25 | # ignored files 26 | global-exclude *.py[co] 27 | global-exclude .DS_Store 28 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | v2020.09.0rc2-dev 2 | -------------------------------------------------------------------------------- /benchmarks/setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for garage benchmarking scripts. 2 | 3 | This package is generally not needed by users of garage. 4 | """ 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | GARAGE_GH_TOKEN = os.environ.get('GARAGE_GH_TOKEN') or 'git' 10 | 11 | REQUIRED = [ 12 | # Please keep alphabetized 13 | 'google-cloud-storage', 14 | 'matplotlib' 15 | ] # yapf: disable 16 | 17 | setup(name='garage_benchmarks', 18 | packages=find_packages(where='src'), 19 | package_dir={'': 'src'}, 20 | install_requires=REQUIRED, 21 | include_package_data=True, 22 | entry_points=''' 23 | [console_scripts] 24 | garage_benchmark=garage_benchmarks.run_benchmarks:cli 25 | ''') 26 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | """garage benchmarks.""" 2 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/benchmark_baselines.py: -------------------------------------------------------------------------------- 1 | """Benchmarking for baselines.""" 2 | import random 3 | 4 | from garage_benchmarks.experiments.baselines import (continuous_mlp_baseline, 5 | gaussian_cnn_baseline, 6 | gaussian_mlp_baseline) 7 | from garage_benchmarks.helper import benchmark, iterate_experiments 8 | from garage_benchmarks.parameters import MuJoCo1M_ENV_SET, PIXEL_ENV_SET 9 | 10 | _seeds = random.sample(range(100), 3) 11 | 12 | 13 | @benchmark 14 | def continuous_mlp_baseline_tf_ppo_benchmarks(): 15 | """Run benchmarking experiments for Continuous MLP Baseline on TF-PPO.""" 16 | iterate_experiments(continuous_mlp_baseline, 17 | MuJoCo1M_ENV_SET, 18 | seeds=_seeds) 19 | 20 | 21 | @benchmark 22 | def gaussian_cnn_baseline_tf_ppo_benchmarks(): 23 | """Run benchmarking experiments for Gaussian CNN Baseline on TF-PPO.""" 24 | iterate_experiments(gaussian_cnn_baseline, PIXEL_ENV_SET, seeds=_seeds) 25 | 26 | 27 | @benchmark 28 | def gaussian_mlp_baseline_tf_ppo_benchmarks(): 29 | """Run benchmarking experiments for Gaussian MLP Baseline on TF-PPO.""" 30 | iterate_experiments(gaussian_mlp_baseline, MuJoCo1M_ENV_SET, seeds=_seeds) 31 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/benchmark_q_functions.py: -------------------------------------------------------------------------------- 1 | """Benchmarking for q-functions.""" 2 | import random 3 | 4 | from garage_benchmarks.experiments.q_functions import continuous_mlp_q_function 5 | from garage_benchmarks.helper import benchmark, iterate_experiments 6 | from garage_benchmarks.parameters import MuJoCo1M_ENV_SET 7 | 8 | _seeds = random.sample(range(100), 5) 9 | 10 | 11 | @benchmark 12 | def continuous_mlp_q_function_tf_ddpg_benchmarks(): 13 | """Run benchmarking experiments for Continuous MLP QFunction on TF-DDPG.""" 14 | iterate_experiments(continuous_mlp_q_function, 15 | MuJoCo1M_ENV_SET, 16 | seeds=_seeds) 17 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiments.""" 2 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/algos/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiments for algorithms.""" 2 | from garage_benchmarks.experiments.algos.ddpg_garage_tf import ddpg_garage_tf 3 | from garage_benchmarks.experiments.algos.her_garage_tf import her_garage_tf 4 | from garage_benchmarks.experiments.algos.ppo_garage_pytorch import ( 5 | ppo_garage_pytorch) 6 | from garage_benchmarks.experiments.algos.ppo_garage_tf import ppo_garage_tf 7 | from garage_benchmarks.experiments.algos.td3_garage_pytorch import ( 8 | td3_garage_pytorch) 9 | from garage_benchmarks.experiments.algos.td3_garage_tf import td3_garage_tf 10 | from garage_benchmarks.experiments.algos.trpo_garage_pytorch import ( 11 | trpo_garage_pytorch) 12 | from garage_benchmarks.experiments.algos.trpo_garage_tf import trpo_garage_tf 13 | from garage_benchmarks.experiments.algos.vpg_garage_pytorch import ( 14 | vpg_garage_pytorch) 15 | from garage_benchmarks.experiments.algos.vpg_garage_tf import vpg_garage_tf 16 | 17 | __all__ = [ 18 | 'ddpg_garage_tf', 'her_garage_tf', 'ppo_garage_pytorch', 'ppo_garage_tf', 19 | 'td3_garage_pytorch', 'td3_garage_tf', 'trpo_garage_pytorch', 20 | 'trpo_garage_tf', 'vpg_garage_pytorch', 'vpg_garage_tf' 21 | ] 22 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiments for baselines.""" 2 | from garage_benchmarks.experiments.baselines.continuous_mlp_baseline import ( 3 | continuous_mlp_baseline) 4 | from garage_benchmarks.experiments.baselines.gaussian_cnn_baseline import ( 5 | gaussian_cnn_baseline) 6 | from garage_benchmarks.experiments.baselines.gaussian_mlp_baseline import ( 7 | gaussian_mlp_baseline) 8 | 9 | __all__ = [ 10 | 'continuous_mlp_baseline', 'gaussian_cnn_baseline', 'gaussian_mlp_baseline' 11 | ] 12 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/policies/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiments for baselines.""" 2 | from garage_benchmarks.experiments.policies.categorical_cnn_policy import ( 3 | categorical_cnn_policy) 4 | from garage_benchmarks.experiments.policies.categorical_gru_policy import ( 5 | categorical_gru_policy) 6 | from garage_benchmarks.experiments.policies.categorical_lstm_policy import ( 7 | categorical_lstm_policy) 8 | from garage_benchmarks.experiments.policies.categorical_mlp_policy import ( 9 | categorical_mlp_policy) 10 | from garage_benchmarks.experiments.policies.continuous_mlp_policy import ( 11 | continuous_mlp_policy) 12 | from garage_benchmarks.experiments.policies.gaussian_gru_policy import ( 13 | gaussian_gru_policy) 14 | from garage_benchmarks.experiments.policies.gaussian_lstm_policy import ( 15 | gaussian_lstm_policy) 16 | from garage_benchmarks.experiments.policies.gaussian_mlp_policy import ( 17 | gaussian_mlp_policy) 18 | 19 | __all__ = [ 20 | 'categorical_cnn_policy', 'categorical_gru_policy', 21 | 'categorical_lstm_policy', 'categorical_mlp_policy', 22 | 'continuous_mlp_policy', 'gaussian_gru_policy', 'gaussian_lstm_policy', 23 | 'gaussian_mlp_policy' 24 | ] 25 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/policies/categorical_gru_policy.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiment of the CategoricalGRUPolicy.""" 2 | import tensorflow as tf 3 | 4 | from garage import wrap_experiment 5 | from garage.envs import GymEnv, normalize 6 | from garage.experiment import deterministic 7 | from garage.np.baselines import LinearFeatureBaseline 8 | from garage.sampler import LocalSampler 9 | from garage.tf.algos import PPO 10 | from garage.tf.policies import CategoricalGRUPolicy 11 | from garage.trainer import TFTrainer 12 | 13 | 14 | @wrap_experiment 15 | def categorical_gru_policy(ctxt, env_id, seed): 16 | """Create Categorical CNN Policy on TF-PPO. 17 | 18 | Args: 19 | ctxt (garage.experiment.ExperimentContext): The experiment 20 | configuration used by Trainer to create the 21 | snapshotter. 22 | env_id (str): Environment id of the task. 23 | seed (int): Random positive integer for the trial. 24 | 25 | """ 26 | deterministic.set_seed(seed) 27 | 28 | with TFTrainer(ctxt) as trainer: 29 | env = normalize(GymEnv(env_id)) 30 | 31 | policy = CategoricalGRUPolicy( 32 | env_spec=env.spec, 33 | hidden_dim=32, 34 | hidden_nonlinearity=tf.nn.tanh, 35 | ) 36 | 37 | baseline = LinearFeatureBaseline(env_spec=env.spec) 38 | 39 | sampler = LocalSampler(agents=policy, 40 | envs=env, 41 | max_episode_length=env.spec.max_episode_length) 42 | 43 | algo = PPO( 44 | env_spec=env.spec, 45 | policy=policy, 46 | baseline=baseline, 47 | sampler=sampler, 48 | discount=0.99, 49 | gae_lambda=0.95, 50 | lr_clip_range=0.2, 51 | policy_ent_coeff=0.0, 52 | optimizer_args=dict( 53 | batch_size=32, 54 | max_optimization_epochs=10, 55 | learning_rate=1e-3, 56 | ), 57 | ) 58 | 59 | trainer.setup(algo, env) 60 | trainer.train(n_epochs=488, batch_size=2048) 61 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/experiments/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarking experiments for Q-functions.""" 2 | from garage_benchmarks.experiments.q_functions.continuous_mlp_q_function import ( # isort:skip # noqa: E501 3 | continuous_mlp_q_function) 4 | 5 | __all__ = ['continuous_mlp_q_function'] 6 | -------------------------------------------------------------------------------- /benchmarks/src/garage_benchmarks/parameters.py: -------------------------------------------------------------------------------- 1 | """Global parameters for benchmarking.""" 2 | from garage_benchmarks import benchmarks 3 | 4 | Fetch1M_ENV_SET = [ 5 | task['env_id'] for task in benchmarks.get_benchmark('Fetch1M')['tasks'] 6 | ] 7 | 8 | MuJoCo1M_ENV_SET = [ 9 | task['env_id'] for task in benchmarks.get_benchmark('Mujoco1M')['tasks'] 10 | ] 11 | 12 | Atari10M_ENV_SET = [ 13 | task['env_id'] for task in benchmarks.get_benchmark('Atari10M')['tasks'] 14 | ] 15 | 16 | PIXEL_ENV_SET = ['CubeCrash-v0', 'MemorizeDigits-v0'] 17 | 18 | STATE_ENV_SET = [ 19 | 'LunarLander-v2', 20 | 'Assault-ramDeterministic-v4', 21 | 'Breakout-ramDeterministic-v4', 22 | 'ChopperCommand-ramDeterministic-v4', 23 | 'Tutankham-ramDeterministic-v4', 24 | 'CartPole-v1', 25 | ] 26 | -------------------------------------------------------------------------------- /docker/entrypoint-headless.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Dockerfile entrypoint 3 | set -e 4 | 5 | # Get MuJoCo key from the environment 6 | if [ -z "${MJKEY}" ]; then 7 | : 8 | else 9 | echo "${MJKEY}" > ${HOME}/.mujoco/mjkey.txt 10 | fi 11 | 12 | # Setup dummy X server display 13 | # Socket for display :0 may already be in use if the container is connected 14 | # to the network of the host, and other low-numbered socket could also be in 15 | # use, that's why we use 100. 16 | display_num=100 17 | export DISPLAY=:"${display_num}" 18 | Xvfb "${DISPLAY}" -screen 0 1024x768x24 & 19 | pulseaudio -D --exit-idle-time=-1 20 | 21 | # Wait for X to come up 22 | file="/tmp/.X11-unix/X${display_num}" 23 | for i in $(seq 1 10); do 24 | if [ -e "$file" ]; then 25 | break 26 | fi 27 | echo "Waiting for X to start (i.e. $file to be created) (attempt $i/10)" 28 | sleep "$i" 29 | done 30 | if ! [ -e "$file" ]; then 31 | echo "Timed out waiting for X to start: $file was not created" 32 | exit 1 33 | fi 34 | 35 | exec "$@" 36 | -------------------------------------------------------------------------------- /docker/entrypoint-runtime.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Dockerfile entrypoint 3 | set -e 4 | 5 | # Get MuJoCo key from the environment 6 | if [ -z "${MJKEY}" ]; then 7 | : 8 | else 9 | echo "${MJKEY}" > ${HOME}/.mujoco/mjkey.txt 10 | fi 11 | 12 | exec "$@" 13 | -------------------------------------------------------------------------------- /docker/hooks/build: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # https://docs.docker.com/docker-cloud/builds/advanced/ 4 | # DockerHub starts this script in the /docker directory 5 | 6 | set -e 7 | 8 | echo DOCKERFILE_PATH="$DOCKERFILE_PATH" 9 | echo DOCKER_REPO="$DOCKER_REPO" 10 | echo IMAGE_NAME="$IMAGE_NAME" 11 | 12 | if [[ "$DOCKER_REPO" = *"rlworkgroup/garage" ]]; then 13 | echo "Building target garage" 14 | docker build \ 15 | -f "../$DOCKERFILE_PATH" \ 16 | --build-arg BUILDKIT_INLINE_CACHE=1 \ 17 | --target garage \ 18 | -t "$IMAGE_NAME" \ 19 | .. 20 | elif [[ "$DOCKER_REPO" = *"rlworkgroup/garage-nvidia" ]]; then 21 | echo "Building target garage-nvidia" 22 | docker build \ 23 | -f "../$DOCKERFILE_PATH" \ 24 | --build-arg BUILDKIT_INLINE_CACHE=1 \ 25 | --target garage-nvidia \ 26 | -t "$IMAGE_NAME" \ 27 | --build-arg PARENT_IMAGE="nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04" \ 28 | .. 29 | fi 30 | -------------------------------------------------------------------------------- /docs/_static/resl_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/_static/resl_logo.png -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | /* override table width restrictions*/ 2 | @media screen and (min-width: 767px) { 3 | 4 | .wy-table-responsive table td { 5 | /* !important prevents the common CSS stylesheets from overriding 6 | this as on RTD they are loaded after this stylesheet */ 7 | white-space: normal !important; 8 | } 9 | 10 | .wy-table-responsive { 11 | overflow: visible !important; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /docs/_static/viterbi_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/_static/viterbi_logo.png -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends "!footer.html" %} 2 | {% block extrafooter %} 3 |
4 | 5 | {{ super() }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /docs/bibtex.json: -------------------------------------------------------------------------------- 1 | { 2 | "cited": { 3 | "user/algo_bc": [ 4 | "ho2016model" 5 | ], 6 | "user/algo_cem": [ 7 | "rubinstein2004cross" 8 | ], 9 | "user/algo_ddpg": [ 10 | "lillicrap2015continuous" 11 | ], 12 | "user/algo_erwr": [ 13 | "peters2007reward", 14 | "2009koberpolicy" 15 | ], 16 | "user/algo_maml": [ 17 | "finn2017modelagnostic" 18 | ], 19 | "user/algo_mtppo": [ 20 | "yu2019metaworld", 21 | "schulman2017proximal" 22 | ], 23 | "user/algo_mttrpo": [ 24 | "schulman2015trust", 25 | "yu2019metaworld" 26 | ], 27 | "user/algo_pearl": [ 28 | "rakelly2019efficient" 29 | ], 30 | "user/algo_ppo": [ 31 | "schulman2017proximal", 32 | "levine2018reinforcement" 33 | ], 34 | "user/algo_rl2": [ 35 | "duan2016rl" 36 | ], 37 | "user/algo_sac": [ 38 | "haarnoja2018soft" 39 | ], 40 | "user/algo_td3": [ 41 | "Fujimoto2018AddressingFA" 42 | ], 43 | "user/algo_trpo": [ 44 | "schulman2015trust" 45 | ], 46 | "user/algo_vpg": [ 47 | "williams1992simple" 48 | ], 49 | "user/custom_worker": [ 50 | "duan2016rl" 51 | ], 52 | "user/implement_algo": [ 53 | "williams1992simple", 54 | "rubinstein2004cross" 55 | ] 56 | } 57 | } -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-autoapi 2 | sphinxcontrib-bibtex 3 | 4 | akro 5 | click 6 | cloudpickle 7 | cma==2.7.0 8 | dm_env 9 | dowel==0.0.3 10 | gym[atari, box2d, classic_control]==0.17.2 11 | matplotlib 12 | psutil 13 | pyprind 14 | python-dateutil 15 | scikit-image 16 | scipy 17 | setproctitle 18 | tensorflow 19 | tensorflow-probability 20 | docutils<0.17 21 | -------------------------------------------------------------------------------- /docs/user/concept_experiment.md: -------------------------------------------------------------------------------- 1 | # Experiment 2 | 3 | The experiment in Garage is a function we use to run an algorithm. This 4 | function is wrapped by a decorator called `wrap_experiment`, which defines the 5 | scope of an experiment, sets up the log directory and what to be saved in 6 | snapshots. 7 | 8 | Below is a simple experiment launcher. The first parameter of the experiment 9 | function must be `ctxt`, which is used to pass the experiment's context into 10 | the inner function. 11 | 12 | ```py 13 | from garage import wrap_experiment 14 | 15 | @wrap_experiment 16 | def my_first_experiment(ctxt=None): 17 | print('Hello World!') 18 | 19 | my_first_experiment() 20 | ``` 21 | 22 | Running the example launcher will generate outputs like the following 23 | (`CUR_DIR` is the current directory). 24 | 25 | ```sh 26 | 2020-08-20 15:18:53 | [my_first_experiment] Logging to CUR_DIR/data/local/experiment/my_first_experiment 27 | Hello World! 28 | ``` 29 | 30 | The followings are some useful parameters of `wrap_experiment`. You can see 31 | details of its parameters [here](../_autoapi/garage/index.html#garage.wrap_experiment). 32 | 33 | * log_dir: The custom log directory to log to. 34 | * snapshot_mode: Policy for which snapshots to keep. The last iteration will be 35 | saved by default. Here are acceptable inputs. 36 | * `'last'`, only the last iteration will be saved. 37 | * `'all'`, all iterations will be saved. 38 | * `'gap'`, every snapshot_gap iterations are saved. 39 | * `'none'`, do not save snapshots. 40 | * snapshot_gap: Gap between snapshot iterations. Waits this number of 41 | iterations before taking another snapshot. 42 | 43 | Here is an example to set a custom log directory: 44 | 45 | ```py 46 | from garage import wrap_experiment 47 | 48 | @wrap_experiment(log_dir='my_custom_log_fir') 49 | def my_first_experiment(ctxt=None): 50 | print('Hello World!') 51 | ``` 52 | 53 | You can check [this user guide](experiments) for how to write and run an 54 | experiment in detail. 55 | 56 | ---- 57 | 58 | *This page was authored by Ruofu Wang ([@yeukfu](https://github.com/yeukfu)).* 59 | -------------------------------------------------------------------------------- /docs/user/images/bc_meanLoss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/bc_meanLoss.png -------------------------------------------------------------------------------- /docs/user/images/bc_stdLoss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/bc_stdLoss.png -------------------------------------------------------------------------------- /docs/user/images/dqn_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/dqn_plots.png -------------------------------------------------------------------------------- /docs/user/images/numpy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/numpy.png -------------------------------------------------------------------------------- /docs/user/images/pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/pytorch.png -------------------------------------------------------------------------------- /docs/user/images/td3_tf_HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_tf_HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_tf_Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_tf_Hopper-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_tf_InvertedDoublePendulum-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_tf_InvertedDoublePendulum-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_tf_InvertedPendulum-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_tf_InvertedPendulum-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_tf_Swimmer-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_tf_Swimmer-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_Hopper-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_InvertedDoublePendulum-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_InvertedDoublePendulum-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_InvertedPendulum-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_InvertedPendulum-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_Swimmer-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_Swimmer-v2.png -------------------------------------------------------------------------------- /docs/user/images/td3_torch_Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/td3_torch_Walker2d-v2.png -------------------------------------------------------------------------------- /docs/user/images/tf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/images/tf.png -------------------------------------------------------------------------------- /docs/user/implement_worker.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/implement_worker.md -------------------------------------------------------------------------------- /docs/user/matplotlib_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/docs/user/matplotlib_example.png -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | python: 7 | version: 3.7 8 | install: 9 | - requirements: docs/requirements.txt 10 | system_packages: true 11 | -------------------------------------------------------------------------------- /scripts/ci/check_docs_only.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | status=0 3 | 4 | echo "Checking commit range ${TRAVIS_COMMIT_RANGE}" 5 | SOURCE="${TRAVIS_COMMIT_RANGE%...*}" 6 | ORIGIN="${TRAVIS_COMMIT_RANGE#*...}" 7 | status="$((${status} | ${?}))" 8 | 9 | while read commit; do 10 | echo "Checking for docs only in ${commit}" 11 | not_docs="$(git show --name-only --oneline ${commit} \ 12 | | tail -n +2 \ 13 | | awk -F . '{print $NF}' \ 14 | | uniq \ 15 | | grep -v 'md\|rst\|png\|bib\|html\|css')" 16 | test -z "${not_docs}" 17 | pass=$? 18 | status="$((${status} | ${pass}))" 19 | 20 | # Print message if it fails 21 | if [[ "${pass}" -ne 0 ]]; then 22 | echo "Found non-documentation changes in commit ${commit}" 23 | fi 24 | 25 | done < <(git log --cherry-pick --left-only --pretty="%H" \ 26 | "${ORIGIN}...${SOURCE}") 27 | 28 | exit "${status}" 29 | -------------------------------------------------------------------------------- /scripts/ci/check_no_deps_changed.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | status=0 3 | 4 | function join_by { local d="${1}"; shift; echo -n "${1}"; shift; printf "%s" "${@/#/$d}"; } 5 | 6 | deps_files=( 7 | '^setup.py$' 8 | '^benchmarks/setup.py$' 9 | '^Makefile$' 10 | '^docker/' 11 | ) 12 | deps_regex="$(join_by '\|' ${deps_files[@]})" 13 | 14 | echo "Checking commit range ${TRAVIS_COMMIT_RANGE}" 15 | SOURCE="${TRAVIS_COMMIT_RANGE%...*}" 16 | ORIGIN="${TRAVIS_COMMIT_RANGE#*...}" 17 | status="$((${status} | ${?}))" 18 | 19 | while read commit; do 20 | echo "Checking for dependency changes in ${commit}" 21 | deps_change="$(git show --name-only --oneline ${commit} \ 22 | | tail -n +2 \ 23 | | grep "${deps_regex}" \ 24 | )" 25 | test -z "${deps_change}" 26 | pass=$? 27 | status="$((${status} | ${pass}))" 28 | 29 | # Print message if it fails 30 | if [[ "${pass}" -ne 0 ]]; then 31 | echo -e "Found dependency changes in ${commit}" 32 | echo -e "Matched with changes in files:\n${deps_change}" 33 | fi 34 | 35 | done < <(git log --cherry-pick --left-only --pretty="%H" \ 36 | "${ORIGIN}...${SOURCE}") 37 | 38 | exit "${status}" 39 | -------------------------------------------------------------------------------- /scripts/ci/check_precommit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | status=0 3 | 4 | echo "Checking commit range ${PR_COMMIT_RANGE}" 5 | SOURCE="${PR_COMMIT_RANGE%...*}" 6 | ORIGIN="${PR_COMMIT_RANGE#*...}" 7 | pre-commit run --source "${SOURCE}" --origin "${ORIGIN}" 8 | status="$((${status} | ${?}))" 9 | 10 | while read commit; do 11 | echo "Checking commit message for ${commit}" 12 | commit_msg="$(mktemp)" 13 | git log --format=%B -n 1 "${commit}" > "${commit_msg}" 14 | scripts/check_commit_message "${commit_msg}" 15 | pass=$? 16 | status="$((${status} | ${pass}))" 17 | 18 | # Print message if it fails 19 | if [[ "${pass}" -ne 0 ]]; then 20 | echo "Failing commit message:" 21 | cat "${commit_msg}" 22 | fi 23 | 24 | done < <(git log --cherry-pick --left-only --pretty="%H" \ 25 | "${ORIGIN}...${SOURCE}") 26 | 27 | exit "${status}" 28 | -------------------------------------------------------------------------------- /src/garage/__init__.py: -------------------------------------------------------------------------------- 1 | """Garage Base.""" 2 | # yapf: disable 3 | 4 | from garage._dtypes import EpisodeBatch, TimeStep, TimeStepBatch 5 | from garage._environment import (Environment, EnvSpec, EnvStep, InOutSpec, 6 | StepType, Wrapper) 7 | from garage._functions import (_Default, log_multitask_performance, 8 | log_performance, make_optimizer, 9 | obtain_evaluation_episodes, rollout) 10 | from garage.experiment.experiment import wrap_experiment 11 | from garage.trainer import TFTrainer, Trainer 12 | 13 | # yapf: enable 14 | 15 | __all__ = [ 16 | '_Default', 17 | 'make_optimizer', 18 | 'wrap_experiment', 19 | 'TimeStep', 20 | 'EpisodeBatch', 21 | 'log_multitask_performance', 22 | 'log_performance', 23 | 'InOutSpec', 24 | 'TimeStepBatch', 25 | 'Environment', 26 | 'StepType', 27 | 'EnvStep', 28 | 'EnvSpec', 29 | 'Wrapper', 30 | 'rollout', 31 | 'obtain_evaluation_episodes', 32 | 'Trainer', 33 | 'TFTrainer', 34 | ] 35 | -------------------------------------------------------------------------------- /src/garage/envs/__init__.py: -------------------------------------------------------------------------------- 1 | """Garage wrappers for gym environments.""" 2 | 3 | from garage.envs.grid_world_env import GridWorldEnv 4 | from garage.envs.gym_env import GymEnv 5 | from garage.envs.metaworld_set_task_env import MetaWorldSetTaskEnv 6 | from garage.envs.multi_env_wrapper import MultiEnvWrapper 7 | from garage.envs.normalized_env import normalize 8 | from garage.envs.point_env import PointEnv 9 | from garage.envs.task_name_wrapper import TaskNameWrapper 10 | from garage.envs.task_onehot_wrapper import TaskOnehotWrapper 11 | 12 | __all__ = [ 13 | 'GymEnv', 14 | 'GridWorldEnv', 15 | 'MetaWorldSetTaskEnv', 16 | 'MultiEnvWrapper', 17 | 'normalize', 18 | 'PointEnv', 19 | 'TaskOnehotWrapper', 20 | 'TaskNameWrapper', 21 | ] 22 | -------------------------------------------------------------------------------- /src/garage/envs/bullet/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrappers for the py_bullet based gym environments. 2 | 3 | See https://github.com/bulletphysics/bullet3/tree/master/examples/pybullet 4 | """ 5 | try: 6 | import pybullet_envs # noqa: F401 7 | except Exception as e: 8 | raise ImportError('To use garage\'s bullet wrappers, please install ' 9 | 'garage[bullet]') 10 | 11 | from garage.envs.bullet.bullet_env import BulletEnv 12 | 13 | __all__ = ['BulletEnv'] 14 | -------------------------------------------------------------------------------- /src/garage/envs/dm_control/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrappers for the DeepMind Control Suite. 2 | 3 | See https://github.com/deepmind/dm_control 4 | """ 5 | try: 6 | import dm_control # noqa: F401 7 | except ImportError: 8 | raise ImportError("To use garage's dm_control wrappers, please install " 9 | 'garage[dm_control].') 10 | 11 | from garage.envs.dm_control.dm_control_env import DMControlEnv 12 | from garage.envs.dm_control.dm_control_viewer import DmControlViewer 13 | 14 | __all__ = ['DmControlViewer', 'DMControlEnv'] 15 | -------------------------------------------------------------------------------- /src/garage/envs/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | """Garage wrappers for mujoco based gym environments.""" 2 | try: 3 | import mujoco_py # noqa: F401 4 | except Exception as e: 5 | raise e 6 | 7 | from garage.envs.mujoco.half_cheetah_dir_env import HalfCheetahDirEnv 8 | from garage.envs.mujoco.half_cheetah_vel_env import HalfCheetahVelEnv 9 | 10 | __all__ = [ 11 | 'HalfCheetahDirEnv', 12 | 'HalfCheetahVelEnv', 13 | ] 14 | -------------------------------------------------------------------------------- /src/garage/envs/task_name_wrapper.py: -------------------------------------------------------------------------------- 1 | """Wrapper for adding an environment info to track task ID.""" 2 | from garage import Wrapper 3 | 4 | 5 | class TaskNameWrapper(Wrapper): 6 | """Add task_name or task_id to env infos. 7 | 8 | Args: 9 | env (gym.Env): The environment to wrap. 10 | task_name (str or None): Task name to be added, if any. 11 | task_id (int or None): Task ID to be added, if any. 12 | 13 | """ 14 | 15 | def __init__(self, env, *, task_name=None, task_id=None): 16 | super().__init__(env) 17 | self._task_name = task_name 18 | self._task_id = task_id 19 | 20 | def step(self, action): 21 | """gym.Env step for the active task env. 22 | 23 | Args: 24 | action (np.ndarray): Action performed by the agent in the 25 | environment. 26 | 27 | Returns: 28 | tuple: 29 | np.ndarray: Agent's observation of the current environment. 30 | float: Amount of reward yielded by previous action. 31 | bool: True iff the episode has ended. 32 | dict[str, np.ndarray]: Contains auxiliary diagnostic 33 | information about this time-step. 34 | 35 | """ 36 | es = super().step(action) 37 | if self._task_name is not None: 38 | es.env_info['task_name'] = self._task_name 39 | if self._task_id is not None: 40 | es.env_info['task_id'] = self._task_id 41 | return es 42 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """gym.Env wrappers. 2 | 3 | Used to transform an environment in a modular way. 4 | It is also possible to apply multiple wrappers at the same 5 | time. 6 | 7 | Example: 8 | StackFrames(GrayScale(gym.make('env'))) 9 | 10 | """ 11 | from garage.envs.wrappers.atari_env import AtariEnv 12 | from garage.envs.wrappers.clip_reward import ClipReward 13 | from garage.envs.wrappers.episodic_life import EpisodicLife 14 | from garage.envs.wrappers.fire_reset import FireReset 15 | from garage.envs.wrappers.grayscale import Grayscale 16 | from garage.envs.wrappers.max_and_skip import MaxAndSkip 17 | from garage.envs.wrappers.noop import Noop 18 | from garage.envs.wrappers.pixel_observation import PixelObservationWrapper 19 | from garage.envs.wrappers.resize import Resize 20 | from garage.envs.wrappers.stack_frames import StackFrames 21 | 22 | __all__ = [ 23 | 'AtariEnv', 'ClipReward', 'EpisodicLife', 'FireReset', 'Grayscale', 24 | 'MaxAndSkip', 'Noop', 'PixelObservationWrapper', 'Resize', 'StackFrames' 25 | ] 26 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/atari_env.py: -------------------------------------------------------------------------------- 1 | """Episodic life wrapper for gym.Env.""" 2 | import gym 3 | import numpy as np 4 | 5 | 6 | class AtariEnv(gym.Wrapper): 7 | """Atari environment wrapper for gym.Env. 8 | 9 | This wrapper convert the observations returned from baselines wrapped 10 | environment, which is a LazyFrames object into numpy arrays. 11 | 12 | Args: 13 | env (gym.Env): The environment to be wrapped. 14 | """ 15 | 16 | def __init__(self, env): 17 | super().__init__(env) 18 | 19 | def step(self, action): 20 | """gym.Env step function.""" 21 | obs, reward, done, info = self.env.step(action) 22 | return np.asarray(obs), reward, done, info 23 | 24 | def reset(self, **kwargs): 25 | """gym.Env reset function.""" 26 | return np.asarray(self.env.reset()) 27 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/clip_reward.py: -------------------------------------------------------------------------------- 1 | """Clip reward for gym.Env.""" 2 | import gym 3 | import numpy as np 4 | 5 | 6 | class ClipReward(gym.Wrapper): 7 | """Clip the reward by its sign.""" 8 | 9 | def step(self, ac): 10 | """gym.Env step function.""" 11 | obs, reward, done, info = self.env.step(ac) 12 | return obs, np.sign(reward), done, info 13 | 14 | def reset(self): 15 | """gym.Env reset.""" 16 | return self.env.reset() 17 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/episodic_life.py: -------------------------------------------------------------------------------- 1 | """Episodic life wrapper for gym.Env.""" 2 | import gym 3 | 4 | 5 | class EpisodicLife(gym.Wrapper): 6 | """Episodic life wrapper for gym.Env. 7 | 8 | This wrapper makes episode end when a life is lost, but only reset 9 | when all lives are lost. 10 | 11 | Args: 12 | env: The environment to be wrapped. 13 | """ 14 | 15 | def __init__(self, env): 16 | super().__init__(env) 17 | self._lives = 0 18 | self._was_real_done = True 19 | 20 | def step(self, action): 21 | """gym.Env step function.""" 22 | obs, reward, done, info = self.env.step(action) 23 | self._was_real_done = done 24 | lives = self.env.unwrapped.ale.lives() 25 | if lives < self._lives and lives > 0: 26 | done = True 27 | self._lives = lives 28 | return obs, reward, done, info 29 | 30 | def reset(self, **kwargs): 31 | """ 32 | gym.Env reset function. 33 | 34 | Reset only when lives are lost. 35 | """ 36 | if self._was_real_done: 37 | obs = self.env.reset(**kwargs) 38 | else: 39 | # no-op step 40 | obs, _, _, _ = self.env.step(0) 41 | self._lives = self.env.unwrapped.ale.lives() 42 | return obs 43 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/fire_reset.py: -------------------------------------------------------------------------------- 1 | """Fire reset wrapper for gym.Env.""" 2 | import gym 3 | 4 | 5 | class FireReset(gym.Wrapper): 6 | """Fire reset wrapper for gym.Env. 7 | 8 | Take action "fire" on reset. 9 | 10 | Args: 11 | env (gym.Env): The environment to be wrapped. 12 | """ 13 | 14 | def __init__(self, env): 15 | super().__init__(env) 16 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE', ( 17 | 'Only use fire reset wrapper for suitable environment!') 18 | assert len(env.unwrapped.get_action_meanings()) >= 3, ( 19 | 'Only use fire reset wrapper for suitable environment!') 20 | 21 | def step(self, action): 22 | """gym.Env step function. 23 | 24 | Args: 25 | action (int): index of the action to take. 26 | 27 | Returns: 28 | np.ndarray: Observation conforming to observation_space 29 | float: Reward for this step 30 | bool: Termination signal 31 | dict: Extra information from the environment. 32 | """ 33 | return self.env.step(action) 34 | 35 | def reset(self, **kwargs): 36 | """gym.Env reset function. 37 | 38 | Args: 39 | kwargs (dict): extra arguments passed to gym.Env.reset() 40 | 41 | Returns: 42 | np.ndarray: next observation. 43 | """ 44 | self.env.reset(**kwargs) 45 | obs, _, done, _ = self.env.step(1) 46 | if done: 47 | self.env.reset(**kwargs) 48 | obs, _, done, _ = self.env.step(2) 49 | if done: 50 | self.env.reset(**kwargs) 51 | return obs 52 | -------------------------------------------------------------------------------- /src/garage/envs/wrappers/noop.py: -------------------------------------------------------------------------------- 1 | """Noop wrapper for gym.Env.""" 2 | import gym 3 | import numpy as np 4 | 5 | 6 | class Noop(gym.Wrapper): 7 | """Noop wrapper for gym.Env. 8 | 9 | It samples initial states by taking random number of no-ops on reset. 10 | No-op is assumed to be action 0. 11 | 12 | Args: 13 | env (gym.Env): The environment to be wrapped. 14 | noop_max (int): Maximum number no-op to be performed on reset. 15 | """ 16 | 17 | def __init__(self, env, noop_max=30): 18 | super().__init__(env) 19 | self._noop_max = noop_max 20 | self._noop_action = 0 21 | assert noop_max > 0, 'noop_max should be larger than 0!' 22 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP', ( 23 | "No-op should be the 0-th action but it's not in {}!".format(env)) 24 | 25 | def step(self, action): 26 | """gym.Env step function.""" 27 | return self.env.step(action) 28 | 29 | def reset(self, **kwargs): 30 | """gym.Env reset function.""" 31 | obs = self.env.reset(**kwargs) 32 | noops = np.random.randint(1, self._noop_max + 1) 33 | for _ in range(noops): 34 | obs, _, done, _ = self.step(self._noop_action) 35 | if done: 36 | obs = self.env.reset(**kwargs) 37 | return obs 38 | -------------------------------------------------------------------------------- /src/garage/examples/np/cem_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with Cross Entropy Method. 3 | 4 | Here it runs CartPole-v1 environment with 100 epoches. 5 | 6 | Results: 7 | AverageReturn: 100 8 | RiseTime: epoch 8 9 | """ 10 | from garage import wrap_experiment 11 | from garage.envs import GymEnv 12 | from garage.experiment.deterministic import set_seed 13 | from garage.np.algos import CEM 14 | from garage.sampler import LocalSampler 15 | from garage.tf.policies import CategoricalMLPPolicy 16 | from garage.trainer import TFTrainer 17 | 18 | 19 | @wrap_experiment 20 | def cem_cartpole(ctxt=None, seed=1): 21 | """Train CEM with Cartpole-v1 environment. 22 | 23 | Args: 24 | ctxt (garage.experiment.ExperimentContext): The experiment 25 | configuration used by Trainer to create the snapshotter. 26 | seed (int): Used to seed the random number generator to produce 27 | determinism. 28 | 29 | """ 30 | set_seed(seed) 31 | with TFTrainer(snapshot_config=ctxt) as trainer: 32 | env = GymEnv('CartPole-v1') 33 | 34 | policy = CategoricalMLPPolicy(name='policy', 35 | env_spec=env.spec, 36 | hidden_sizes=(32, 32)) 37 | 38 | n_samples = 20 39 | 40 | sampler = LocalSampler(agents=policy, 41 | envs=env, 42 | max_episode_length=env.spec.max_episode_length, 43 | is_tf_worker=True) 44 | 45 | algo = CEM(env_spec=env.spec, 46 | policy=policy, 47 | sampler=sampler, 48 | best_frac=0.05, 49 | n_samples=n_samples) 50 | 51 | trainer.setup(algo, env) 52 | trainer.train(n_epochs=100, batch_size=1000) 53 | 54 | 55 | cem_cartpole(seed=1) 56 | -------------------------------------------------------------------------------- /src/garage/examples/np/cma_es_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with CMA-ES. 3 | 4 | Here it runs CartPole-v1 environment with 100 epoches. 5 | 6 | Results: 7 | AverageReturn: 100 8 | RiseTime: epoch 38 (itr 760), 9 | but regression is observed in the course of training. 10 | """ 11 | from garage import wrap_experiment 12 | from garage.envs import GymEnv 13 | from garage.experiment.deterministic import set_seed 14 | from garage.np.algos import CMAES 15 | from garage.sampler import LocalSampler 16 | from garage.tf.policies import CategoricalMLPPolicy 17 | from garage.trainer import TFTrainer 18 | 19 | 20 | @wrap_experiment 21 | def cma_es_cartpole(ctxt=None, seed=1): 22 | """Train CMA_ES with Cartpole-v1 environment. 23 | 24 | Args: 25 | ctxt (garage.experiment.ExperimentContext): The experiment 26 | configuration used by Trainer to create the snapshotter. 27 | seed (int): Used to seed the random number generator to produce 28 | determinism. 29 | 30 | """ 31 | set_seed(seed) 32 | with TFTrainer(ctxt) as trainer: 33 | env = GymEnv('CartPole-v1') 34 | 35 | policy = CategoricalMLPPolicy(name='policy', 36 | env_spec=env.spec, 37 | hidden_sizes=(32, 32)) 38 | 39 | n_samples = 20 40 | 41 | sampler = LocalSampler(agents=policy, 42 | envs=env, 43 | max_episode_length=env.spec.max_episode_length, 44 | is_tf_worker=True) 45 | 46 | algo = CMAES(env_spec=env.spec, 47 | policy=policy, 48 | sampler=sampler, 49 | n_samples=n_samples) 50 | 51 | trainer.setup(algo, env) 52 | trainer.train(n_epochs=100, batch_size=1000) 53 | 54 | 55 | cma_es_cartpole() 56 | -------------------------------------------------------------------------------- /src/garage/examples/step_bullet_kuka_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Example of how to load, step, and visualize a Bullet Kuka environment. 3 | 4 | This example requires that garage[bullet] be installed. 5 | 6 | Note that pybullet_envs is imported so that bullet environments are 7 | registered in gym registry. 8 | """ 9 | # yapf: disable 10 | 11 | import click 12 | import gym 13 | import pybullet_envs # noqa: F401 # pylint: disable=unused-import 14 | 15 | from garage.envs import GymEnv 16 | 17 | # yapf: enable 18 | 19 | 20 | @click.command() 21 | @click.option('--n_steps', 22 | default=1000, 23 | type=int, 24 | help='Number of steps to run') 25 | def step_bullet_kuka_env(n_steps=1000): 26 | """Load, step, and visualize a Bullet Kuka environment. 27 | 28 | Args: 29 | n_steps (int): number of steps to run. 30 | 31 | """ 32 | # Construct the environment 33 | env = GymEnv(gym.make('KukaBulletEnv-v0', renders=True, isDiscrete=True)) 34 | 35 | # Reset the environment and launch the viewer 36 | env.reset() 37 | env.visualize() 38 | 39 | step_count = 0 40 | es = env.step(env.action_space.sample()) 41 | while not es.last and step_count < n_steps: 42 | es = env.step(env.action_space.sample()) 43 | step_count += 1 44 | 45 | 46 | step_bullet_kuka_env() 47 | -------------------------------------------------------------------------------- /src/garage/examples/step_dm_control_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Example of how to load, step, and visualize an environment. 3 | 4 | This example requires that garage[dm_control] be installed. 5 | """ 6 | import argparse 7 | 8 | from garage.envs.dm_control import DMControlEnv 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--n_steps', 12 | type=int, 13 | default=1000, 14 | help='Number of steps to run') 15 | args = parser.parse_args() 16 | 17 | # Construct the environment 18 | env = DMControlEnv.from_suite('walker', 'run') 19 | 20 | # Reset the environment and launch the viewer 21 | env.reset() 22 | env.visualize() 23 | 24 | # Step randomly until interrupted 25 | step_count = 0 26 | es = env.step(env.action_space.sample()) 27 | while not es.last and step_count < args.n_steps: 28 | es = env.step(env.action_space.sample()) 29 | step_count += 1 30 | 31 | env.close() 32 | -------------------------------------------------------------------------------- /src/garage/examples/step_gym_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Example of how to load, step, and visualize an environment.""" 3 | import argparse 4 | 5 | from garage.envs import GymEnv 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--n_steps', 9 | type=int, 10 | default=1000, 11 | help='Number of steps to run') 12 | args = parser.parse_args() 13 | 14 | # Construct the environment 15 | env = GymEnv('MountainCar-v0') 16 | 17 | # Reset the environment and launch the viewer 18 | env.reset() 19 | env.visualize() 20 | 21 | step_count = 0 22 | es = env.step(env.action_space.sample()) 23 | 24 | while not es.last and step_count < args.n_steps: 25 | es = env.step(env.action_space.sample()) 26 | step_count += 1 27 | 28 | env.close() 29 | -------------------------------------------------------------------------------- /src/garage/examples/tf/erwr_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with ERWR algorithm. 3 | 4 | Here it runs CartpoleEnv on ERWR with 100 iterations. 5 | 6 | Results: 7 | AverageReturn: 100 8 | RiseTime: itr 34 9 | """ 10 | from garage import wrap_experiment 11 | from garage.envs import GymEnv 12 | from garage.experiment.deterministic import set_seed 13 | from garage.np.baselines import LinearFeatureBaseline 14 | from garage.sampler import RaySampler 15 | from garage.tf.algos import ERWR 16 | from garage.tf.policies import CategoricalMLPPolicy 17 | from garage.trainer import TFTrainer 18 | 19 | 20 | @wrap_experiment 21 | def erwr_cartpole(ctxt=None, seed=1): 22 | """Train with ERWR on CartPole-v1 environment. 23 | 24 | Args: 25 | ctxt (garage.experiment.ExperimentContext): The experiment 26 | configuration used by Trainer to create the snapshotter. 27 | seed (int): Used to seed the random number generator to produce 28 | determinism. 29 | 30 | """ 31 | set_seed(seed) 32 | with TFTrainer(snapshot_config=ctxt) as trainer: 33 | env = GymEnv('CartPole-v1') 34 | 35 | policy = CategoricalMLPPolicy(name='policy', 36 | env_spec=env.spec, 37 | hidden_sizes=(32, 32)) 38 | 39 | baseline = LinearFeatureBaseline(env_spec=env.spec) 40 | 41 | sampler = RaySampler(agents=policy, 42 | envs=env, 43 | max_episode_length=env.spec.max_episode_length, 44 | is_tf_worker=True) 45 | 46 | algo = ERWR(env_spec=env.spec, 47 | policy=policy, 48 | baseline=baseline, 49 | sampler=sampler, 50 | discount=0.99) 51 | 52 | trainer.setup(algo=algo, env=env) 53 | 54 | trainer.train(n_epochs=100, batch_size=10000, plot=False) 55 | 56 | 57 | erwr_cartpole(seed=1) 58 | -------------------------------------------------------------------------------- /src/garage/examples/tf/multi_env_trpo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train multiple tasks with TRPO algorithm.""" 3 | from garage import wrap_experiment 4 | from garage.envs import normalize, PointEnv 5 | from garage.envs.multi_env_wrapper import MultiEnvWrapper 6 | from garage.experiment.deterministic import set_seed 7 | from garage.np.baselines import LinearFeatureBaseline 8 | from garage.sampler import RaySampler 9 | from garage.tf.algos import TRPO 10 | from garage.tf.policies import GaussianMLPPolicy 11 | from garage.trainer import TFTrainer 12 | 13 | 14 | @wrap_experiment 15 | def multi_env_trpo(ctxt=None, seed=1): 16 | """Train TRPO on two different PointEnv instances. 17 | 18 | Args: 19 | ctxt (garage.experiment.ExperimentContext): The experiment 20 | configuration used by Trainer to create the snapshotter. 21 | seed (int): Used to seed the random number generator to produce 22 | determinism. 23 | 24 | """ 25 | set_seed(seed) 26 | with TFTrainer(ctxt) as trainer: 27 | env1 = normalize(PointEnv(goal=(-1., 0.), max_episode_length=100)) 28 | env2 = normalize(PointEnv(goal=(1., 0.), max_episode_length=100)) 29 | env = MultiEnvWrapper([env1, env2]) 30 | 31 | policy = GaussianMLPPolicy(env_spec=env.spec) 32 | 33 | baseline = LinearFeatureBaseline(env_spec=env.spec) 34 | 35 | sampler = RaySampler(agents=policy, 36 | envs=env, 37 | max_episode_length=env.spec.max_episode_length, 38 | is_tf_worker=True) 39 | 40 | algo = TRPO(env_spec=env.spec, 41 | policy=policy, 42 | baseline=baseline, 43 | sampler=sampler, 44 | discount=0.99, 45 | gae_lambda=0.95, 46 | lr_clip_range=0.2, 47 | policy_ent_coeff=0.0) 48 | 49 | trainer.setup(algo, env) 50 | trainer.train(n_epochs=40, batch_size=2048, plot=False) 51 | 52 | 53 | multi_env_trpo() 54 | -------------------------------------------------------------------------------- /src/garage/examples/tf/reps_gym_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with REPS algorithm. 3 | 4 | Here it runs gym CartPole env with 100 iterations. 5 | 6 | Results: 7 | AverageReturn: 100 +/- 40 8 | RiseTime: itr 10 +/- 5 9 | 10 | """ 11 | 12 | from garage import wrap_experiment 13 | from garage.envs import GymEnv 14 | from garage.experiment.deterministic import set_seed 15 | from garage.np.baselines import LinearFeatureBaseline 16 | from garage.sampler import RaySampler 17 | from garage.tf.algos import REPS 18 | from garage.tf.policies import CategoricalMLPPolicy 19 | from garage.trainer import TFTrainer 20 | 21 | 22 | @wrap_experiment 23 | def reps_gym_cartpole(ctxt=None, seed=1): 24 | """Train REPS with CartPole-v0 environment. 25 | 26 | Args: 27 | ctxt (garage.experiment.ExperimentContext): The experiment 28 | configuration used by Trainer to create the snapshotter. 29 | seed (int): Used to seed the random number generator to produce 30 | determinism. 31 | 32 | """ 33 | set_seed(seed) 34 | with TFTrainer(snapshot_config=ctxt) as trainer: 35 | env = GymEnv('CartPole-v0') 36 | 37 | policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=[32, 32]) 38 | 39 | baseline = LinearFeatureBaseline(env_spec=env.spec) 40 | 41 | sampler = RaySampler(agents=policy, 42 | envs=env, 43 | max_episode_length=env.spec.max_episode_length, 44 | is_tf_worker=True) 45 | 46 | algo = REPS(env_spec=env.spec, 47 | policy=policy, 48 | baseline=baseline, 49 | sampler=sampler, 50 | discount=0.99) 51 | 52 | trainer.setup(algo, env) 53 | trainer.train(n_epochs=100, batch_size=4000, plot=False) 54 | 55 | 56 | reps_gym_cartpole() 57 | -------------------------------------------------------------------------------- /src/garage/examples/tf/resume_training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to resume training programmatically.""" 3 | # pylint: disable=no-value-for-parameter 4 | import click 5 | 6 | from garage import wrap_experiment 7 | from garage.trainer import TFTrainer 8 | 9 | 10 | @click.command() 11 | @click.option('--saved_dir', 12 | required=True, 13 | help='Path where snapshots are saved.') 14 | @wrap_experiment 15 | def resume_experiment(ctxt, saved_dir): 16 | """Resume a Tensorflow experiment. 17 | 18 | Args: 19 | ctxt (garage.experiment.ExperimentContext): The experiment 20 | configuration used by Trainer to create the snapshotter. 21 | saved_dir (str): Path where snapshots are saved. 22 | 23 | """ 24 | with TFTrainer(snapshot_config=ctxt) as trainer: 25 | trainer.restore(from_dir=saved_dir) 26 | trainer.resume() 27 | 28 | 29 | resume_experiment() 30 | -------------------------------------------------------------------------------- /src/garage/examples/tf/trpo_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with TRPO algorithm. 3 | 4 | Here it runs CartPole-v1 environment with 100 iterations. 5 | 6 | Results: 7 | AverageReturn: 100 8 | RiseTime: itr 13 9 | """ 10 | from garage import wrap_experiment 11 | from garage.envs import GymEnv 12 | from garage.experiment.deterministic import set_seed 13 | from garage.np.baselines import LinearFeatureBaseline 14 | from garage.sampler import LocalSampler 15 | from garage.tf.algos import TRPO 16 | from garage.tf.policies import CategoricalMLPPolicy 17 | from garage.trainer import TFTrainer 18 | 19 | 20 | @wrap_experiment 21 | def trpo_cartpole(ctxt=None, seed=1): 22 | """Train TRPO with CartPole-v1 environment. 23 | 24 | Args: 25 | ctxt (garage.experiment.ExperimentContext): The experiment 26 | configuration used by Trainer to create the snapshotter. 27 | seed (int): Used to seed the random number generator to produce 28 | determinism. 29 | 30 | """ 31 | set_seed(seed) 32 | with TFTrainer(ctxt) as trainer: 33 | env = GymEnv('CartPole-v1') 34 | 35 | policy = CategoricalMLPPolicy(name='policy', 36 | env_spec=env.spec, 37 | hidden_sizes=(32, 32)) 38 | 39 | baseline = LinearFeatureBaseline(env_spec=env.spec) 40 | 41 | sampler = LocalSampler(agents=policy, 42 | envs=env, 43 | max_episode_length=env.spec.max_episode_length, 44 | is_tf_worker=True) 45 | 46 | algo = TRPO(env_spec=env.spec, 47 | policy=policy, 48 | baseline=baseline, 49 | sampler=sampler, 50 | discount=0.99, 51 | max_kl_step=0.01) 52 | 53 | trainer.setup(algo, env) 54 | trainer.train(n_epochs=100, batch_size=4000) 55 | 56 | 57 | trpo_cartpole() 58 | -------------------------------------------------------------------------------- /src/garage/examples/tf/trpo_cartpole_bullet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with TRPO algorithm. 3 | 4 | Here it runs CartPoleBulletEnv environment with 100 iterations. 5 | 6 | """ 7 | import gym 8 | 9 | from garage import wrap_experiment 10 | from garage.envs.bullet import BulletEnv 11 | from garage.experiment.deterministic import set_seed 12 | from garage.np.baselines import LinearFeatureBaseline 13 | from garage.sampler import RaySampler 14 | from garage.tf.algos import TRPO 15 | from garage.tf.policies import CategoricalMLPPolicy 16 | from garage.trainer import TFTrainer 17 | 18 | 19 | @wrap_experiment 20 | def trpo_cartpole_bullet(ctxt=None, seed=1): 21 | """Train TRPO with Pybullet's CartPoleBulletEnv environment. 22 | 23 | Args: 24 | ctxt (garage.experiment.ExperimentContext): The experiment 25 | configuration used by Trainer to create the snapshotter. 26 | seed (int): Used to seed the random number generator to produce 27 | determinism. 28 | 29 | """ 30 | set_seed(seed) 31 | with TFTrainer(ctxt) as trainer: 32 | env = BulletEnv( 33 | gym.make('CartPoleBulletEnv-v1', 34 | renders=False, 35 | discrete_actions=True)) 36 | 37 | policy = CategoricalMLPPolicy(name='policy', 38 | env_spec=env.spec, 39 | hidden_sizes=(32, 32)) 40 | 41 | baseline = LinearFeatureBaseline(env_spec=env.spec) 42 | 43 | sampler = RaySampler(agents=policy, 44 | envs=env, 45 | max_episode_length=env.spec.max_episode_length, 46 | is_tf_worker=True) 47 | 48 | algo = TRPO(env_spec=env.spec, 49 | policy=policy, 50 | baseline=baseline, 51 | sampler=sampler, 52 | discount=0.99, 53 | max_kl_step=0.01) 54 | 55 | trainer.setup(algo, env) 56 | trainer.train(n_epochs=100, batch_size=4000) 57 | 58 | 59 | trpo_cartpole_bullet() 60 | -------------------------------------------------------------------------------- /src/garage/examples/tf/trpo_gym_tf_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """An example to train a task with TRPO algorithm.""" 3 | from garage import wrap_experiment 4 | from garage.envs import GymEnv 5 | from garage.experiment.deterministic import set_seed 6 | from garage.np.baselines import LinearFeatureBaseline 7 | from garage.sampler import RaySampler 8 | from garage.tf.algos import TRPO 9 | from garage.tf.policies import CategoricalMLPPolicy 10 | from garage.trainer import TFTrainer 11 | 12 | 13 | @wrap_experiment 14 | def trpo_gym_tf_cartpole(ctxt=None, seed=1): 15 | """Train TRPO with CartPole-v0 environment. 16 | 17 | Args: 18 | ctxt (garage.experiment.ExperimentContext): The experiment 19 | configuration used by Trainer to create the snapshotter. 20 | seed (int): Used to seed the random number generator to produce 21 | determinism. 22 | 23 | """ 24 | set_seed(seed) 25 | with TFTrainer(snapshot_config=ctxt) as trainer: 26 | env = GymEnv('CartPole-v0') 27 | 28 | policy = CategoricalMLPPolicy(name='policy', 29 | env_spec=env.spec, 30 | hidden_sizes=(32, 32)) 31 | 32 | baseline = LinearFeatureBaseline(env_spec=env.spec) 33 | 34 | sampler = RaySampler(agents=policy, 35 | envs=env, 36 | max_episode_length=env.spec.max_episode_length, 37 | is_tf_worker=True) 38 | 39 | algo = TRPO( 40 | env_spec=env.spec, 41 | policy=policy, 42 | baseline=baseline, 43 | sampler=sampler, 44 | discount=0.99, 45 | max_kl_step=0.01, 46 | ) 47 | 48 | trainer.setup(algo, env) 49 | trainer.train(n_epochs=120, batch_size=4000) 50 | 51 | 52 | trpo_gym_tf_cartpole() 53 | -------------------------------------------------------------------------------- /src/garage/examples/tf/trpo_swimmer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """An example to train a task with TRPO algorithm.""" 3 | from garage import wrap_experiment 4 | from garage.envs import GymEnv 5 | from garage.experiment.deterministic import set_seed 6 | from garage.np.baselines import LinearFeatureBaseline 7 | from garage.sampler import RaySampler 8 | from garage.tf.algos import TRPO 9 | from garage.tf.policies import GaussianMLPPolicy 10 | from garage.trainer import TFTrainer 11 | 12 | 13 | @wrap_experiment 14 | def trpo_swimmer(ctxt=None, seed=1, batch_size=4000): 15 | """Train TRPO with Swimmer-v2 environment. 16 | 17 | Args: 18 | ctxt (garage.experiment.ExperimentContext): The experiment 19 | configuration used by Trainer to create the snapshotter. 20 | seed (int): Used to seed the random number generator to produce 21 | determinism. 22 | batch_size (int): Number of timesteps to use in each training step. 23 | 24 | """ 25 | set_seed(seed) 26 | with TFTrainer(ctxt) as trainer: 27 | env = GymEnv('Swimmer-v2') 28 | 29 | policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=(32, 32)) 30 | 31 | baseline = LinearFeatureBaseline(env_spec=env.spec) 32 | 33 | sampler = RaySampler(agents=policy, 34 | envs=env, 35 | max_episode_length=env.spec.max_episode_length, 36 | is_tf_worker=True) 37 | 38 | algo = TRPO(env_spec=env.spec, 39 | policy=policy, 40 | baseline=baseline, 41 | sampler=sampler, 42 | discount=0.99, 43 | max_kl_step=0.01) 44 | 45 | trainer.setup(algo, env) 46 | trainer.train(n_epochs=40, batch_size=batch_size) 47 | 48 | 49 | trpo_swimmer() 50 | -------------------------------------------------------------------------------- /src/garage/examples/tf/vpg_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to train a task with VPG algorithm. 3 | 4 | Here it runs CartPole-v1 environment with 100 iterations. 5 | 6 | Results: 7 | AverageReturn: 100 8 | RiseTime: itr 16 9 | """ 10 | from garage import wrap_experiment 11 | from garage.envs import GymEnv 12 | from garage.experiment.deterministic import set_seed 13 | from garage.np.baselines import LinearFeatureBaseline 14 | from garage.sampler import RaySampler 15 | from garage.tf.algos import VPG 16 | from garage.tf.policies import CategoricalMLPPolicy 17 | from garage.trainer import TFTrainer 18 | 19 | 20 | @wrap_experiment 21 | def vpg_cartpole(ctxt=None, seed=1): 22 | """Train VPG with CartPole-v1 environment. 23 | 24 | Args: 25 | ctxt (garage.experiment.ExperimentContext): The experiment 26 | configuration used by Trainer to create the snapshotter. 27 | seed (int): Used to seed the random number generator to produce 28 | determinism. 29 | 30 | """ 31 | set_seed(seed) 32 | with TFTrainer(snapshot_config=ctxt) as trainer: 33 | env = GymEnv('CartPole-v1') 34 | 35 | policy = CategoricalMLPPolicy(name='policy', 36 | env_spec=env.spec, 37 | hidden_sizes=(32, 32)) 38 | 39 | baseline = LinearFeatureBaseline(env_spec=env.spec) 40 | 41 | sampler = RaySampler(agents=policy, 42 | envs=env, 43 | max_episode_length=env.spec.max_episode_length, 44 | is_tf_worker=True) 45 | 46 | algo = VPG(env_spec=env.spec, 47 | policy=policy, 48 | baseline=baseline, 49 | sampler=sampler, 50 | discount=0.99, 51 | optimizer_args=dict(learning_rate=0.01, )) 52 | 53 | trainer.setup(algo, env) 54 | trainer.train(n_epochs=100, batch_size=10000) 55 | 56 | 57 | vpg_cartpole() 58 | -------------------------------------------------------------------------------- /src/garage/examples/torch/resume_training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This is an example to resume training programmatically.""" 3 | # pylint: disable=no-value-for-parameter 4 | import click 5 | 6 | from garage import wrap_experiment 7 | from garage.trainer import Trainer 8 | 9 | 10 | @click.command() 11 | @click.option('--saved_dir', 12 | required=True, 13 | help='Path where snapshots are saved.') 14 | @wrap_experiment 15 | def resume_experiment(ctxt, saved_dir): 16 | """Resume a PyTorch experiment. 17 | 18 | Args: 19 | ctxt (garage.experiment.ExperimentContext): The experiment 20 | configuration used by Trainer to create the snapshotter. 21 | saved_dir (str): Path where snapshots are saved. 22 | 23 | """ 24 | trainer = Trainer(snapshot_config=ctxt) 25 | trainer.restore(from_dir=saved_dir) 26 | trainer.resume() 27 | 28 | 29 | resume_experiment() 30 | -------------------------------------------------------------------------------- /src/garage/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | """Experiment functions.""" 2 | # yapf: disable 3 | from garage.experiment.meta_evaluator import MetaEvaluator 4 | from garage.experiment.snapshotter import SnapshotConfig, Snapshotter 5 | from garage.experiment.task_sampler import (ConstructEnvsSampler, 6 | EnvPoolSampler, 7 | MetaWorldTaskSampler, 8 | SetTaskSampler, TaskSampler) 9 | 10 | # yapf: enable 11 | 12 | __all__ = [ 13 | 'MetaEvaluator', 14 | 'Snapshotter', 15 | 'SnapshotConfig', 16 | 'TaskSampler', 17 | 'ConstructEnvsSampler', 18 | 'EnvPoolSampler', 19 | 'SetTaskSampler', 20 | 'MetaWorldTaskSampler', 21 | ] 22 | -------------------------------------------------------------------------------- /src/garage/experiment/deterministic.py: -------------------------------------------------------------------------------- 1 | """Utilities for ensuring that experiments are deterministic.""" 2 | import random 3 | import sys 4 | import warnings 5 | 6 | import numpy as np 7 | 8 | seed_ = None 9 | seed_stream_ = None 10 | 11 | 12 | def set_seed(seed): 13 | """Set the process-wide random seed. 14 | 15 | Args: 16 | seed (int): A positive integer 17 | 18 | """ 19 | seed %= 4294967294 20 | # pylint: disable=global-statement 21 | global seed_ 22 | global seed_stream_ 23 | seed_ = seed 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | if 'tensorflow' in sys.modules: 27 | import tensorflow as tf # pylint: disable=import-outside-toplevel 28 | tf.compat.v1.set_random_seed(seed) 29 | try: 30 | # pylint: disable=import-outside-toplevel 31 | import tensorflow_probability as tfp 32 | seed_stream_ = tfp.util.SeedStream(seed_, salt='garage') 33 | except ImportError: 34 | pass 35 | if 'torch' in sys.modules: 36 | warnings.warn( 37 | 'Enabeling deterministic mode in PyTorch can have a performance ' 38 | 'impact when using GPU.') 39 | import torch # pylint: disable=import-outside-toplevel 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | 45 | def get_seed(): 46 | """Get the process-wide random seed. 47 | 48 | Returns: 49 | int: The process-wide random seed 50 | 51 | """ 52 | return seed_ 53 | 54 | 55 | def get_tf_seed_stream(): 56 | """Get the pseudo-random number generator (PRNG) for TensorFlow ops. 57 | 58 | Returns: 59 | int: A seed generated by a PRNG with fixed global seed. 60 | 61 | """ 62 | if seed_stream_ is None: 63 | set_seed(0) 64 | return seed_stream_() % 4294967294 65 | -------------------------------------------------------------------------------- /src/garage/np/__init__.py: -------------------------------------------------------------------------------- 1 | """Reinforcement Learning Algorithms which use NumPy as a numerical backend.""" 2 | # yapf: disable 3 | from garage.np._functions import (concat_tensor_dict_list, discount_cumsum, 4 | explained_variance_1d, flatten_tensors, 5 | pad_batch_array, pad_tensor, pad_tensor_dict, 6 | pad_tensor_n, rrse, slice_nested_dict, 7 | sliding_window, 8 | stack_and_pad_tensor_dict_list, 9 | stack_tensor_dict_list, truncate_tensor_dict, 10 | unflatten_tensors) 11 | 12 | # yapf: enable 13 | 14 | __all__ = [ 15 | 'discount_cumsum', 16 | 'explained_variance_1d', 17 | 'flatten_tensors', 18 | 'unflatten_tensors', 19 | 'pad_batch_array', 20 | 'pad_tensor', 21 | 'pad_tensor_n', 22 | 'pad_tensor_dict', 23 | 'stack_tensor_dict_list', 24 | 'stack_and_pad_tensor_dict_list', 25 | 'concat_tensor_dict_list', 26 | 'truncate_tensor_dict', 27 | 'slice_nested_dict', 28 | 'rrse', 29 | 'sliding_window', 30 | ] 31 | -------------------------------------------------------------------------------- /src/garage/np/algos/__init__.py: -------------------------------------------------------------------------------- 1 | """Reinforcement learning algorithms which use NumPy as a numerical backend.""" 2 | from garage.np.algos.cem import CEM 3 | from garage.np.algos.cma_es import CMAES 4 | from garage.np.algos.meta_rl_algorithm import MetaRLAlgorithm 5 | from garage.np.algos.nop import NOP 6 | from garage.np.algos.rl_algorithm import RLAlgorithm 7 | 8 | __all__ = [ 9 | 'RLAlgorithm', 10 | 'CEM', 11 | 'CMAES', 12 | 'MetaRLAlgorithm', 13 | 'NOP', 14 | ] 15 | -------------------------------------------------------------------------------- /src/garage/np/algos/meta_rl_algorithm.py: -------------------------------------------------------------------------------- 1 | """Interface of Meta-RL ALgorithms.""" 2 | import abc 3 | 4 | from garage.np.algos.rl_algorithm import RLAlgorithm 5 | 6 | 7 | class MetaRLAlgorithm(RLAlgorithm, abc.ABC): 8 | """Base class for Meta-RL Algorithms.""" 9 | 10 | @abc.abstractmethod 11 | def get_exploration_policy(self): 12 | """Return a policy used before adaptation to a specific task. 13 | 14 | Each time it is retrieved, this policy should only be evaluated in one 15 | task. 16 | 17 | Returns: 18 | Policy: The policy used to obtain samples, which are later used for 19 | meta-RL adaptation. 20 | 21 | """ 22 | 23 | @abc.abstractmethod 24 | def adapt_policy(self, exploration_policy, exploration_episodes): 25 | """Produce a policy adapted for a task. 26 | 27 | Args: 28 | exploration_policy (Policy): A policy which was returned from 29 | get_exploration_policy(), and which generated 30 | exploration_trajectories by interacting with an environment. 31 | The caller may not use this object after passing it into this 32 | method. 33 | exploration_episodes (EpisodeBatch): Episodes with which to adapt. 34 | These are generated by exploration_policy while exploring the 35 | environment. 36 | 37 | Returns: 38 | Policy: A policy adapted to the task represented by the 39 | exploration_episodes. 40 | 41 | """ 42 | -------------------------------------------------------------------------------- /src/garage/np/algos/nop.py: -------------------------------------------------------------------------------- 1 | """NOP (no optimization performed) policy search algorithm.""" 2 | from garage.np.algos.rl_algorithm import RLAlgorithm 3 | 4 | 5 | class NOP(RLAlgorithm): 6 | """NOP (no optimization performed) policy search algorithm.""" 7 | 8 | def init_opt(self): 9 | """Initialize the optimization procedure.""" 10 | 11 | def optimize_policy(self, paths): 12 | """Optimize the policy using the samples. 13 | 14 | Args: 15 | paths (list[dict]): A list of collected paths. 16 | 17 | """ 18 | 19 | def train(self, trainer): 20 | """Obtain samplers and start actual training for each epoch. 21 | 22 | Args: 23 | trainer (Trainer): Trainer is passed to give algorithm 24 | the access to trainer.step_epochs(), which provides services 25 | such as snapshotting and sampler control. 26 | 27 | """ 28 | -------------------------------------------------------------------------------- /src/garage/np/algos/rl_algorithm.py: -------------------------------------------------------------------------------- 1 | """Interface of RLAlgorithm.""" 2 | import abc 3 | 4 | 5 | class RLAlgorithm(abc.ABC): 6 | """Base class for all the algorithms. 7 | 8 | Note: 9 | If the field sampler_cls exists, it will be by Trainer.setup to 10 | initialize a sampler. 11 | 12 | """ 13 | 14 | # pylint: disable=too-few-public-methods 15 | 16 | @abc.abstractmethod 17 | def train(self, trainer): 18 | """Obtain samplers and start actual training for each epoch. 19 | 20 | Args: 21 | trainer (Trainer): Trainer is passed to give algorithm 22 | the access to trainer.step_epochs(), which provides services 23 | such as snapshotting and sampler control. 24 | 25 | """ 26 | -------------------------------------------------------------------------------- /src/garage/np/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Baselines (value functions) which use NumPy as a numerical backend.""" 2 | from garage.np.baselines.baseline import Baseline 3 | from garage.np.baselines.linear_feature_baseline import LinearFeatureBaseline 4 | from garage.np.baselines.linear_multi_feature_baseline import ( 5 | LinearMultiFeatureBaseline) 6 | from garage.np.baselines.zero_baseline import ZeroBaseline 7 | 8 | __all__ = [ 9 | 'Baseline', 'LinearFeatureBaseline', 'LinearMultiFeatureBaseline', 10 | 'ZeroBaseline' 11 | ] 12 | -------------------------------------------------------------------------------- /src/garage/np/baselines/baseline.py: -------------------------------------------------------------------------------- 1 | """Base class for all baselines.""" 2 | import abc 3 | 4 | 5 | class Baseline(abc.ABC): 6 | """Base class for all baselines.""" 7 | 8 | @abc.abstractmethod 9 | def fit(self, paths): 10 | """Fit regressor based on paths. 11 | 12 | Args: 13 | paths (dict[numpy.ndarray]): Sample paths. 14 | 15 | """ 16 | 17 | @abc.abstractmethod 18 | def predict(self, paths): 19 | """Predict value based on paths. 20 | 21 | Args: 22 | paths (dict[numpy.ndarray]): Sample paths. 23 | 24 | Returns: 25 | numpy.ndarray: Predicted value. 26 | 27 | """ 28 | -------------------------------------------------------------------------------- /src/garage/np/baselines/linear_multi_feature_baseline.py: -------------------------------------------------------------------------------- 1 | """Linear Multi-Feature Baseline.""" 2 | import numpy as np 3 | 4 | from garage.np.baselines import LinearFeatureBaseline 5 | 6 | 7 | class LinearMultiFeatureBaseline(LinearFeatureBaseline): 8 | """A linear value function (baseline) based on features. 9 | 10 | Args: 11 | env_spec (garage.envs.env_spec.EnvSpec): Environment specification. 12 | reg_coeff (float): Regularization coefficient. 13 | features (list[str]): Name of features. 14 | name (str): Name of baseline. 15 | 16 | """ 17 | 18 | def __init__(self, 19 | env_spec, 20 | features=None, 21 | reg_coeff=1e-5, 22 | name='LinearMultiFeatureBaseline'): 23 | super().__init__(env_spec, reg_coeff, name) 24 | features = features or ['observations'] 25 | self._feature_names = features 26 | 27 | def _features(self, path): 28 | """Extract features from path. 29 | 30 | Args: 31 | path (list[dict]): Sample paths. 32 | 33 | Returns: 34 | numpy.ndarray: Extracted features. 35 | 36 | """ 37 | features = [ 38 | np.clip(path[feature_name], -10, 10) 39 | for feature_name in self._feature_names 40 | ] 41 | n = len(path['observations']) 42 | return np.concatenate(sum([[f, f**2] 43 | for f in features], []) + [np.ones((n, 1))], 44 | axis=1) 45 | -------------------------------------------------------------------------------- /src/garage/np/baselines/zero_baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.np.baselines.baseline import Baseline 4 | 5 | 6 | class ZeroBaseline(Baseline): 7 | 8 | def __init__(self, env_spec): 9 | pass 10 | 11 | def get_param_values(self, **kwargs): 12 | return None 13 | 14 | def set_param_values(self, val, **kwargs): 15 | pass 16 | 17 | def fit(self, paths): 18 | pass 19 | 20 | def predict(self, path): 21 | return np.zeros_like(path['rewards']) 22 | 23 | def predict_n(self, paths): 24 | return [np.zeros_like(path['rewards']) for path in paths] 25 | -------------------------------------------------------------------------------- /src/garage/np/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """Embedding encoders and decoders which use NumPy as a numerical backend.""" 2 | from garage.np.embeddings.encoder import Encoder, StochasticEncoder 3 | 4 | __all__ = ['Encoder', 'StochasticEncoder'] 5 | -------------------------------------------------------------------------------- /src/garage/np/embeddings/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for context encoder.""" 2 | import abc 3 | 4 | 5 | class Encoder(abc.ABC): 6 | """Base class of context encoders for training meta-RL algorithms.""" 7 | 8 | @property 9 | @abc.abstractmethod 10 | def spec(self): 11 | """garage.InOutSpec: Input and output space.""" 12 | 13 | @property 14 | @abc.abstractmethod 15 | def input_dim(self): 16 | """int: Dimension of the encoder input.""" 17 | 18 | @property 19 | @abc.abstractmethod 20 | def output_dim(self): 21 | """int: Dimension of the encoder output (embedding).""" 22 | 23 | def reset(self, do_resets=None): 24 | """Reset the encoder. 25 | 26 | This is effective only to recurrent encoder. do_resets is effective 27 | only to vectoried encoder. 28 | 29 | For a vectorized encoder, do_resets is an array of boolean indicating 30 | which internal states to be reset. The length of do_resets should be 31 | equal to the length of inputs. 32 | 33 | Args: 34 | do_resets (numpy.ndarray): Bool array indicating which states 35 | to be reset. 36 | 37 | """ 38 | 39 | 40 | class StochasticEncoder(Encoder): 41 | """An stochastic context encoders. 42 | 43 | An stochastic encoder maps an input to a distribution, but not a 44 | deterministic vector. 45 | 46 | """ 47 | 48 | @property 49 | @abc.abstractmethod 50 | def distribution(self): 51 | """object: Embedding distribution.""" 52 | -------------------------------------------------------------------------------- /src/garage/np/exploration_policies/__init__.py: -------------------------------------------------------------------------------- 1 | """Exploration strategies which use NumPy as a numerical backend.""" 2 | from garage.np.exploration_policies.add_gaussian_noise import AddGaussianNoise 3 | from garage.np.exploration_policies.add_ornstein_uhlenbeck_noise import ( 4 | AddOrnsteinUhlenbeckNoise) 5 | from garage.np.exploration_policies.epsilon_greedy_policy import ( 6 | EpsilonGreedyPolicy) 7 | from garage.np.exploration_policies.exploration_policy import ExplorationPolicy 8 | 9 | __all__ = [ 10 | 'EpsilonGreedyPolicy', 'ExplorationPolicy', 'AddGaussianNoise', 11 | 'AddOrnsteinUhlenbeckNoise' 12 | ] 13 | -------------------------------------------------------------------------------- /src/garage/np/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | """Optimizers which use NumPy as a numerical backend.""" 2 | from garage.np.optimizers.minibatch_dataset import BatchDataset 3 | 4 | __all__ = ['BatchDataset'] 5 | -------------------------------------------------------------------------------- /src/garage/np/optimizers/minibatch_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BatchDataset: 5 | def __init__(self, inputs, batch_size, extra_inputs=None): 6 | self._inputs = [i for i in inputs] 7 | if extra_inputs is None: 8 | extra_inputs = [] 9 | self._extra_inputs = extra_inputs 10 | self._batch_size = batch_size 11 | if batch_size is not None: 12 | self._ids = np.arange(self._inputs[0].shape[0]) 13 | self.update() 14 | 15 | @property 16 | def number_batches(self): 17 | if self._batch_size is None: 18 | return 1 19 | return int(np.ceil(self._inputs[0].shape[0] * 1.0 / self._batch_size)) 20 | 21 | def iterate(self, update=True): 22 | if self._batch_size is None: 23 | yield list(self._inputs) + list(self._extra_inputs) 24 | else: 25 | for itr in range(self.number_batches): 26 | batch_start = itr * self._batch_size 27 | batch_end = (itr + 1) * self._batch_size 28 | batch_ids = self._ids[batch_start:batch_end] 29 | batch = [d[batch_ids] for d in self._inputs] 30 | yield list(batch) + list(self._extra_inputs) 31 | if update: 32 | self.update() 33 | 34 | def update(self): 35 | np.random.shuffle(self._ids) 36 | -------------------------------------------------------------------------------- /src/garage/np/policies/__init__.py: -------------------------------------------------------------------------------- 1 | """Policies which use NumPy as a numerical backend.""" 2 | 3 | from garage.np.policies.fixed_policy import FixedPolicy 4 | from garage.np.policies.policy import Policy 5 | from garage.np.policies.scripted_policy import ScriptedPolicy 6 | from garage.np.policies.uniform_random_policy import UniformRandomPolicy 7 | 8 | __all__ = ['FixedPolicy', 'Policy', 'ScriptedPolicy', 'UniformRandomPolicy'] 9 | -------------------------------------------------------------------------------- /src/garage/np/policies/uniform_random_policy.py: -------------------------------------------------------------------------------- 1 | """Uniform random exploration strategy.""" 2 | import gym 3 | 4 | from garage.np.policies.policy import Policy 5 | 6 | 7 | class UniformRandomPolicy(Policy): 8 | """Action taken is uniformly random. 9 | 10 | Args: 11 | env_spec (EnvSpec): Environment spec to explore. 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | env_spec, 18 | ): 19 | assert isinstance(env_spec.action_space, gym.spaces.Box) 20 | assert len(env_spec.action_space.shape) == 1 21 | self._env_spec = env_spec 22 | self._action_space = env_spec.action_space 23 | self._iteration = 0 24 | 25 | def reset(self, do_resets=None): 26 | """Reset the state of the exploration. 27 | 28 | Args: 29 | do_resets (List[bool] or numpy.ndarray or None): Which 30 | vectorization states to reset. 31 | 32 | """ 33 | self._iteration += 1 34 | super().reset(do_resets) 35 | 36 | def get_action(self, observation): 37 | """Get action from this policy for the input observation. 38 | 39 | Args: 40 | observation(numpy.ndarray): Observation from the environment. 41 | 42 | Returns: 43 | np.ndarray: Actions with noise. 44 | List[dict]: Arbitrary policy state information (agent_info). 45 | 46 | """ 47 | return self._env_spec.action_space.sample(), dict() 48 | 49 | def get_actions(self, observations): 50 | """Get actions from this policy for the input observation. 51 | 52 | Args: 53 | observations(list): Observations from the environment. 54 | 55 | Returns: 56 | np.ndarray: Actions with noise. 57 | List[dict]: Arbitrary policy state information (agent_info). 58 | 59 | """ 60 | return [self._env_spec.action_space.sample() 61 | for obs in observations], dict() 62 | -------------------------------------------------------------------------------- /src/garage/np/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """Q-functions which use NumPy as a numerical backend.""" 2 | from garage.np.q_functions.q_function import QFunction 3 | 4 | __all__ = ['QFunction'] 5 | -------------------------------------------------------------------------------- /src/garage/np/q_functions/q_function.py: -------------------------------------------------------------------------------- 1 | """Base class for Q Functions implemented in numpy.""" 2 | 3 | 4 | class QFunction: 5 | """Q-Function interface.""" 6 | pass 7 | -------------------------------------------------------------------------------- /src/garage/plotter/__init__.py: -------------------------------------------------------------------------------- 1 | from garage.plotter.plotter import Plotter 2 | 3 | __all__ = ['Plotter'] 4 | -------------------------------------------------------------------------------- /src/garage/replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | """Replay buffers. 2 | 3 | The replay buffer primitives can be used for RL algorithms. 4 | """ 5 | from garage.replay_buffer.her_replay_buffer import HERReplayBuffer 6 | from garage.replay_buffer.path_buffer import PathBuffer 7 | from garage.replay_buffer.replay_buffer import ReplayBuffer 8 | 9 | __all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer'] 10 | -------------------------------------------------------------------------------- /src/garage/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Samplers which run agents in environments.""" 2 | # yapf: disable 3 | from garage.sampler._dtypes import InProgressEpisode 4 | from garage.sampler._functions import _apply_env_update 5 | from garage.sampler.default_worker import DefaultWorker 6 | from garage.sampler.env_update import (EnvUpdate, 7 | ExistingEnvUpdate, 8 | NewEnvUpdate, 9 | SetTaskUpdate) 10 | from garage.sampler.fragment_worker import FragmentWorker 11 | from garage.sampler.local_sampler import LocalSampler 12 | from garage.sampler.multiprocessing_sampler import MultiprocessingSampler 13 | from garage.sampler.ray_sampler import RaySampler 14 | from garage.sampler.sampler import Sampler 15 | from garage.sampler.vec_worker import VecWorker 16 | from garage.sampler.worker import Worker 17 | from garage.sampler.worker_factory import WorkerFactory 18 | 19 | # yapf: enable 20 | 21 | __all__ = [ 22 | '_apply_env_update', 23 | 'InProgressEpisode', 24 | 'FragmentWorker', 25 | 'Sampler', 26 | 'LocalSampler', 27 | 'RaySampler', 28 | 'MultiprocessingSampler', 29 | 'VecWorker', 30 | 'WorkerFactory', 31 | 'Worker', 32 | 'DefaultWorker', 33 | 'EnvUpdate', 34 | 'NewEnvUpdate', 35 | 'SetTaskUpdate', 36 | 'ExistingEnvUpdate', 37 | ] 38 | -------------------------------------------------------------------------------- /src/garage/sampler/_functions.py: -------------------------------------------------------------------------------- 1 | """Functions used by multiple Samplers or Workers.""" 2 | from garage import Environment 3 | from garage.sampler.env_update import EnvUpdate 4 | 5 | 6 | def _apply_env_update(old_env, env_update): 7 | """Use any non-None env_update as a new environment. 8 | 9 | A simple env update function. If env_update is not None, it should be 10 | the complete new environment. 11 | 12 | This allows changing environments by passing the new environment as 13 | `env_update` into `obtain_samples`. 14 | 15 | Args: 16 | old_env (Environment): Environment to updated. 17 | env_update (Environment or EnvUpdate or None): The environment to 18 | replace the existing env with. Note that other implementations 19 | of `Worker` may take different types for this parameter. 20 | 21 | Returns: 22 | Environment: The updated environment (may be a different object from 23 | `old_env`). 24 | bool: True if an update happened. 25 | 26 | Raises: 27 | TypeError: If env_update is not one of the documented types. 28 | 29 | """ 30 | if env_update is not None: 31 | if isinstance(env_update, EnvUpdate): 32 | return env_update(old_env), True 33 | elif isinstance(env_update, Environment): 34 | if old_env is not None: 35 | old_env.close() 36 | return env_update, True 37 | else: 38 | raise TypeError('Unknown environment update type.') 39 | else: 40 | return old_env, False 41 | -------------------------------------------------------------------------------- /src/garage/tf/__init__.py: -------------------------------------------------------------------------------- 1 | """Tensorflow Branch.""" 2 | # yapf: disable 3 | from garage.tf._functions import (center_advs, compile_function, 4 | compute_advantages, concat_tensor_dict_list, 5 | concat_tensor_list, discounted_returns, 6 | filter_valids, filter_valids_dict, 7 | flatten_batch, flatten_batch_dict, 8 | flatten_inputs, flatten_tensor_variables, 9 | get_target_ops, graph_inputs, new_tensor, 10 | new_tensor_like, pad_tensor, pad_tensor_dict, 11 | pad_tensor_n, positive_advs, 12 | split_tensor_dict_list, 13 | stack_tensor_dict_list) 14 | 15 | # yapf: enable 16 | 17 | __all__ = [ 18 | 'compile_function', 19 | 'get_target_ops', 20 | 'flatten_batch', 21 | 'flatten_batch_dict', 22 | 'filter_valids', 23 | 'filter_valids_dict', 24 | 'graph_inputs', 25 | 'flatten_inputs', 26 | 'flatten_tensor_variables', 27 | 'new_tensor', 28 | 'new_tensor_like', 29 | 'concat_tensor_list', 30 | 'concat_tensor_dict_list', 31 | 'stack_tensor_dict_list', 32 | 'split_tensor_dict_list', 33 | 'pad_tensor', 34 | 'pad_tensor_n', 35 | 'pad_tensor_dict', 36 | 'compute_advantages', 37 | 'center_advs', 38 | 'positive_advs', 39 | 'discounted_returns', 40 | ] 41 | -------------------------------------------------------------------------------- /src/garage/tf/algos/__init__.py: -------------------------------------------------------------------------------- 1 | """Tensorflow implementation of reinforcement learning algorithms.""" 2 | from garage.tf.algos.ddpg import DDPG 3 | from garage.tf.algos.dqn import DQN 4 | from garage.tf.algos.erwr import ERWR 5 | from garage.tf.algos.npo import NPO 6 | from garage.tf.algos.ppo import PPO 7 | from garage.tf.algos.reps import REPS 8 | from garage.tf.algos.rl2 import RL2 9 | from garage.tf.algos.rl2ppo import RL2PPO 10 | from garage.tf.algos.rl2trpo import RL2TRPO 11 | from garage.tf.algos.td3 import TD3 12 | from garage.tf.algos.te_npo import TENPO 13 | from garage.tf.algos.te_ppo import TEPPO 14 | from garage.tf.algos.tnpg import TNPG 15 | from garage.tf.algos.trpo import TRPO 16 | from garage.tf.algos.vpg import VPG 17 | 18 | __all__ = [ 19 | 'DDPG', 20 | 'DQN', 21 | 'ERWR', 22 | 'NPO', 23 | 'PPO', 24 | 'REPS', 25 | 'RL2', 26 | 'RL2PPO', 27 | 'RL2TRPO', 28 | 'TD3', 29 | 'TNPG', 30 | 'TRPO', 31 | 'VPG', 32 | 'TENPO', 33 | 'TEPPO', 34 | ] 35 | -------------------------------------------------------------------------------- /src/garage/tf/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | """Baseline estimators for TensorFlow-based algorithms.""" 2 | from garage.tf.baselines.continuous_mlp_baseline import ContinuousMLPBaseline 3 | from garage.tf.baselines.gaussian_cnn_baseline import GaussianCNNBaseline 4 | from garage.tf.baselines.gaussian_mlp_baseline import GaussianMLPBaseline 5 | 6 | __all__ = [ 7 | 'ContinuousMLPBaseline', 8 | 'GaussianCNNBaseline', 9 | 'GaussianMLPBaseline', 10 | ] 11 | -------------------------------------------------------------------------------- /src/garage/tf/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """Embeddings.""" 2 | from garage.tf.embeddings.encoder import Encoder, StochasticEncoder 3 | from garage.tf.embeddings.gaussian_mlp_encoder import GaussianMLPEncoder 4 | 5 | __all__ = ['Encoder', 'StochasticEncoder', 'GaussianMLPEncoder'] 6 | -------------------------------------------------------------------------------- /src/garage/tf/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Network Models.""" 2 | from garage.tf.models.categorical_cnn_model import CategoricalCNNModel 3 | from garage.tf.models.categorical_gru_model import CategoricalGRUModel 4 | from garage.tf.models.categorical_lstm_model import CategoricalLSTMModel 5 | from garage.tf.models.categorical_mlp_model import CategoricalMLPModel 6 | from garage.tf.models.cnn_mlp_merge_model import CNNMLPMergeModel 7 | from garage.tf.models.cnn_model import CNNModel 8 | from garage.tf.models.cnn_model_max_pooling import CNNModelWithMaxPooling 9 | from garage.tf.models.gaussian_cnn_model import GaussianCNNModel 10 | from garage.tf.models.gaussian_gru_model import GaussianGRUModel 11 | from garage.tf.models.gaussian_lstm_model import GaussianLSTMModel 12 | from garage.tf.models.gaussian_mlp_model import GaussianMLPModel 13 | from garage.tf.models.gru_model import GRUModel 14 | from garage.tf.models.lstm_model import LSTMModel 15 | from garage.tf.models.mlp_dueling_model import MLPDuelingModel 16 | from garage.tf.models.mlp_merge_model import MLPMergeModel 17 | from garage.tf.models.mlp_model import MLPModel 18 | from garage.tf.models.model import BaseModel, Model 19 | from garage.tf.models.module import Module, StochasticModule 20 | from garage.tf.models.normalized_input_mlp_model import NormalizedInputMLPModel 21 | from garage.tf.models.sequential import Sequential 22 | 23 | __all__ = [ 24 | 'BaseModel', 'CategoricalCNNModel', 'CategoricalGRUModel', 25 | 'CategoricalLSTMModel', 'CategoricalMLPModel', 'CNNMLPMergeModel', 26 | 'CNNModel', 'CNNModelWithMaxPooling', 'LSTMModel', 'Model', 'Module', 27 | 'GaussianCNNModel', 'GaussianGRUModel', 'GaussianLSTMModel', 28 | 'GaussianMLPModel', 'GRUModel', 'MLPDuelingModel', 'MLPMergeModel', 29 | 'MLPModel', 'NormalizedInputMLPModel', 'Sequential', 'StochasticModule' 30 | ] 31 | -------------------------------------------------------------------------------- /src/garage/tf/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | """TensorFlow optimizers.""" 2 | # yapf: disable 3 | from garage.tf.optimizers.conjugate_gradient_optimizer import ( 4 | ConjugateGradientOptimizer) # noqa: E501 5 | from garage.tf.optimizers.conjugate_gradient_optimizer import ( 6 | FiniteDifferenceHVP) # noqa: E501 7 | from garage.tf.optimizers.conjugate_gradient_optimizer import PearlmutterHVP 8 | from garage.tf.optimizers.first_order_optimizer import FirstOrderOptimizer 9 | from garage.tf.optimizers.lbfgs_optimizer import LBFGSOptimizer 10 | from garage.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLBFGSOptimizer 11 | 12 | # yapf: enable 13 | 14 | __all__ = [ 15 | 'ConjugateGradientOptimizer', 'PearlmutterHVP', 'FiniteDifferenceHVP', 16 | 'FirstOrderOptimizer', 'LBFGSOptimizer', 'PenaltyLBFGSOptimizer' 17 | ] 18 | -------------------------------------------------------------------------------- /src/garage/tf/optimizers/_dtypes.py: -------------------------------------------------------------------------------- 1 | """Data types for TensorFlow optimizers.""" 2 | 3 | 4 | class LazyDict: 5 | """An immutable, lazily-evaluated dict. 6 | 7 | Args: 8 | **kwargs (dict[Callable]): Initial lazy key-value pairs. 9 | """ 10 | 11 | def __init__(self, **kwargs): 12 | self._lazy_dict = kwargs 13 | self._dict = {} 14 | 15 | def __getitem__(self, key): 16 | """See :meth:`object.__getitem__`. 17 | 18 | Args: 19 | key (Hashable): Key associated with the value to retrieve. 20 | 21 | Returns: 22 | object: Lazily-evaluated value of the :class:`Callable` associated 23 | with key. 24 | 25 | """ 26 | if key not in self._dict: 27 | self._dict[key] = self._lazy_dict[key]() 28 | return self._dict[key] 29 | 30 | def __setitem__(self, key, value): 31 | """See :meth:`object.__setitem__`. 32 | 33 | Args: 34 | key (Hashable): Key associated with value. 35 | value (Callable): Function which returns the lazy value associated 36 | with key. 37 | 38 | """ 39 | self.set(key, value) 40 | 41 | def get(self, key, default=None): 42 | """See :meth:`dict.get`. 43 | 44 | Args: 45 | key (Hashable): Key associated with the value to retreive. 46 | default (object): Value to return if key is not present in this 47 | :class:`LazyDict`. 48 | 49 | Returns: 50 | object: Value associated with key if the key is present, otherwise 51 | default. 52 | """ 53 | if key in self._lazy_dict: 54 | return self[key] 55 | 56 | return default 57 | 58 | def set(self, key, value): 59 | """See :meth:`dict.set`. 60 | 61 | Args: 62 | key (Hashable): Key associated with value. 63 | value (Callable): Function which returns the lazy value associated 64 | with key. 65 | 66 | """ 67 | self._lazy_dict[key] = value 68 | -------------------------------------------------------------------------------- /src/garage/tf/plotter/__init__.py: -------------------------------------------------------------------------------- 1 | from garage.tf.plotter.plotter import Plotter 2 | 3 | __all__ = ['Plotter'] 4 | -------------------------------------------------------------------------------- /src/garage/tf/policies/__init__.py: -------------------------------------------------------------------------------- 1 | """Policies for TensorFlow-based algorithms.""" 2 | from garage.tf.policies.categorical_cnn_policy import CategoricalCNNPolicy 3 | from garage.tf.policies.categorical_gru_policy import CategoricalGRUPolicy 4 | from garage.tf.policies.categorical_lstm_policy import CategoricalLSTMPolicy 5 | from garage.tf.policies.categorical_mlp_policy import CategoricalMLPPolicy 6 | from garage.tf.policies.continuous_mlp_policy import ContinuousMLPPolicy 7 | from garage.tf.policies.discrete_qf_argmax_policy import DiscreteQFArgmaxPolicy 8 | from garage.tf.policies.gaussian_gru_policy import GaussianGRUPolicy 9 | from garage.tf.policies.gaussian_lstm_policy import GaussianLSTMPolicy 10 | from garage.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 11 | from garage.tf.policies.gaussian_mlp_task_embedding_policy import ( 12 | GaussianMLPTaskEmbeddingPolicy) 13 | from garage.tf.policies.policy import Policy 14 | from garage.tf.policies.task_embedding_policy import TaskEmbeddingPolicy 15 | 16 | __all__ = [ 17 | 'Policy', 'CategoricalCNNPolicy', 'CategoricalGRUPolicy', 18 | 'CategoricalLSTMPolicy', 'CategoricalMLPPolicy', 'ContinuousMLPPolicy', 19 | 'DiscreteQFArgmaxPolicy', 'GaussianGRUPolicy', 'GaussianLSTMPolicy', 20 | 'GaussianMLPPolicy', 'GaussianMLPTaskEmbeddingPolicy', 21 | 'TaskEmbeddingPolicy' 22 | ] 23 | -------------------------------------------------------------------------------- /src/garage/tf/policies/policy.py: -------------------------------------------------------------------------------- 1 | """Base class for policies in TensorFlow.""" 2 | import abc 3 | 4 | from garage.np.policies import Policy as BasePolicy 5 | 6 | 7 | class Policy(BasePolicy): 8 | """Base class for policies in TensorFlow.""" 9 | 10 | @abc.abstractmethod 11 | def get_action(self, observation): 12 | """Get action sampled from the policy. 13 | 14 | Args: 15 | observation (np.ndarray): Observation from the environment. 16 | 17 | Returns: 18 | Tuple[np.ndarray, dict[str,np.ndarray]]: Action and extra agent 19 | info. 20 | 21 | """ 22 | 23 | @abc.abstractmethod 24 | def get_actions(self, observations): 25 | """Get actions given observations. 26 | 27 | Args: 28 | observations (np.ndarray): Observations from the environment. 29 | 30 | Returns: 31 | Tuple[np.ndarray, dict[str,np.ndarray]]: Actions and extra agent 32 | infos. 33 | 34 | """ 35 | 36 | @property 37 | def state_info_specs(self): 38 | """State info specification. 39 | 40 | Returns: 41 | List[str]: keys and shapes for the information related to the 42 | module's state when taking an action. 43 | 44 | """ 45 | return list() 46 | 47 | @property 48 | def state_info_keys(self): 49 | """State info keys. 50 | 51 | Returns: 52 | List[str]: keys for the information related to the module's state 53 | when taking an input. 54 | 55 | """ 56 | return [k for k, _ in self.state_info_specs] 57 | -------------------------------------------------------------------------------- /src/garage/tf/policies/uniform_control_policy.py: -------------------------------------------------------------------------------- 1 | """Uniform control policy.""" 2 | from garage.tf.policies.policy import Policy 3 | 4 | 5 | class UniformControlPolicy(Policy): 6 | """Policy that output random action uniformly. 7 | 8 | Args: 9 | env_spec (garage.envs.env_spec.EnvSpec): Environment specification. 10 | 11 | """ 12 | 13 | def __init__(self, env_spec): 14 | self._env_spec = env_spec 15 | 16 | def get_action(self, observation): 17 | """Get single action from this policy for the input observation. 18 | 19 | Args: 20 | observation (numpy.ndarray): Observation from environment. 21 | 22 | Returns: 23 | numpy.ndarray: Action 24 | dict: Predicted action and agent information. It returns an empty 25 | dict since there is no parameterization. 26 | 27 | """ 28 | return self.action_space.sample(), dict() 29 | 30 | def get_actions(self, observations): 31 | """Get multiple actions from this policy for the input observations. 32 | 33 | Args: 34 | observations (numpy.ndarray): Observations from environment. 35 | 36 | Returns: 37 | numpy.ndarray: Actions 38 | dict: Predicted action and agent information. It returns an empty 39 | dict since there is no parameterization. 40 | 41 | """ 42 | return self.action_space.sample_n(len(observations)), dict() 43 | -------------------------------------------------------------------------------- /src/garage/tf/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """Q-Functions for TensorFlow-based algorithms.""" 2 | # isort:skip_file 3 | 4 | from garage.tf.q_functions.continuous_cnn_q_function import ( 5 | ContinuousCNNQFunction) 6 | from garage.tf.q_functions.continuous_mlp_q_function import ( 7 | ContinuousMLPQFunction) 8 | from garage.tf.q_functions.discrete_cnn_q_function import DiscreteCNNQFunction 9 | from garage.tf.q_functions.discrete_mlp_q_function import DiscreteMLPQFunction 10 | from garage.tf.q_functions.discrete_mlp_dueling_q_function import ( 11 | DiscreteMLPDuelingQFunction) 12 | 13 | __all__ = [ 14 | 'ContinuousMLPQFunction', 'DiscreteCNNQFunction', 'DiscreteMLPQFunction', 15 | 'DiscreteMLPDuelingQFunction', 'ContinuousCNNQFunction' 16 | ] 17 | -------------------------------------------------------------------------------- /src/garage/tf/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | """Samplers which run agents that use Tensorflow in environments.""" 2 | 3 | from garage.tf.samplers.worker import TFWorkerClassWrapper, TFWorkerWrapper 4 | 5 | __all__ = ['TFWorkerClassWrapper', 'TFWorkerWrapper'] 6 | -------------------------------------------------------------------------------- /src/garage/torch/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch-backed modules and algorithms.""" 2 | # yapf: disable 3 | from garage.torch._functions import (as_torch_dict, compute_advantages, 4 | expand_var, filter_valids, flatten_batch, 5 | flatten_to_single_vector, global_device, 6 | NonLinearity, np_to_torch, 7 | output_height_2d, output_width_2d, 8 | pad_to_last, prefer_gpu, 9 | product_of_gaussians, set_gpu_mode, 10 | soft_update_model, state_dict_to, 11 | torch_to_np, update_module_params) 12 | 13 | # yapf: enable 14 | __all__ = [ 15 | 'NonLinearity', 16 | 'as_torch_dict', 17 | 'compute_advantages', 18 | 'expand_var', 19 | 'filter_valids', 20 | 'flatten_batch', 21 | 'flatten_to_single_vector', 22 | 'global_device', 23 | 'np_to_torch', 24 | 'output_height_2d', 25 | 'output_width_2d', 26 | 'pad_to_last', 27 | 'prefer_gpu', 28 | 'product_of_gaussians', 29 | 'set_gpu_mode', 30 | 'soft_update_model', 31 | 'state_dict_to', 32 | 'torch_to_np', 33 | 'update_module_params', 34 | ] 35 | -------------------------------------------------------------------------------- /src/garage/torch/algos/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch algorithms.""" 2 | # isort:skip_file 3 | 4 | from garage.torch.algos.bc import BC 5 | from garage.torch.algos.ddpg import DDPG 6 | # VPG has to be imported first because it is depended by PPO and TRPO. 7 | # PPO, TRPO, and VPG need to be imported before their MAML variants 8 | from garage.torch.algos.dqn import DQN 9 | from garage.torch.algos.vpg import VPG 10 | from garage.torch.algos.maml_vpg import MAMLVPG 11 | from garage.torch.algos.ppo import PPO 12 | from garage.torch.algos.maml_ppo import MAMLPPO 13 | from garage.torch.algos.td3 import TD3 14 | from garage.torch.algos.trpo import TRPO 15 | from garage.torch.algos.maml_trpo import MAMLTRPO 16 | # SAC needs to be imported before MTSAC 17 | from garage.torch.algos.sac import SAC 18 | from garage.torch.algos.mtsac import MTSAC 19 | from garage.torch.algos.pearl import PEARL 20 | 21 | __all__ = [ 22 | 'BC', 'DDPG', 'DQN', 'VPG', 'PPO', 'TD3', 'TRPO', 'MAMLPPO', 'MAMLTRPO', 23 | 'MAMLVPG', 'MTSAC', 'PEARL', 'SAC' 24 | ] 25 | -------------------------------------------------------------------------------- /src/garage/torch/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Custom Distributions.""" 2 | from garage.torch.distributions.tanh_normal import TanhNormal 3 | 4 | __all__ = ['TanhNormal'] 5 | -------------------------------------------------------------------------------- /src/garage/torch/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch embedding modules for meta-learning algorithms.""" 2 | 3 | from garage.torch.embeddings.mlp_encoder import MLPEncoder 4 | 5 | __all__ = ['MLPEncoder'] 6 | -------------------------------------------------------------------------------- /src/garage/torch/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Modules.""" 2 | # yapf: disable 3 | # isort:skip_file 4 | from garage.torch.modules.cnn_module import CNNModule 5 | from garage.torch.modules.gaussian_mlp_module import ( 6 | GaussianMLPIndependentStdModule) # noqa: E501 7 | from garage.torch.modules.gaussian_mlp_module import ( 8 | GaussianMLPTwoHeadedModule) # noqa: E501 9 | from garage.torch.modules.gaussian_mlp_module import GaussianMLPModule 10 | from garage.torch.modules.mlp_module import MLPModule 11 | from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule 12 | # DiscreteCNNModule must go after MLPModule 13 | from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule 14 | # yapf: enable 15 | 16 | __all__ = [ 17 | 'CNNModule', 18 | 'DiscreteCNNModule', 19 | 'MLPModule', 20 | 'MultiHeadedMLPModule', 21 | 'GaussianMLPModule', 22 | 'GaussianMLPIndependentStdModule', 23 | 'GaussianMLPTwoHeadedModule', 24 | ] 25 | -------------------------------------------------------------------------------- /src/garage/torch/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch optimizers.""" 2 | from garage.torch.optimizers.conjugate_gradient_optimizer import ( 3 | ConjugateGradientOptimizer) 4 | from garage.torch.optimizers.differentiable_sgd import DifferentiableSGD 5 | from garage.torch.optimizers.optimizer_wrapper import OptimizerWrapper 6 | 7 | __all__ = [ 8 | 'OptimizerWrapper', 'ConjugateGradientOptimizer', 'DifferentiableSGD' 9 | ] 10 | -------------------------------------------------------------------------------- /src/garage/torch/policies/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Policies.""" 2 | from garage.torch.policies.categorical_cnn_policy import CategoricalCNNPolicy 3 | from garage.torch.policies.context_conditioned_policy import ( 4 | ContextConditionedPolicy) 5 | from garage.torch.policies.deterministic_mlp_policy import ( 6 | DeterministicMLPPolicy) 7 | from garage.torch.policies.discrete_cnn_policy import DiscreteCNNPolicy 8 | from garage.torch.policies.discrete_qf_argmax_policy import ( 9 | DiscreteQFArgmaxPolicy) 10 | from garage.torch.policies.gaussian_mlp_policy import GaussianMLPPolicy 11 | from garage.torch.policies.policy import Policy 12 | from garage.torch.policies.tanh_gaussian_mlp_policy import ( 13 | TanhGaussianMLPPolicy) 14 | 15 | __all__ = [ 16 | 'CategoricalCNNPolicy', 17 | 'DeterministicMLPPolicy', 18 | 'DiscreteCNNPolicy', 19 | 'DiscreteQFArgmaxPolicy', 20 | 'GaussianMLPPolicy', 21 | 'Policy', 22 | 'TanhGaussianMLPPolicy', 23 | 'ContextConditionedPolicy', 24 | ] 25 | -------------------------------------------------------------------------------- /src/garage/torch/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Q-functions.""" 2 | from garage.torch.q_functions.continuous_mlp_q_function import ( 3 | ContinuousMLPQFunction) 4 | from garage.torch.q_functions.discrete_cnn_q_function import ( 5 | DiscreteCNNQFunction) 6 | from garage.torch.q_functions.discrete_dueling_cnn_q_function import ( 7 | DiscreteDuelingCNNQFunction) 8 | from garage.torch.q_functions.discrete_mlp_q_function import ( 9 | DiscreteMLPQFunction) 10 | 11 | __all__ = [ 12 | 'ContinuousMLPQFunction', 'DiscreteCNNQFunction', 13 | 'DiscreteDuelingCNNQFunction', 'DiscreteMLPQFunction' 14 | ] 15 | -------------------------------------------------------------------------------- /src/garage/torch/q_functions/continuous_mlp_q_function.py: -------------------------------------------------------------------------------- 1 | """This modules creates a continuous Q-function network.""" 2 | 3 | import torch 4 | 5 | from garage.torch.modules import MLPModule 6 | 7 | 8 | class ContinuousMLPQFunction(MLPModule): 9 | """Implements a continuous MLP Q-value network. 10 | 11 | It predicts the Q-value for all actions based on the input state. It uses 12 | a PyTorch neural network module to fit the function of Q(s, a). 13 | """ 14 | 15 | def __init__(self, env_spec, **kwargs): 16 | """Initialize class with multiple attributes. 17 | 18 | Args: 19 | env_spec (EnvSpec): Environment specification. 20 | **kwargs: Keyword arguments. 21 | 22 | """ 23 | self._env_spec = env_spec 24 | self._obs_dim = env_spec.observation_space.flat_dim 25 | self._action_dim = env_spec.action_space.flat_dim 26 | 27 | MLPModule.__init__(self, 28 | input_dim=self._obs_dim + self._action_dim, 29 | output_dim=1, 30 | **kwargs) 31 | 32 | # pylint: disable=arguments-differ 33 | def forward(self, observations, actions): 34 | """Return Q-value(s). 35 | 36 | Args: 37 | observations (np.ndarray): observations. 38 | actions (np.ndarray): actions. 39 | 40 | Returns: 41 | torch.Tensor: Output value 42 | """ 43 | return super().forward(torch.cat([observations, actions], 1)) 44 | -------------------------------------------------------------------------------- /src/garage/torch/value_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """Value functions which use PyTorch.""" 2 | from garage.torch.value_functions.gaussian_mlp_value_function import ( 3 | GaussianMLPValueFunction) 4 | from garage.torch.value_functions.value_function import ValueFunction 5 | 6 | __all__ = ['ValueFunction', 'GaussianMLPValueFunction'] 7 | -------------------------------------------------------------------------------- /src/garage/torch/value_functions/value_function.py: -------------------------------------------------------------------------------- 1 | """Base class for all baselines.""" 2 | import abc 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class ValueFunction(abc.ABC, nn.Module): 8 | """Base class for all baselines. 9 | 10 | Args: 11 | env_spec (EnvSpec): Environment specification. 12 | name (str): Value function name, also the variable scope. 13 | 14 | """ 15 | 16 | def __init__(self, env_spec, name): 17 | super(ValueFunction, self).__init__() 18 | 19 | self._mdp_spec = env_spec 20 | self.name = name 21 | 22 | @abc.abstractmethod 23 | def compute_loss(self, obs, returns): 24 | r"""Compute mean value of loss. 25 | 26 | Args: 27 | obs (torch.Tensor): Observation from the environment 28 | with shape :math:`(N \dot [T], O*)`. 29 | returns (torch.Tensor): Acquired returns with shape :math:`(N, )`. 30 | 31 | Returns: 32 | torch.Tensor: Calculated negative mean scalar value of 33 | objective (float). 34 | 35 | """ 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- 1 | """Test fixtures.""" 2 | # yapf: disable 3 | from tests.fixtures.fixtures import (snapshot_config, 4 | TfGraphTestCase, 5 | TfTestCase) 6 | 7 | # yapf: enable 8 | 9 | __all__ = ['snapshot_config', 'TfGraphTestCase', 'TfTestCase'] 10 | -------------------------------------------------------------------------------- /tests/fixtures/algos/__init__.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures.algos.dummy_algo import DummyAlgo 2 | from tests.fixtures.algos.dummy_tf_algo import DummyTFAlgo 3 | 4 | __all__ = ['DummyAlgo', 'DummyTFAlgo'] 5 | -------------------------------------------------------------------------------- /tests/fixtures/algos/dummy_algo.py: -------------------------------------------------------------------------------- 1 | """A dummy algorithm fixture.""" 2 | from garage.np.algos import RLAlgorithm 3 | 4 | 5 | class DummyAlgo(RLAlgorithm): # pylint: disable=too-few-public-methods 6 | """Dummy algo for test. 7 | 8 | Args: 9 | env_spec (garage.envs.EnvSpec): Environment specification. 10 | policy (garage.np.policies.Policy): Policy. 11 | baseline (garage.np.baselines.Baseline): The baseline. 12 | 13 | """ 14 | 15 | def __init__(self, env_spec, policy, baseline): 16 | self.env_spec = env_spec 17 | self.policy = policy 18 | self.baseline = baseline 19 | self.discount = 0.9 20 | self.max_episode_length = 1 21 | self.n_samples = 10 22 | 23 | def train(self, trainer): 24 | """Obtain samplers and start actual training for each epoch. 25 | 26 | See garage.np.algos.RLAlgorithm train(). 27 | 28 | Args: 29 | trainer (Trainer): Trainer is passed to give algorithm 30 | the access to trainer.step_epochs(), which provides services 31 | such as snapshotting and sampler control. 32 | 33 | """ 34 | -------------------------------------------------------------------------------- /tests/fixtures/algos/dummy_tf_algo.py: -------------------------------------------------------------------------------- 1 | """Dummy algorithm.""" 2 | from garage.np.algos import RLAlgorithm 3 | 4 | 5 | class DummyTFAlgo(RLAlgorithm): 6 | """Dummy algorithm.""" 7 | 8 | def init_opt(self): 9 | """Initialize the optimization procedure. 10 | 11 | If using tensorflow, this may include declaring all the variables and 12 | compiling functions. 13 | 14 | """ 15 | 16 | def optimize_policy(self, samples_data): 17 | """Optimize the policy using the samples. 18 | 19 | Args: 20 | samples_data (dict): Processed sample data. 21 | 22 | """ 23 | -------------------------------------------------------------------------------- /tests/fixtures/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/fixtures/envs/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of dummy environments used in testing.""" 2 | from tests.fixtures.envs.dummy.base import DummyEnv 3 | from tests.fixtures.envs.dummy.dummy_box_env import DummyBoxEnv 4 | from tests.fixtures.envs.dummy.dummy_dict_env import DummyDictEnv 5 | from tests.fixtures.envs.dummy.dummy_discrete_2d_env import DummyDiscrete2DEnv 6 | from tests.fixtures.envs.dummy.dummy_discrete_env import DummyDiscreteEnv 7 | from tests.fixtures.envs.dummy.dummy_discrete_pixel_env import ( 8 | DummyDiscretePixelEnv) 9 | from tests.fixtures.envs.dummy.dummy_discrete_pixel_env_baselines import ( 10 | DummyDiscretePixelEnvBaselines) 11 | from tests.fixtures.envs.dummy.dummy_multitask_box_env import ( 12 | DummyMultiTaskBoxEnv) 13 | from tests.fixtures.envs.dummy.dummy_reward_box_env import DummyRewardBoxEnv 14 | 15 | __all__ = [ 16 | 'DummyEnv', 'DummyBoxEnv', 'DummyDictEnv', 'DummyDiscrete2DEnv', 17 | 'DummyDiscreteEnv', 'DummyDiscretePixelEnv', 18 | 'DummyDiscretePixelEnvBaselines', 'DummyMultiTaskBoxEnv', 19 | 'DummyRewardBoxEnv' 20 | ] 21 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/base.py: -------------------------------------------------------------------------------- 1 | """Dummy environment for testing purpose.""" 2 | import gym 3 | 4 | 5 | class DummyEnv(gym.Env): 6 | """Base dummy environment. 7 | 8 | Args: 9 | random (bool): If observations are randomly generated or not. 10 | obs_dim (iterable): Observation space dimension. 11 | action_dim (iterable): Action space dimension. 12 | 13 | """ 14 | 15 | def __init__(self, random, obs_dim=(4, ), action_dim=(2, )): 16 | self.random = random 17 | self.state = None 18 | self._obs_dim = obs_dim 19 | self._action_dim = action_dim 20 | 21 | @property 22 | def observation_space(self): 23 | """Return an observation space.""" 24 | raise NotImplementedError 25 | 26 | @property 27 | def action_space(self): 28 | """Return an action space.""" 29 | raise NotImplementedError 30 | 31 | def reset(self): 32 | """Reset the environment.""" 33 | raise NotImplementedError 34 | 35 | def step(self, action): 36 | """Step the environment. 37 | 38 | Args: 39 | action (int): Action input. 40 | 41 | """ 42 | raise NotImplementedError 43 | 44 | def render(self, mode='human'): 45 | """Render. 46 | 47 | Args: 48 | mode (str): Render mode. 49 | 50 | """ 51 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_box_env.py: -------------------------------------------------------------------------------- 1 | """Dummy akro.Box environment for testing purpose.""" 2 | import akro 3 | import numpy as np 4 | 5 | from tests.fixtures.envs.dummy import DummyEnv 6 | 7 | 8 | class DummyBoxEnv(DummyEnv): 9 | """A dummy gym.spaces.Box environment. 10 | 11 | Args: 12 | random (bool): If observations are randomly generated or not. 13 | obs_dim (iterable): Observation space dimension. 14 | action_dim (iterable): Action space dimension. 15 | 16 | """ 17 | 18 | def __init__(self, random=True, obs_dim=(4, ), action_dim=(2, )): 19 | super().__init__(random, obs_dim, action_dim) 20 | 21 | @property 22 | def observation_space(self): 23 | """Return an observation space. 24 | 25 | Returns: 26 | gym.spaces: The observation space of the environment. 27 | 28 | """ 29 | return akro.Box(low=-1, high=1, shape=self._obs_dim, dtype=np.float32) 30 | 31 | @property 32 | def action_space(self): 33 | """Return an action space. 34 | 35 | Returns: 36 | gym.spaces: The action space of the environment. 37 | 38 | """ 39 | return akro.Box(low=-5.0, 40 | high=5.0, 41 | shape=self._action_dim, 42 | dtype=np.float32) 43 | 44 | def reset(self): 45 | """Reset the environment. 46 | 47 | Returns: 48 | np.ndarray: The observation obtained after reset. 49 | 50 | """ 51 | return np.ones(self._obs_dim, dtype=np.float32) 52 | 53 | def step(self, action): 54 | """Step the environment. 55 | 56 | Args: 57 | action (int): Action input. 58 | 59 | Returns: 60 | np.ndarray: Observation. 61 | float: Reward. 62 | bool: If the environment is terminated. 63 | dict: Environment information. 64 | 65 | """ 66 | return self.observation_space.sample(), 0, False, dict(dummy='dummy') 67 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_discrete_2d_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from tests.fixtures.envs.dummy import DummyEnv 5 | 6 | 7 | class DummyDiscrete2DEnv(DummyEnv): 8 | """A dummy discrete environment.""" 9 | 10 | def __init__(self, random=True): 11 | super().__init__(random) 12 | self.shape = (2, 2) 13 | self._observation_space = gym.spaces.Box( 14 | low=-1, high=1, shape=self.shape, dtype=np.float32) 15 | 16 | @property 17 | def observation_space(self): 18 | """Return an observation space.""" 19 | return self._observation_space 20 | 21 | @observation_space.setter 22 | def observation_space(self, observation_space): 23 | self._observation_space = observation_space 24 | 25 | @property 26 | def action_space(self): 27 | """Return an action space.""" 28 | return gym.spaces.Discrete(2) 29 | 30 | def reset(self): 31 | """Reset the environment.""" 32 | self.state = np.ones(self.shape) 33 | return self.state 34 | 35 | def step(self, action): 36 | """Step the environment.""" 37 | if self.state is not None: 38 | if self.random: 39 | obs = self.observation_space.sample() 40 | else: 41 | obs = self.state + action / 10. 42 | else: 43 | raise RuntimeError( 44 | "DummyEnv: reset() must be called before step()!") 45 | return obs, 0, True, dict() 46 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_discrete_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from tests.fixtures.envs.dummy import DummyEnv 5 | 6 | 7 | class DummyDiscreteEnv(DummyEnv): 8 | """A dummy discrete environment.""" 9 | 10 | def __init__(self, obs_dim=(1, ), action_dim=1, random=True): 11 | super().__init__(random, obs_dim, action_dim) 12 | 13 | @property 14 | def observation_space(self): 15 | """Return an observation space.""" 16 | return gym.spaces.Box( 17 | low=-1, high=1, shape=self._obs_dim, dtype=np.float32) 18 | 19 | @property 20 | def action_space(self): 21 | """Return an action space.""" 22 | return gym.spaces.Discrete(self._action_dim) 23 | 24 | def reset(self): 25 | """Reset the environment.""" 26 | self.state = np.ones(self._obs_dim, dtype=np.float32) 27 | return self.state 28 | 29 | def step(self, action): 30 | """Step the environment.""" 31 | if self.state is not None: 32 | if self.random: 33 | obs = self.observation_space.sample() 34 | else: 35 | obs = self.state + action / 10. 36 | else: 37 | raise RuntimeError( 38 | "DummyEnv: reset() must be called before step()!") 39 | return obs, 0, True, dict() 40 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_discrete_pixel_env_baselines.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from tests.fixtures.envs.dummy import DummyEnv 5 | 6 | 7 | class LazyFrames(object): 8 | def __init__(self, frames): 9 | """ 10 | LazyFrames class from baselines. 11 | 12 | Openai baselines use this class for FrameStack environment 13 | wrapper. It is used for testing garage.envs.wrappers.AtariEnv. 14 | garge.envs.wrapper.AtariEnv is used when algorithms are trained 15 | using baselines wrappers, e.g. during benchmarking. 16 | """ 17 | self._frames = frames 18 | self._out = None 19 | 20 | def _force(self): 21 | if self._out is None: 22 | self._out = np.concatenate(self._frames, axis=-1) 23 | self._frames = None 24 | return self._out 25 | 26 | def __array__(self, dtype=None): 27 | out = self._force() 28 | if dtype is not None: 29 | out = out.astype(dtype) 30 | return out 31 | 32 | 33 | class DummyDiscretePixelEnvBaselines(DummyEnv): 34 | """ 35 | A dummy discrete pixel environment. 36 | 37 | This environment is for testing garge.envs.wrapper.AtariEnv. 38 | """ 39 | 40 | def __init__(self): 41 | super().__init__(random=False, obs_dim=(10, 10, 3), action_dim=5) 42 | self._observation_space = gym.spaces.Box( 43 | low=0, high=255, shape=self._obs_dim, dtype=np.uint8) 44 | 45 | @property 46 | def observation_space(self): 47 | """Return an observation space.""" 48 | return self._observation_space 49 | 50 | @property 51 | def action_space(self): 52 | """Return an action space.""" 53 | return gym.spaces.Discrete(self._action_dim) 54 | 55 | def step(self, action): 56 | """gym.Env step function.""" 57 | obs = self.observation_space.sample() 58 | return LazyFrames([obs]), 0, True, dict() 59 | 60 | def reset(self, **kwargs): 61 | """gym.Env reset function.""" 62 | obs = np.ones(self._obs_dim, dtype=np.uint8) 63 | return LazyFrames([obs]) 64 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_multitask_box_env.py: -------------------------------------------------------------------------------- 1 | """Dummy gym.spaces.Box environment for testing purpose.""" 2 | from random import choices 3 | 4 | from tests.fixtures.envs.dummy import DummyBoxEnv 5 | 6 | 7 | class DummyMultiTaskBoxEnv(DummyBoxEnv): 8 | """A dummy gym.spaces.Box multitask environment. 9 | 10 | Args: 11 | random (bool): If observations are randomly generated or not. 12 | obs_dim (iterable): Observation space dimension. 13 | action_dim (iterable): Action space dimension. 14 | 15 | """ 16 | 17 | def __init__(self, random=True, obs_dim=(4, ), action_dim=(2, )): 18 | super().__init__(random, obs_dim, action_dim) 19 | self.task = 'dummy1' 20 | 21 | def sample_tasks(self, n): 22 | """Sample a list of `num_tasks` tasks. 23 | 24 | Args: 25 | n (int): Number of tasks to sample. 26 | 27 | Returns: 28 | list[str]: A list of tasks. 29 | 30 | """ 31 | return choices(self.all_task_names, k=n) 32 | 33 | @property 34 | def all_task_names(self): 35 | """list[str]: Return a list of dummy task names.""" 36 | return ['dummy1', 'dummy2', 'dummy3'] 37 | 38 | def set_task(self, task): 39 | """Reset with a task. 40 | 41 | Args: 42 | task (str): A task. 43 | 44 | """ 45 | self.task = task 46 | 47 | def step(self, action): 48 | """Step the environment. 49 | 50 | Args: 51 | action (int): Action input. 52 | 53 | Returns: 54 | np.ndarray: Observation. 55 | float: Reward. 56 | bool: If the environment is terminated. 57 | dict: Environment information. 58 | 59 | """ 60 | return (self.observation_space.sample(), 0, False, 61 | dict(dummy='dummy', task_name=self.task)) 62 | -------------------------------------------------------------------------------- /tests/fixtures/envs/dummy/dummy_reward_box_env.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures.envs.dummy import DummyBoxEnv 2 | 3 | 4 | class DummyRewardBoxEnv(DummyBoxEnv): 5 | """A dummy box environment.""" 6 | 7 | def __init__(self, random=True): 8 | super().__init__(random) 9 | 10 | def step(self, action): 11 | """Step the environment.""" 12 | if action == 0: 13 | reward = 10 14 | else: 15 | reward = -10 16 | return self.observation_space.sample(), reward, True, dict() 17 | -------------------------------------------------------------------------------- /tests/fixtures/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures.envs.wrappers.reshape_observation import ReshapeObservation 2 | 3 | __all__ = ['ReshapeObservation'] 4 | -------------------------------------------------------------------------------- /tests/fixtures/envs/wrappers/reshape_observation.py: -------------------------------------------------------------------------------- 1 | """Reshaping Observation for gym.Env.""" 2 | import gym 3 | import numpy as np 4 | 5 | 6 | class ReshapeObservation(gym.Wrapper): 7 | """ 8 | Reshaping Observation wrapper for gym.Env. 9 | 10 | This wrapper convert the observations into the given shape. 11 | 12 | Args: 13 | env (gym.Env): The environment to be wrapped. 14 | shape (list[int]): Target shape to be applied on the observations. 15 | """ 16 | 17 | def __init__(self, env, shape): 18 | super().__init__(env) 19 | print(env.observation_space.shape) 20 | assert np.prod(shape) == np.prod(env.observation_space.shape) 21 | _low = env.observation_space.low.flatten()[0] 22 | _high = env.observation_space.high.flatten()[0] 23 | self._observation_space = gym.spaces.Box( 24 | _low, _high, shape=shape, dtype=env.observation_space.dtype) 25 | self._shape = shape 26 | 27 | @property 28 | def observation_space(self): 29 | """gym.Env observation space.""" 30 | return self._observation_space 31 | 32 | @observation_space.setter 33 | def observation_space(self, observation_space): 34 | self._observation_space = observation_space 35 | 36 | def _observation(self, obs): 37 | return obs.reshape(self._shape) 38 | 39 | def reset(self): 40 | """gym.Env reset function.""" 41 | return self._observation(self.env.reset()) 42 | 43 | def step(self, action): 44 | """gym.Env step function.""" 45 | obs, reward, done, info = self.env.step(action) 46 | return self._observation(obs), reward, done, info 47 | -------------------------------------------------------------------------------- /tests/fixtures/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures.experiment.fixture_experiment import fixture_exp 2 | 3 | __all__ = ['fixture_exp'] 4 | -------------------------------------------------------------------------------- /tests/fixtures/experiment/fixture_experiment.py: -------------------------------------------------------------------------------- 1 | """A dummy experiment fixture.""" 2 | from garage.envs import GymEnv 3 | from garage.np.baselines import LinearFeatureBaseline 4 | from garage.sampler import LocalSampler 5 | from garage.tf.algos import VPG 6 | from garage.tf.policies import CategoricalMLPPolicy 7 | from garage.trainer import TFTrainer 8 | 9 | 10 | # pylint: disable=missing-return-type-doc 11 | def fixture_exp(snapshot_config, sess): 12 | """Dummy fixture experiment function. 13 | 14 | Args: 15 | snapshot_config (garage.experiment.SnapshotConfig): The snapshot 16 | configuration used by Trainer to create the snapshotter. 17 | If None, it will create one with default settings. 18 | sess (tf.Session): An optional TensorFlow session. 19 | A new session will be created immediately if not provided. 20 | 21 | Returns: 22 | np.ndarray: Values of the parameters evaluated in 23 | the current session 24 | 25 | """ 26 | with TFTrainer(snapshot_config=snapshot_config, sess=sess) as trainer: 27 | env = GymEnv('CartPole-v1', max_episode_length=100) 28 | 29 | policy = CategoricalMLPPolicy(name='policy', 30 | env_spec=env.spec, 31 | hidden_sizes=(8, 8)) 32 | 33 | baseline = LinearFeatureBaseline(env_spec=env.spec) 34 | 35 | sampler = LocalSampler(agents=policy, 36 | envs=env, 37 | max_episode_length=env.spec.max_episode_length) 38 | 39 | algo = VPG(env_spec=env.spec, 40 | policy=policy, 41 | baseline=baseline, 42 | sampler=sampler, 43 | discount=0.99, 44 | optimizer_args=dict(learning_rate=0.01, )) 45 | 46 | trainer.setup(algo, env) 47 | trainer.train(n_epochs=5, batch_size=100) 48 | 49 | return policy.get_param_values() 50 | -------------------------------------------------------------------------------- /tests/fixtures/logger.py: -------------------------------------------------------------------------------- 1 | from dowel import LogOutput, TabularInput 2 | 3 | 4 | class NullOutput(LogOutput): 5 | """Dummy output to disable 'no logger output' warnings.""" 6 | 7 | @property 8 | def types_accepted(self): 9 | """Accept all output types.""" 10 | return (object, ) 11 | 12 | def record(self, data, prefix=''): 13 | """Don't do anything.""" 14 | if isinstance(data, TabularInput): 15 | data.mark_all() 16 | -------------------------------------------------------------------------------- /tests/fixtures/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Mock models for testing.""" 2 | from tests.fixtures.models.simple_categorical_gru_model import ( 3 | SimpleCategoricalGRUModel) 4 | from tests.fixtures.models.simple_categorical_lstm_model import ( 5 | SimpleCategoricalLSTMModel) 6 | from tests.fixtures.models.simple_categorical_mlp_model import ( 7 | SimpleCategoricalMLPModel) 8 | from tests.fixtures.models.simple_cnn_model import SimpleCNNModel 9 | from tests.fixtures.models.simple_cnn_model_with_max_pooling import ( 10 | SimpleCNNModelWithMaxPooling) 11 | from tests.fixtures.models.simple_gru_model import SimpleGRUModel 12 | from tests.fixtures.models.simple_lstm_model import SimpleLSTMModel 13 | from tests.fixtures.models.simple_mlp_merge_model import SimpleMLPMergeModel 14 | from tests.fixtures.models.simple_mlp_model import SimpleMLPModel 15 | 16 | __all__ = [ 17 | 'SimpleCategoricalGRUModel', 18 | 'SimpleCategoricalLSTMModel', 19 | 'SimpleCategoricalMLPModel', 20 | 'SimpleCNNModel', 21 | 'SimpleCNNModelWithMaxPooling', 22 | 'SimpleGRUModel', 23 | 'SimpleLSTMModel', 24 | 'SimpleMLPMergeModel', 25 | 'SimpleMLPModel', 26 | ] 27 | -------------------------------------------------------------------------------- /tests/fixtures/models/simple_categorical_gru_model.py: -------------------------------------------------------------------------------- 1 | """Simple CategoricalGRUModel for testing.""" 2 | import tensorflow_probability as tfp 3 | 4 | from tests.fixtures.models.simple_gru_model import SimpleGRUModel 5 | 6 | 7 | class SimpleCategoricalGRUModel(SimpleGRUModel): 8 | """Simple CategoricalGRUModel for testing. 9 | 10 | Args: 11 | output_dim (int): Dimension of the network output. 12 | hidden_dim (int): Hidden dimension for GRU cell. 13 | name (str): Policy name, also the variable scope. 14 | args: Extra arguments which are not used. 15 | kwargs: Extra keyword arguments which are not used. 16 | 17 | """ 18 | 19 | def __init__(self, output_dim, hidden_dim, name, *args, **kwargs): 20 | super().__init__(output_dim, hidden_dim, name) 21 | 22 | def network_output_spec(self): 23 | """Network output spec. 24 | 25 | Returns: 26 | list[str]: Name of the model outputs, in order. 27 | 28 | """ 29 | return [ 30 | 'all_output', 'step_output', 'step_hidden', 'init_hidden', 'dist' 31 | ] 32 | 33 | def _build(self, obs_input, step_obs_input, step_hidden, name=None): 34 | """Build model. 35 | 36 | Args: 37 | obs_input (tf.Tensor): Entire time-series observation input. 38 | step_obs_input (tf.Tensor): Single timestep observation input. 39 | step_hidden (tf.Tensor): Hidden state for step. 40 | name (str): Name of the model, also the name scope. 41 | 42 | Returns: 43 | tf.Tensor: Entire time-series outputs. 44 | tf.Tensor: Step output. 45 | tf.Tensor: Step hidden state. 46 | tf.Tensor: Initial hidden state. 47 | tfp.distributions.OneHotCategorical: Distribution. 48 | 49 | """ 50 | outputs, output, step_hidden, hidden_init_var = super()._build( 51 | obs_input, step_obs_input, step_hidden, name) 52 | dist = tfp.distributions.OneHotCategorical(outputs) 53 | return outputs, output, step_hidden, hidden_init_var, dist 54 | -------------------------------------------------------------------------------- /tests/fixtures/models/simple_categorical_mlp_model.py: -------------------------------------------------------------------------------- 1 | """Simple CategoricalMLPModel for testing.""" 2 | import tensorflow_probability as tfp 3 | 4 | from tests.fixtures.models.simple_mlp_model import SimpleMLPModel 5 | 6 | 7 | class SimpleCategoricalMLPModel(SimpleMLPModel): 8 | """Simple CategoricalMLPModel for testing. 9 | 10 | Args: 11 | output_dim (int): Dimension of the network output. 12 | name (str): Policy name, also the variable scope. 13 | args: Extra arguments which are not used. 14 | kwargs: Extra keyword arguments which are not used. 15 | 16 | """ 17 | 18 | def __init__(self, output_dim, name, *args, **kwargs): 19 | super().__init__(output_dim, name) 20 | 21 | def network_output_spec(self): 22 | """Network output spec. 23 | 24 | Returns: 25 | list[str]: Name of the model outputs, in order. 26 | 27 | """ 28 | return ['prob', 'dist'] 29 | 30 | def _build(self, obs_input, name=None): 31 | """Build model. 32 | 33 | Args: 34 | obs_input (tf.Tensor): Observation inputs. 35 | name (str): Name of the model, also the name scope. 36 | 37 | Returns: 38 | tf.Tensor: Network outputs. 39 | tfp.distributions.OneHotCategorical: Distribution. 40 | 41 | """ 42 | prob = super()._build(obs_input, name) 43 | dist = tfp.distributions.OneHotCategorical(prob) 44 | return prob, dist 45 | -------------------------------------------------------------------------------- /tests/fixtures/models/simple_mlp_merge_model.py: -------------------------------------------------------------------------------- 1 | """Simple MLPMergeModel for testing.""" 2 | import tensorflow as tf 3 | 4 | from garage.tf.models import Model 5 | 6 | 7 | class SimpleMLPMergeModel(Model): 8 | """Simple SimpleMLPMergeModel for testing. 9 | 10 | Args: 11 | output_dim (int): Dimension of the network output. 12 | name (str): Model name, also the variable scope. 13 | args (list): Unused positionl arguments. 14 | kwargs (dict): Unused keyword arguments. 15 | 16 | """ 17 | 18 | # pylint: disable=arguments-differ 19 | 20 | def __init__(self, output_dim, *args, name=None, **kwargs): 21 | del args 22 | del kwargs 23 | super().__init__(name) 24 | self.output_dim = output_dim 25 | 26 | def network_input_spec(self): 27 | """Network input spec. 28 | 29 | Return: 30 | list[str]: List of key(str) for the network outputs. 31 | 32 | """ 33 | return ['input_var1', 'input_var2'] 34 | 35 | def _build(self, obs_input, act_input, name=None): 36 | """Build model given input placeholder(s). 37 | 38 | Args: 39 | obs_input (tf.Tensor): Tensor input for state. 40 | act_input (tf.Tensor): Tensor input for action. 41 | name (str): Inner model name, also the variable scope of the 42 | inner model, if exist. One example is 43 | garage.tf.models.Sequential. 44 | 45 | Return: 46 | tf.Tensor: Tensor output of the model. 47 | 48 | """ 49 | del name 50 | del act_input 51 | return_var = tf.compat.v1.get_variable( 52 | 'return_var', (), initializer=tf.constant_initializer(0.5)) 53 | return tf.fill((tf.shape(obs_input)[0], self.output_dim), return_var) 54 | -------------------------------------------------------------------------------- /tests/fixtures/models/simple_mlp_model.py: -------------------------------------------------------------------------------- 1 | """Simple MLPModel for testing.""" 2 | import tensorflow as tf 3 | 4 | from garage.tf.models import Model 5 | 6 | 7 | class SimpleMLPModel(Model): 8 | """Simple MLPModel for testing. 9 | 10 | Args: 11 | output_dim (int): Dimension of the network output. 12 | name (str): Model name, also the variable scope. 13 | args (list): Unused positionl arguments. 14 | kwargs (dict): Unused keyword arguments. 15 | 16 | """ 17 | 18 | # pylint: disable=arguments-differ 19 | def __init__(self, output_dim, *args, name=None, **kwargs): 20 | del args 21 | del kwargs 22 | super().__init__(name) 23 | self.output_dim = output_dim 24 | 25 | def _build(self, obs_input, name=None): 26 | """Build model given input placeholder(s). 27 | 28 | Args: 29 | obs_input (tf.Tensor): Tensor input for state. 30 | name (str): Inner model name, also the variable scope of the 31 | inner model, if exist. One example is 32 | garage.tf.models.Sequential. 33 | 34 | Return: 35 | tf.Tensor: Tensor output of the model. 36 | 37 | """ 38 | del name 39 | return_var = tf.compat.v1.get_variable( 40 | 'return_var', (), initializer=tf.constant_initializer(0.5)) 41 | return tf.fill((tf.shape(obs_input)[0], self.output_dim), return_var) 42 | -------------------------------------------------------------------------------- /tests/fixtures/policies/__init__.py: -------------------------------------------------------------------------------- 1 | """Fake policies for writing tests.""" 2 | from tests.fixtures.policies.dummy_policy import (DummyPolicy, 3 | DummyPolicyWithoutVectorized) 4 | from tests.fixtures.policies.dummy_recurrent_policy import DummyRecurrentPolicy 5 | 6 | __all__ = [ 7 | 'DummyPolicy', 'DummyRecurrentPolicy', 'DummyPolicyWithoutVectorized' 8 | ] 9 | -------------------------------------------------------------------------------- /tests/fixtures/policies/dummy_recurrent_policy.py: -------------------------------------------------------------------------------- 1 | """Dummy Recurrent Policy for algo tests.""" 2 | import numpy as np 3 | 4 | from garage.np.policies import Policy 5 | 6 | 7 | class DummyRecurrentPolicy(Policy): 8 | """Dummy Recurrent Policy. 9 | 10 | Args: 11 | env_spec (garage.envs.env_spec.EnvSpec): Environment specification. 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | env_spec, 18 | ): 19 | super().__init__(env_spec=env_spec) 20 | self.params = [] 21 | self.param_values = np.random.uniform(-1, 1, 1000) 22 | 23 | def get_action(self, observation): 24 | """Get single action from this policy for the input observation. 25 | 26 | Args: 27 | observation (numpy.ndarray): Observation from environment. 28 | 29 | Returns: 30 | numpy.ndarray: Predicted action. 31 | dict: Distribution parameters. Empty because no distribution is 32 | used. 33 | 34 | """ 35 | return self.action_space.sample(), dict() 36 | 37 | def get_params_internal(self): 38 | """Return a list of policy internal params. 39 | 40 | Returns: 41 | list: Policy parameters. 42 | 43 | """ 44 | return self.params 45 | 46 | def get_param_values(self): 47 | """Return values of params. 48 | 49 | Returns: 50 | np.ndarray: Policy parameters values. 51 | 52 | """ 53 | return self.param_values 54 | -------------------------------------------------------------------------------- /tests/fixtures/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures.q_functions.simple_q_function import SimpleQFunction 2 | 3 | __all__ = ['SimpleQFunction'] 4 | -------------------------------------------------------------------------------- /tests/fixtures/q_functions/simple_q_function.py: -------------------------------------------------------------------------------- 1 | """Simple QFunction for testing.""" 2 | import tensorflow as tf 3 | 4 | from tests.fixtures.models import SimpleMLPModel 5 | 6 | 7 | class SimpleQFunction(SimpleMLPModel): 8 | """Simple QFunction for testing. 9 | 10 | Args: 11 | env_spec (garage.envs.env_spec.EnvSpec): Environment specification. 12 | name (str): Name of the q-function, also serves as the variable scope. 13 | 14 | """ 15 | 16 | def __init__(self, env_spec, name='SimpleQFunction'): 17 | self.obs_dim = (env_spec.observation_space.flat_dim, ) 18 | action_dim = env_spec.observation_space.flat_dim 19 | super().__init__(output_dim=action_dim, name=name) 20 | 21 | self._q_val = None 22 | 23 | self._initialize() 24 | 25 | def _initialize(self): 26 | """Initialize QFunction.""" 27 | obs_ph = tf.compat.v1.placeholder(tf.float32, (None, ) + self.obs_dim, 28 | name='obs') 29 | 30 | self._q_val = super().build(obs_ph).outputs 31 | 32 | @property 33 | def q_vals(self): 34 | """Return the Q values, the output of the network. 35 | 36 | Return: 37 | list[tf.Tensor]: Q values. 38 | 39 | """ 40 | return self._q_val 41 | 42 | def __setstate__(self, state): 43 | """Object.__setstate__. 44 | 45 | Args: 46 | state (dict): Unpickled state. 47 | 48 | """ 49 | super().__setstate__(state) 50 | self._initialize() 51 | 52 | def __getstate__(self): 53 | """Object.__getstate__. 54 | 55 | Returns: 56 | dict: the state to be pickled for the instance. 57 | 58 | """ 59 | new_dict = super().__getstate__() 60 | del new_dict['_q_val'] 61 | return new_dict 62 | -------------------------------------------------------------------------------- /tests/fixtures/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Fixtures for testing samplers.""" 2 | 3 | from tests.fixtures.sampler.ray_fixtures import (ray_local_session_fixture, 4 | ray_session_fixture) 5 | 6 | __all__ = ['ray_local_session_fixture', 'ray_session_fixture'] 7 | -------------------------------------------------------------------------------- /tests/fixtures/sampler/ray_fixtures.py: -------------------------------------------------------------------------------- 1 | """Pytest fixtures for intializing ray during ray related tests.""" 2 | import pytest 3 | import ray 4 | 5 | 6 | @pytest.fixture(scope='function') 7 | def ray_local_session_fixture(): 8 | """Initializes Ray and shuts down Ray in local mode. 9 | 10 | Yields: 11 | None: Yield is for purposes of pytest module style. 12 | All statements before the yield are apart of module setup, and all 13 | statements after the yield are apart of module teardown. 14 | 15 | """ 16 | if not ray.is_initialized(): 17 | ray.init(local_mode=True, 18 | ignore_reinit_error=True, 19 | log_to_driver=False, 20 | include_dashboard=False) 21 | yield 22 | if ray.is_initialized(): 23 | ray.shutdown() 24 | 25 | 26 | @pytest.fixture(scope='function') 27 | def ray_session_fixture(): 28 | """Initializes Ray and shuts down Ray. 29 | 30 | Yields: 31 | None: Yield is for purposes of pytest module style. 32 | All statements before the yield are apart of module setup, and all 33 | statements after the yield are apart of module teardown. 34 | 35 | """ 36 | if not ray.is_initialized(): 37 | ray.init(_memory=52428800, 38 | object_store_memory=78643200, 39 | ignore_reinit_error=True, 40 | log_to_driver=False, 41 | include_dashboard=False) 42 | yield 43 | if ray.is_initialized(): 44 | ray.shutdown() 45 | -------------------------------------------------------------------------------- /tests/fixtures/tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/fixtures/tf/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/tf/algos/dummy_off_policy_algo.py: -------------------------------------------------------------------------------- 1 | """A dummy off-policy algorithm.""" 2 | from garage.np.algos import RLAlgorithm 3 | 4 | 5 | class DummyOffPolicyAlgo(RLAlgorithm): 6 | """A dummy off-policy algorithm.""" 7 | 8 | def init_opt(self): 9 | """Initialize the optimization procedure.""" 10 | 11 | def train(self, trainer): 12 | """Obtain samplers and start actual training for each epoch. 13 | 14 | Args: 15 | trainer (Trainer): Trainer is passed to give algorithm 16 | the access to trainer.step_epochs(), which provides services 17 | such as snapshotting and sampler control. 18 | 19 | """ 20 | 21 | def train_once(self, itr, paths): 22 | """Perform one step of policy optimization given one batch of samples. 23 | 24 | Args: 25 | itr (int): Iteration number. 26 | paths (list[dict]): A list of collected paths. 27 | 28 | """ 29 | 30 | def optimize_policy(self, samples_data): 31 | """Optimize the policy using the samples. 32 | 33 | Args: 34 | samples_data (dict): Processed sample data. 35 | 36 | """ 37 | -------------------------------------------------------------------------------- /tests/garage/.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | # Check docstring completeness 3 | load-plugins = pylint.extensions.docparams, pylint.extensions.docstyle 4 | 5 | # Unit tests have a special configuration which is checked separately 6 | ignore = tests/garage 7 | 8 | # Go as fast as you can 9 | jobs = 0 10 | 11 | # Packages which we need to load so we can see their C extensions 12 | extension-pkg-whitelist = 13 | numpy.random, 14 | mpi4py.MPI, 15 | 16 | 17 | [MESSAGES CONTROL] 18 | enable = all 19 | disable = 20 | # Style rules handled by yapf/flake8/isort 21 | bad-continuation, 22 | invalid-name, 23 | line-too-long, 24 | ungrouped-imports, 25 | wrong-import-order, 26 | # Algorithms and neural networks generally have a lot of variables 27 | too-many-instance-attributes, 28 | too-many-arguments, 29 | too-many-locals, 30 | # Detection seems buggy or unhelpful 31 | duplicate-code, 32 | # Rules disabled *for unit tests only* 33 | attribute-defined-outside-init, 34 | differing-param-doc, 35 | differing-type-doc, 36 | docstring-first-line-empty, 37 | missing-docstring, 38 | missing-param-doc, 39 | missing-return-doc, 40 | missing-return-type-doc, 41 | missing-type-doc, 42 | no-self-use, 43 | protected-access, 44 | redefined-outer-name, 45 | too-few-public-methods, 46 | unused-import, 47 | 48 | 49 | [REPORTS] 50 | msg-template = {path}:{line:3d},{column}: {msg} ({symbol}) 51 | output-format = colorized 52 | 53 | 54 | [TYPECHECK] 55 | # Packages which might not admit static analysis because they have C extensions 56 | generated-members = torch.* 57 | -------------------------------------------------------------------------------- /tests/garage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/envs/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/box2d/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/envs/box2d/parser/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/bullet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/envs/bullet/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/dm_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/envs/dm_control/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/dm_control/test_dm_control_tf_policy.py: -------------------------------------------------------------------------------- 1 | from dm_control.suite import ALL_TASKS 2 | import pytest 3 | 4 | from garage.envs.dm_control import DMControlEnv 5 | from garage.np.baselines import LinearFeatureBaseline 6 | from garage.sampler import LocalSampler 7 | from garage.tf.algos import TRPO 8 | from garage.tf.policies import GaussianMLPPolicy 9 | from garage.trainer import TFTrainer 10 | 11 | from tests.fixtures import snapshot_config, TfGraphTestCase 12 | 13 | 14 | @pytest.mark.mujoco 15 | class TestDmControlTfPolicy(TfGraphTestCase): 16 | 17 | def test_dm_control_tf_policy(self): 18 | task = ALL_TASKS[0] 19 | 20 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 21 | env = DMControlEnv.from_suite(*task) 22 | 23 | policy = GaussianMLPPolicy( 24 | env_spec=env.spec, 25 | hidden_sizes=(32, 32), 26 | ) 27 | 28 | baseline = LinearFeatureBaseline(env_spec=env.spec) 29 | 30 | sampler = LocalSampler( 31 | agents=policy, 32 | envs=env, 33 | max_episode_length=env.spec.max_episode_length, 34 | is_tf_worker=True) 35 | 36 | algo = TRPO( 37 | env_spec=env.spec, 38 | policy=policy, 39 | baseline=baseline, 40 | sampler=sampler, 41 | discount=0.99, 42 | max_kl_step=0.01, 43 | ) 44 | 45 | trainer.setup(algo, env) 46 | trainer.train(n_epochs=1, batch_size=10) 47 | 48 | env.close() 49 | -------------------------------------------------------------------------------- /tests/garage/envs/test_grid_world_env.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from garage.envs.grid_world_env import GridWorldEnv 4 | 5 | from tests.helpers import step_env 6 | 7 | 8 | class TestGridWorldEnv: 9 | 10 | def test_pickleable(self): 11 | env = GridWorldEnv(desc='8x8') 12 | round_trip = pickle.loads(pickle.dumps(env)) 13 | assert round_trip 14 | assert round_trip._start_state == env._start_state 15 | step_env(round_trip) 16 | round_trip.close() 17 | env.close() 18 | 19 | def test_does_not_modify_action(self): 20 | env = GridWorldEnv(desc='8x8') 21 | a = env.action_space.sample() 22 | a_copy = a 23 | env.reset() 24 | env.step(a) 25 | assert a == a_copy 26 | env.close() 27 | -------------------------------------------------------------------------------- /tests/garage/envs/test_half_cheetah_meta_envs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | try: 7 | # pylint: disable=unused-import 8 | import mujoco_py # noqa: F401 9 | except ImportError: 10 | pytest.skip('To use mujoco-based features, please install garage[mujoco].', 11 | allow_module_level=True) 12 | except Exception: # pylint: disable=broad-except 13 | pytest.skip( 14 | 'Skipping tests, failed to import mujoco. Do you have a ' 15 | 'valid mujoco key installed?', 16 | allow_module_level=True) 17 | 18 | from garage.envs.mujoco.half_cheetah_dir_env import HalfCheetahDirEnv # isort:skip # noqa: E501 19 | from garage.envs.mujoco.half_cheetah_vel_env import HalfCheetahVelEnv # isort:skip # noqa: E501 20 | 21 | 22 | @pytest.mark.mujoco 23 | @pytest.mark.parametrize('env_type', [HalfCheetahVelEnv, HalfCheetahDirEnv]) 24 | def test_can_sim(env_type): 25 | env = env_type() 26 | task = env.sample_tasks(1)[0] 27 | env.set_task(task) 28 | for _ in range(3): 29 | env.step(env.action_space.sample()) 30 | 31 | 32 | @pytest.mark.mujoco 33 | @pytest.mark.parametrize('env_type', [HalfCheetahVelEnv, HalfCheetahDirEnv]) 34 | def test_pickling_keeps_goal(env_type): 35 | env = env_type() 36 | task = env.sample_tasks(1)[0] 37 | env.set_task(task) 38 | env_clone = pickle.loads(pickle.dumps(env)) 39 | assert env._task == env_clone._task 40 | 41 | 42 | @pytest.mark.mujoco 43 | @pytest.mark.parametrize('env_type', [HalfCheetahVelEnv, HalfCheetahDirEnv]) 44 | def test_env_infos(env_type): 45 | env = env_type() 46 | task = env.sample_tasks(1)[0] 47 | env.set_task(task) 48 | _, _, _, infos = env.step(env.action_space.sample()) 49 | for k in infos: 50 | if k == 'task_name': 51 | assert isinstance(infos[k], str) 52 | else: 53 | assert isinstance(infos[k], np.ndarray) 54 | -------------------------------------------------------------------------------- /tests/garage/envs/test_normalized_env.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | 5 | from garage.envs import PointEnv 6 | from garage.envs.normalized_env import NormalizedEnv 7 | 8 | from tests.helpers import step_env 9 | 10 | 11 | class TestNormalizedEnv: 12 | 13 | def test_pickleable(self): 14 | inner_env = PointEnv(goal=(1., 2.)) 15 | env = NormalizedEnv(inner_env, scale_reward=10.) 16 | round_trip = pickle.loads(pickle.dumps(env)) 17 | assert round_trip 18 | assert round_trip._scale_reward == env._scale_reward 19 | assert np.array_equal(round_trip._env._goal, env._env._goal) 20 | step_env(round_trip, visualize=False) 21 | env.close() 22 | round_trip.close() 23 | 24 | def test_does_not_modify_action(self): 25 | inner_env = PointEnv(goal=(1., 2.)) 26 | env = NormalizedEnv(inner_env, scale_reward=10.) 27 | a = env.action_space.high + 1. 28 | a_copy = a 29 | env.reset() 30 | env.step(a) 31 | assert np.array_equal(a, a_copy) 32 | env.close() 33 | 34 | def test_visualization(self): 35 | inner_env = PointEnv(goal=(1., 2.)) 36 | env = NormalizedEnv(inner_env) 37 | 38 | env.visualize() 39 | env.reset() 40 | assert inner_env.render_modes == env.render_modes 41 | mode = inner_env.render_modes[0] 42 | assert inner_env.render(mode) == env.render(mode) 43 | 44 | def test_no_flatten_obs(self): 45 | inner_env = PointEnv(goal=(1., 2.)) 46 | env = NormalizedEnv(inner_env, flatten_obs=False) 47 | obs = env.reset()[0] 48 | 49 | assert obs.shape == env.observation_space.shape 50 | -------------------------------------------------------------------------------- /tests/garage/envs/test_normalized_gym.py: -------------------------------------------------------------------------------- 1 | from garage.envs import GymEnv, normalize 2 | 3 | 4 | class TestNormalizedGym: 5 | 6 | def setup_method(self): 7 | self.env = normalize(GymEnv('CartPole-v1'), 8 | normalize_reward=True, 9 | normalize_obs=True, 10 | flatten_obs=True) 11 | 12 | def teardown_method(self): 13 | self.env.close() 14 | 15 | def test_does_not_modify_action(self): 16 | a = self.env.action_space.sample() 17 | a_copy = a 18 | self.env.reset() 19 | self.env.step(a) 20 | assert a == a_copy 21 | 22 | def test_flatten(self): 23 | for _ in range(10): 24 | self.env.reset() 25 | self.env.visualize() 26 | for _ in range(5): 27 | action = self.env.action_space.sample() 28 | es = self.env.step(action) 29 | next_obs, done = es.observation, es.terminal 30 | assert next_obs.shape == self.env.observation_space.low.shape 31 | if done: 32 | break 33 | 34 | def test_unflatten(self): 35 | for _ in range(10): 36 | self.env.reset() 37 | for _ in range(5): 38 | action = self.env.action_space.sample() 39 | es = self.env.step(action) 40 | next_obs, done = es.observation, es.terminal 41 | # yapf: disable 42 | assert (self.env.observation_space.flatten(next_obs).shape 43 | == self.env.observation_space.flat_dim) 44 | # yapf: enable 45 | if done: 46 | break 47 | -------------------------------------------------------------------------------- /tests/garage/envs/test_rl2_env.py: -------------------------------------------------------------------------------- 1 | from garage.envs import PointEnv 2 | from garage.tf.algos.rl2 import RL2Env 3 | 4 | 5 | class TestRL2Env: 6 | 7 | # pylint: disable=unsubscriptable-object 8 | def test_observation_dimension(self): 9 | env = PointEnv() 10 | wrapped_env = RL2Env(PointEnv()) 11 | assert wrapped_env.spec.observation_space.shape[0] == ( 12 | env.observation_space.shape[0] + env.action_space.shape[0] + 2) 13 | obs, _ = env.reset() 14 | obs2, _ = wrapped_env.reset() 15 | assert obs.shape[0] + env.action_space.shape[0] + 2 == obs2.shape[0] 16 | obs = env.step(env.action_space.sample()).observation 17 | obs2 = wrapped_env.step(env.action_space.sample()).observation 18 | assert obs.shape[0] + env.action_space.shape[0] + 2 == obs2.shape[0] 19 | 20 | def test_step(self): 21 | env = RL2Env(PointEnv()) 22 | 23 | env.reset() 24 | es = env.step(env.action_space.sample()) 25 | assert env.observation_space.contains(es.observation) 26 | 27 | def test_visualization(self): 28 | env = PointEnv() 29 | wrapped_env = RL2Env(env) 30 | 31 | assert env.render_modes == wrapped_env.render_modes 32 | mode = env.render_modes[0] 33 | assert env.render(mode) == wrapped_env.render(mode) 34 | 35 | wrapped_env.reset() 36 | wrapped_env.visualize() 37 | wrapped_env.step(wrapped_env.action_space.sample()) 38 | wrapped_env.close() 39 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/envs/wrappers/__init__.py -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_atari_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.envs.wrappers import AtariEnv 4 | 5 | from tests.fixtures.envs.dummy import DummyDiscretePixelEnvBaselines 6 | 7 | 8 | class TestFireReset: 9 | 10 | def test_atari_env(self): 11 | env = DummyDiscretePixelEnvBaselines() 12 | env_wrapped = AtariEnv(env) 13 | obs = env.reset() 14 | obs_wrapped = env_wrapped.reset() 15 | assert not isinstance(obs, np.ndarray) 16 | assert isinstance(obs_wrapped, np.ndarray) 17 | 18 | obs, _, _, _ = env.step(1) 19 | obs_wrapped, _, _, _ = env_wrapped.step(1) 20 | assert not isinstance(obs, np.ndarray) 21 | assert isinstance(obs_wrapped, np.ndarray) 22 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_clip_reward.py: -------------------------------------------------------------------------------- 1 | from garage.envs.wrappers import ClipReward 2 | 3 | from tests.fixtures.envs.dummy import DummyRewardBoxEnv 4 | 5 | 6 | class TestClipReward: 7 | 8 | def test_clip_reward(self): 9 | # reward = 10 when action = 0, otherwise -10 10 | env = DummyRewardBoxEnv(random=True) 11 | env_wrap = ClipReward(env) 12 | env.reset() 13 | env_wrap.reset() 14 | 15 | _, reward, _, _ = env.step(0) 16 | _, reward_wrap, _, _ = env_wrap.step(0) 17 | 18 | assert reward == 10 19 | assert reward_wrap == 1 20 | 21 | _, reward, _, _ = env.step(1) 22 | _, reward_wrap, _, _ = env_wrap.step(1) 23 | 24 | assert reward == -10 25 | assert reward_wrap == -1 26 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_episodic_life.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.envs.wrappers import EpisodicLife 4 | 5 | from tests.fixtures.envs.dummy import DummyDiscretePixelEnv 6 | 7 | 8 | class TestEpisodicLife: 9 | 10 | def test_episodic_life_reset(self): 11 | env = EpisodicLife(DummyDiscretePixelEnv()) 12 | obs = env.reset() 13 | 14 | # env has reset 15 | assert np.array_equal(obs, np.ones(env.observation_space.shape)) 16 | assert env.unwrapped.ale.lives() == 5 17 | 18 | obs, _, d, info = env.step(0) 19 | assert d 20 | assert info['ale.lives'] == 4 21 | obs = env.reset() 22 | 23 | # env has not reset 24 | assert not np.array_equal(obs, np.ones(env.observation_space.shape)) 25 | 26 | for _ in range(3): 27 | obs, _, d, info = env.step(0) 28 | assert d 29 | assert info['ale.lives'] == 0 30 | obs = env.reset() 31 | # env has reset 32 | assert np.array_equal(obs, np.ones(env.observation_space.shape)) 33 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_fire_reset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.envs.wrappers import FireReset 4 | 5 | from tests.fixtures.envs.dummy import DummyDiscretePixelEnv 6 | 7 | 8 | class TestFireReset: 9 | 10 | def test_fire_reset(self): 11 | env = DummyDiscretePixelEnv(random=False) 12 | env_wrap = FireReset(env) 13 | obs = env.reset() 14 | obs_wrap = env_wrap.reset() 15 | 16 | assert np.array_equal(obs, np.ones(env.observation_space.shape)) 17 | assert np.array_equal(obs_wrap, np.full(env.observation_space.shape, 18 | 3)) 19 | 20 | env_wrap.step(2) 21 | obs_wrap = env_wrap.reset() # env will call reset again, after fire 22 | assert np.array_equal(obs_wrap, np.full(env.observation_space.shape, 23 | 3)) 24 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_max_and_skip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.envs.wrappers import MaxAndSkip 4 | 5 | from tests.fixtures.envs.dummy import DummyDiscretePixelEnv 6 | 7 | 8 | class TestMaxAndSkip: 9 | 10 | def setup_method(self): 11 | self.env = DummyDiscretePixelEnv(random=False) 12 | self.env_wrap = MaxAndSkip(DummyDiscretePixelEnv(random=False), skip=4) 13 | 14 | def teardown_method(self): 15 | self.env.close() 16 | self.env_wrap.close() 17 | 18 | def test_max_and_skip_reset(self): 19 | np.testing.assert_array_equal(self.env.reset(), self.env_wrap.reset()) 20 | 21 | def test_max_and_skip_step(self): 22 | self.env.reset() 23 | self.env_wrap.reset() 24 | obs_wrap, reward_wrap, _, _ = self.env_wrap.step(1) 25 | reward = 0 26 | for _ in range(4): 27 | obs, r, _, _ = self.env.step(1) 28 | reward += r 29 | 30 | np.testing.assert_array_equal(obs, obs_wrap) 31 | np.testing.assert_array_equal(reward, reward_wrap) 32 | 33 | # done=True because both env stepped more than 4 times in total 34 | obs_wrap, _, done_wrap, _ = self.env_wrap.step(1) 35 | obs, _, done, _ = self.env.step(1) 36 | 37 | assert done 38 | assert done_wrap 39 | np.testing.assert_array_equal(obs, obs_wrap) 40 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_noop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from garage.envs.wrappers import Noop 4 | 5 | from tests.fixtures.envs.dummy import DummyDiscretePixelEnv 6 | 7 | 8 | class TestNoop: 9 | 10 | def test_noop(self): 11 | env = Noop(DummyDiscretePixelEnv(), noop_max=3) 12 | 13 | for _ in range(1000): 14 | env.reset() 15 | assert 1 <= env.env.step_called <= 3 16 | 17 | env = Noop(DummyDiscretePixelEnv(), noop_max=10) 18 | for _ in range(1000): 19 | obs = env.reset() 20 | if env.env.step_called % 5 == 0: 21 | # There are only 5 lives in the environment, so if number of 22 | # steps are multiple of 5, env will call reset at last. 23 | assert np.array_equal(obs, 24 | np.ones(env.observation_space.shape)) 25 | else: 26 | assert not np.array_equal(obs, 27 | np.ones(env.observation_space.shape)) 28 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_pixel_observation_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | 5 | from garage.envs.wrappers import PixelObservationWrapper 6 | 7 | 8 | @pytest.mark.mujoco 9 | class TestPixelObservationWrapper: 10 | 11 | def setup_method(self): 12 | self.env = gym.make('InvertedDoublePendulum-v2') 13 | self.pixel_env = PixelObservationWrapper(self.env) 14 | 15 | def teardown_method(self): 16 | self.env.close() 17 | self.pixel_env.close() 18 | 19 | def test_pixel_env_invalid_environment_type(self): 20 | with pytest.raises(ValueError): 21 | self.env.observation_space = gym.spaces.Discrete(64) 22 | PixelObservationWrapper(self.env) 23 | 24 | def test_pixel_env_observation_space(self): 25 | assert isinstance(self.pixel_env.observation_space, gym.spaces.Box) 26 | assert (self.pixel_env.observation_space.low == 0).all() 27 | assert (self.pixel_env.observation_space.high == 255).all() 28 | 29 | def test_pixel_env_reset(self): 30 | obs = self.pixel_env.reset() 31 | assert (obs <= 255.).all() and (obs >= 0.).all() 32 | assert isinstance(obs, np.ndarray) 33 | 34 | def test_pixel_env_step(self): 35 | self.pixel_env.reset() 36 | action = np.full(self.pixel_env.action_space.shape, 0) 37 | obs, _, _, _ = self.pixel_env.step(action) 38 | assert (obs <= 255.).all() and (obs >= 0.).all() 39 | -------------------------------------------------------------------------------- /tests/garage/envs/wrappers/test_resize_env.py: -------------------------------------------------------------------------------- 1 | import gym.spaces 2 | import numpy as np 3 | import pytest 4 | 5 | from garage.envs.wrappers import Resize 6 | 7 | from tests.fixtures.envs.dummy import DummyDiscrete2DEnv 8 | 9 | 10 | class TestResize: 11 | 12 | def setup_method(self): 13 | self.width = 16 14 | self.height = 16 15 | self.env = DummyDiscrete2DEnv() 16 | self.env_r = Resize(DummyDiscrete2DEnv(), 17 | width=self.width, 18 | height=self.height) 19 | 20 | def teardown_method(self): 21 | self.env.close() 22 | self.env_r.close() 23 | 24 | def test_resize_invalid_environment_type(self): 25 | with pytest.raises(ValueError): 26 | self.env.observation_space = gym.spaces.Discrete(64) 27 | Resize(self.env, width=self.width, height=self.height) 28 | 29 | def test_resize_invalid_environment_shape(self): 30 | with pytest.raises(ValueError): 31 | self.env.observation_space = gym.spaces.Box(low=0, 32 | high=255, 33 | shape=(4, ), 34 | dtype=np.uint8) 35 | Resize(self.env, width=self.width, height=self.height) 36 | 37 | def test_resize_output_observation_space(self): 38 | assert self.env_r.observation_space.shape == (self.width, self.height) 39 | 40 | def test_resize_output_reset(self): 41 | assert self.env_r.reset().shape == (self.width, self.height) 42 | 43 | def test_resize_output_step(self): 44 | self.env_r.reset() 45 | obs_r, _, _, _ = self.env_r.step(1) 46 | assert obs_r.shape == (self.width, self.height) 47 | -------------------------------------------------------------------------------- /tests/garage/experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/experiment/__init__.py -------------------------------------------------------------------------------- /tests/garage/experiment/test_snapshotter_integration.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | 5 | from garage.envs import GymEnv 6 | from garage.experiment import SnapshotConfig, Snapshotter 7 | from garage.tf.algos import VPG 8 | from garage.tf.policies import CategoricalMLPPolicy 9 | 10 | from tests.fixtures import TfGraphTestCase 11 | from tests.fixtures.experiment import fixture_exp 12 | 13 | configurations = [('last', 4), ('first', 0), (3, 3)] 14 | 15 | 16 | class TestSnapshot(TfGraphTestCase): 17 | 18 | def setup_method(self): 19 | super().setup_method() 20 | self.temp_dir = tempfile.TemporaryDirectory() 21 | snapshot_config = SnapshotConfig(snapshot_dir=self.temp_dir.name, 22 | snapshot_mode='all', 23 | snapshot_gap=1) 24 | fixture_exp(snapshot_config, self.sess) 25 | for c in self.graph.collections: 26 | self.graph.clear_collection(c) 27 | 28 | def teardown_method(self): 29 | self.temp_dir.cleanup() 30 | super().teardown_method() 31 | 32 | @pytest.mark.parametrize('load_mode, last_epoch', [*configurations]) 33 | def test_load(self, load_mode, last_epoch): 34 | snapshotter = Snapshotter() 35 | saved = snapshotter.load(self.temp_dir.name, load_mode) 36 | 37 | assert isinstance(saved['algo'], VPG) 38 | assert isinstance(saved['env'], GymEnv) 39 | assert isinstance(saved['algo'].policy, CategoricalMLPPolicy) 40 | assert saved['stats'].total_epoch == last_epoch 41 | 42 | def test_load_with_invalid_load_mode(self): 43 | snapshotter = Snapshotter() 44 | with pytest.raises(ValueError): 45 | snapshotter.load(self.temp_dir.name, 'foo') 46 | -------------------------------------------------------------------------------- /tests/garage/np/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/np/__init__.py -------------------------------------------------------------------------------- /tests/garage/np/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/np/algos/__init__.py -------------------------------------------------------------------------------- /tests/garage/np/algos/test_cem.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from garage.envs import GymEnv 4 | from garage.np.algos import CEM 5 | from garage.sampler import LocalSampler 6 | from garage.tf.policies import CategoricalMLPPolicy 7 | from garage.trainer import TFTrainer 8 | 9 | from tests.fixtures import snapshot_config, TfGraphTestCase 10 | 11 | 12 | class TestCEM(TfGraphTestCase): 13 | 14 | @pytest.mark.large 15 | def test_cem_cartpole(self): 16 | """Test CEM with Cartpole-v1 environment.""" 17 | with TFTrainer(snapshot_config) as trainer: 18 | env = GymEnv('CartPole-v1') 19 | 20 | policy = CategoricalMLPPolicy(name='policy', 21 | env_spec=env.spec, 22 | hidden_sizes=(32, 32)) 23 | 24 | n_samples = 10 25 | 26 | sampler = LocalSampler( 27 | agents=policy, 28 | envs=env, 29 | max_episode_length=env.spec.max_episode_length, 30 | is_tf_worker=True) 31 | 32 | algo = CEM(env_spec=env.spec, 33 | policy=policy, 34 | sampler=sampler, 35 | best_frac=0.1, 36 | n_samples=n_samples) 37 | 38 | trainer.setup(algo, env) 39 | rtn = trainer.train(n_epochs=10, batch_size=2048) 40 | assert rtn > 40 41 | 42 | env.close() 43 | -------------------------------------------------------------------------------- /tests/garage/np/algos/test_cma_es.py: -------------------------------------------------------------------------------- 1 | from garage.envs import GymEnv 2 | from garage.np.algos import CMAES 3 | from garage.sampler import LocalSampler 4 | from garage.tf.policies import CategoricalMLPPolicy 5 | from garage.trainer import TFTrainer 6 | 7 | from tests.fixtures import snapshot_config, TfGraphTestCase 8 | 9 | 10 | class TestCMAES(TfGraphTestCase): 11 | 12 | def test_cma_es_cartpole(self): 13 | """Test CMAES with Cartpole-v1 environment.""" 14 | with TFTrainer(snapshot_config) as trainer: 15 | env = GymEnv('CartPole-v1') 16 | 17 | policy = CategoricalMLPPolicy(name='policy', 18 | env_spec=env.spec, 19 | hidden_sizes=(32, 32)) 20 | 21 | n_samples = 20 22 | 23 | sampler = LocalSampler( 24 | agents=policy, 25 | envs=env, 26 | max_episode_length=env.spec.max_episode_length, 27 | is_tf_worker=True) 28 | 29 | algo = CMAES(env_spec=env.spec, 30 | policy=policy, 31 | sampler=sampler, 32 | n_samples=n_samples) 33 | 34 | trainer.setup(algo, env) 35 | trainer.train(n_epochs=1, batch_size=1000) 36 | # No assertion on return because CMAES is not stable. 37 | 38 | env.close() 39 | -------------------------------------------------------------------------------- /tests/garage/np/exploration_strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/np/exploration_strategies/__init__.py -------------------------------------------------------------------------------- /tests/garage/np/policies/test_fixed_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from garage.np.policies import FixedPolicy 5 | 6 | 7 | def test_vectorization_multi_raises(): 8 | policy = FixedPolicy(None, np.array([1, 2, 3])) 9 | with pytest.raises(ValueError): 10 | policy.reset([True, True]) 11 | with pytest.raises(ValueError): 12 | policy.get_actions(np.array([0, 0])) 13 | 14 | 15 | def test_get_actions(): 16 | policy = FixedPolicy(None, np.array([1, 2, 3])) 17 | assert policy.get_actions(np.array([0]).reshape(1, 1))[0] == 1 18 | assert policy.get_action(np.array([0]))[0] == 2 19 | assert policy.get_action(np.array([0]))[0] == 3 20 | with pytest.raises(IndexError): 21 | policy.get_action(np.ndarray([0])) 22 | -------------------------------------------------------------------------------- /tests/garage/np/policies/test_scripted_policy.py: -------------------------------------------------------------------------------- 1 | from garage.np.policies import ScriptedPolicy 2 | 3 | 4 | class TestScriptedPolicy: 5 | 6 | def setup_method(self): 7 | self.sp = ScriptedPolicy(scripted_actions=[1], agent_env_infos={0: 1}) 8 | 9 | """ 10 | potentially add more tests down the line 11 | """ 12 | 13 | def test_pass_codecov(self): 14 | self.sp.get_action(0) 15 | self.sp.get_actions([0]) 16 | -------------------------------------------------------------------------------- /tests/garage/np/policies/test_uniform_random_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from garage.envs import GymEnv, normalize 5 | from garage.np.policies import UniformRandomPolicy 6 | 7 | 8 | @pytest.mark.mujoco 9 | def test_get_actions(): 10 | env = normalize(GymEnv('InvertedDoublePendulum-v2')) 11 | policy = UniformRandomPolicy(env.spec) 12 | assert policy.get_actions(np.array([0]).reshape(1, 1))[0] 13 | assert policy.get_action(np.array([0]))[0] 14 | -------------------------------------------------------------------------------- /tests/garage/replay_buffer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/replay_buffer/__init__.py -------------------------------------------------------------------------------- /tests/garage/sampler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/sampler/__init__.py -------------------------------------------------------------------------------- /tests/garage/sampler/test_env_update.py: -------------------------------------------------------------------------------- 1 | from garage.sampler import SetTaskUpdate 2 | 3 | from tests.fixtures.envs.dummy import DummyBoxEnv 4 | 5 | TEST_TASK = ['test_task'] 6 | 7 | 8 | class MTDummyEnv(DummyBoxEnv): 9 | 10 | def set_task(self, task): 11 | assert task == TEST_TASK 12 | 13 | 14 | class MTDummyEnvSubtype(MTDummyEnv): 15 | pass 16 | 17 | 18 | def test_set_task_update_with_subtype(): 19 | old_env = MTDummyEnvSubtype() 20 | env_update = SetTaskUpdate(MTDummyEnv, TEST_TASK, None) 21 | new_env = env_update(old_env) 22 | assert new_env is not old_env 23 | assert new_env is not None 24 | assert old_env is not None 25 | -------------------------------------------------------------------------------- /tests/garage/sampler/test_rl2_worker.py: -------------------------------------------------------------------------------- 1 | from garage.envs import GymEnv 2 | from garage.tf.algos.rl2 import RL2Worker 3 | 4 | from tests.fixtures import TfGraphTestCase 5 | from tests.fixtures.envs.dummy import DummyBoxEnv 6 | from tests.fixtures.policies import DummyPolicy 7 | 8 | 9 | class TestRL2Worker(TfGraphTestCase): 10 | 11 | def test_rl2_worker(self): 12 | env = GymEnv(DummyBoxEnv(obs_dim=(1, ))) 13 | policy = DummyPolicy(env_spec=env.spec) 14 | worker = RL2Worker(seed=1, 15 | max_episode_length=100, 16 | worker_number=1, 17 | n_episodes_per_trial=5) 18 | worker.update_agent(policy) 19 | worker.update_env(env) 20 | episodes = worker.rollout() 21 | assert episodes.rewards.shape[0] == 500 22 | -------------------------------------------------------------------------------- /tests/garage/tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/algos/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/algos/test_erwr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from garage.envs import GymEnv 4 | from garage.experiment import deterministic 5 | from garage.np.baselines import LinearFeatureBaseline 6 | from garage.sampler import LocalSampler 7 | from garage.tf.algos import ERWR 8 | from garage.tf.policies import CategoricalMLPPolicy 9 | from garage.trainer import TFTrainer 10 | 11 | from tests.fixtures import snapshot_config, TfGraphTestCase 12 | 13 | 14 | class TestERWR(TfGraphTestCase): 15 | 16 | @pytest.mark.flaky 17 | @pytest.mark.large 18 | def test_erwr_cartpole(self): 19 | """Test ERWR with Cartpole-v1 environment.""" 20 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 21 | deterministic.set_seed(1) 22 | env = GymEnv('CartPole-v1') 23 | 24 | policy = CategoricalMLPPolicy(name='policy', 25 | env_spec=env.spec, 26 | hidden_sizes=(32, 32)) 27 | 28 | baseline = LinearFeatureBaseline(env_spec=env.spec) 29 | 30 | sampler = LocalSampler( 31 | agents=policy, 32 | envs=env, 33 | max_episode_length=env.spec.max_episode_length, 34 | is_tf_worker=True) 35 | 36 | algo = ERWR(env_spec=env.spec, 37 | policy=policy, 38 | baseline=baseline, 39 | sampler=sampler, 40 | discount=0.99) 41 | 42 | trainer.setup(algo, env) 43 | 44 | last_avg_ret = trainer.train(n_epochs=10, batch_size=10000) 45 | assert last_avg_ret > 60 46 | 47 | env.close() 48 | -------------------------------------------------------------------------------- /tests/garage/tf/algos/test_reps.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a test that fails when garage.tf.algos.REPS performance is 3 | too low. 4 | """ 5 | import pytest 6 | 7 | from garage.envs import GymEnv 8 | from garage.np.baselines import LinearFeatureBaseline 9 | from garage.sampler import LocalSampler 10 | from garage.tf.algos import REPS 11 | from garage.tf.policies import CategoricalMLPPolicy 12 | from garage.trainer import TFTrainer 13 | 14 | from tests.fixtures import snapshot_config, TfGraphTestCase 15 | 16 | 17 | class TestREPS(TfGraphTestCase): 18 | 19 | @pytest.mark.large 20 | def test_reps_cartpole(self): 21 | """Test REPS with gym Cartpole environment.""" 22 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 23 | env = GymEnv('CartPole-v0') 24 | 25 | policy = CategoricalMLPPolicy(env_spec=env.spec, 26 | hidden_sizes=[32, 32]) 27 | 28 | baseline = LinearFeatureBaseline(env_spec=env.spec) 29 | 30 | sampler = LocalSampler( 31 | agents=policy, 32 | envs=env, 33 | max_episode_length=env.spec.max_episode_length, 34 | is_tf_worker=True) 35 | 36 | algo = REPS(env_spec=env.spec, 37 | policy=policy, 38 | baseline=baseline, 39 | sampler=sampler, 40 | discount=0.99) 41 | 42 | trainer.setup(algo, env) 43 | 44 | last_avg_ret = trainer.train(n_epochs=10, batch_size=4000) 45 | assert last_avg_ret > 5 46 | 47 | env.close() 48 | -------------------------------------------------------------------------------- /tests/garage/tf/algos/test_tnpg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from garage.envs import GymEnv, normalize 4 | from garage.np.baselines import LinearFeatureBaseline 5 | from garage.sampler import LocalSampler 6 | from garage.tf.algos import TNPG 7 | from garage.tf.policies import GaussianMLPPolicy 8 | from garage.trainer import TFTrainer 9 | 10 | from tests.fixtures import snapshot_config, TfGraphTestCase 11 | 12 | 13 | class TestTNPG(TfGraphTestCase): 14 | 15 | @pytest.mark.mujoco_long 16 | def test_tnpg_inverted_pendulum(self): 17 | """Test TNPG with InvertedPendulum-v2 environment.""" 18 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 19 | env = normalize(GymEnv('InvertedPendulum-v2')) 20 | 21 | policy = GaussianMLPPolicy(name='policy', 22 | env_spec=env.spec, 23 | hidden_sizes=(32, 32)) 24 | 25 | baseline = LinearFeatureBaseline(env_spec=env.spec) 26 | 27 | sampler = LocalSampler( 28 | agents=policy, 29 | envs=env, 30 | max_episode_length=env.spec.max_episode_length, 31 | is_tf_worker=True) 32 | 33 | algo = TNPG(env_spec=env.spec, 34 | policy=policy, 35 | baseline=baseline, 36 | sampler=sampler, 37 | discount=0.99, 38 | optimizer_args=dict(reg_coeff=5e-1)) 39 | 40 | trainer.setup(algo, env) 41 | 42 | last_avg_ret = trainer.train(n_epochs=10, batch_size=10000) 43 | assert last_avg_ret > 15 44 | 45 | env.close() 46 | -------------------------------------------------------------------------------- /tests/garage/tf/algos/test_vpg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from garage.envs import GymEnv 4 | from garage.np.baselines import LinearFeatureBaseline 5 | from garage.sampler import LocalSampler 6 | from garage.tf.algos import VPG 7 | from garage.tf.policies import CategoricalMLPPolicy 8 | from garage.trainer import TFTrainer 9 | 10 | from tests.fixtures import snapshot_config, TfGraphTestCase 11 | 12 | 13 | class TestVPG(TfGraphTestCase): 14 | 15 | @pytest.mark.large 16 | def test_vpg_cartpole(self): 17 | """Test VPG with CartPole-v1 environment.""" 18 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 19 | env = GymEnv('CartPole-v1') 20 | 21 | policy = CategoricalMLPPolicy(name='policy', 22 | env_spec=env.spec, 23 | hidden_sizes=(32, 32)) 24 | 25 | baseline = LinearFeatureBaseline(env_spec=env.spec) 26 | 27 | sampler = LocalSampler( 28 | agents=policy, 29 | envs=env, 30 | max_episode_length=env.spec.max_episode_length, 31 | is_tf_worker=True) 32 | 33 | algo = VPG(env_spec=env.spec, 34 | policy=policy, 35 | baseline=baseline, 36 | sampler=sampler, 37 | discount=0.99, 38 | optimizer_args=dict(learning_rate=0.01, )) 39 | 40 | trainer.setup(algo, env) 41 | 42 | last_avg_ret = trainer.train(n_epochs=10, batch_size=10000) 43 | assert last_avg_ret > 90 44 | 45 | env.close() 46 | -------------------------------------------------------------------------------- /tests/garage/tf/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/baselines/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/baselines/test_baselines.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script creates a test that fails when 3 | garage.tf.baselines failed to initialize. 4 | """ 5 | import tensorflow as tf 6 | 7 | from garage.envs import GymEnv 8 | from garage.tf.baselines import ContinuousMLPBaseline, GaussianMLPBaseline 9 | 10 | from tests.fixtures import TfGraphTestCase 11 | from tests.fixtures.envs.dummy import DummyBoxEnv 12 | 13 | 14 | class TestTfBaselines(TfGraphTestCase): 15 | 16 | def test_baseline(self): 17 | """Test the baseline initialization.""" 18 | box_env = GymEnv(DummyBoxEnv()) 19 | deterministic_mlp_baseline = ContinuousMLPBaseline(env_spec=box_env) 20 | gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env) 21 | 22 | self.sess.run(tf.compat.v1.global_variables_initializer()) 23 | deterministic_mlp_baseline.get_param_values() 24 | gaussian_mlp_baseline.get_param_values() 25 | 26 | box_env.close() 27 | -------------------------------------------------------------------------------- /tests/garage/tf/embeddings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/embeddings/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/envs/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/envs/test_gym_base.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import gym 4 | import pytest 5 | 6 | from garage.envs import GymEnv 7 | 8 | from tests.helpers import step_env_with_gym_quirks 9 | 10 | 11 | class TestGymEnv: 12 | 13 | @pytest.mark.nightly 14 | @pytest.mark.parametrize('spec', list(gym.envs.registry.all())) 15 | def test_all_gym_envs(self, spec): 16 | if spec._env_name.startswith('Defender'): 17 | pytest.skip( 18 | 'Defender-* envs bundled in atari-py 0.2.x don\'t load') 19 | if spec._env_name.startswith('CarRacing'): 20 | pytest.skip( 21 | 'CarRacing-* envs bundled in atari-py 0.2.x don\'t load') 22 | if spec._env_name.startswith('KellyCoinflip'): 23 | pytest.skip( 24 | 'KellyCoinflip env has tuple observation, not np.array') 25 | if 'Kuka' in spec.id: 26 | # Kuka environments calls py_bullet.resetSimulation() in reset() 27 | # unconditionally, which globally resets other simulations. So 28 | # only one Kuka environment can be tested. 29 | pytest.skip('Skip Kuka Bullet environments') 30 | env = GymEnv(spec.id) 31 | step_env_with_gym_quirks(env, spec, visualize=False) 32 | 33 | @pytest.mark.nightly 34 | @pytest.mark.parametrize('spec', list(gym.envs.registry.all())) 35 | def test_all_gym_envs_pickleable(self, spec): 36 | if spec._env_name.startswith('Defender'): 37 | pytest.skip( 38 | 'Defender-* envs bundled in atari-py 0.2.x don\'t load') 39 | if 'Kuka' in spec.id: 40 | # Kuka environments calls py_bullet.resetSimulation() in reset() 41 | # unconditionally, which globally resets other simulations. So 42 | # only one Kuka environment can be tested. 43 | pytest.skip('Skip Kuka Bullet environments') 44 | elif 'Minitaur' in spec.id: 45 | pytest.skip('Bulle Minitaur envs don\'t load') 46 | env = GymEnv(spec.id) 47 | round_trip = pickle.loads(pickle.dumps(env)) 48 | assert round_trip 49 | -------------------------------------------------------------------------------- /tests/garage/tf/experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/experiment/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/models/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/models/test_parameter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from garage.tf.models.parameter import parameter, recurrent_parameter 5 | 6 | from tests.fixtures import TfGraphTestCase 7 | 8 | 9 | class TestParameter(TfGraphTestCase): 10 | 11 | def setup_method(self): 12 | super().setup_method() 13 | self.input_vars = tf.compat.v1.placeholder(shape=[None, 2, 5], 14 | dtype=tf.float32) 15 | self.step_input_vars = tf.compat.v1.placeholder(shape=[None, 5], 16 | dtype=tf.float32) 17 | self.initial_params = np.array([48, 21, 33]) 18 | 19 | self.data = np.zeros(shape=[5, 2, 5]) 20 | self.step_data = np.zeros(shape=[5, 5]) 21 | self.feed_dict = { 22 | self.input_vars: self.data, 23 | self.step_input_vars: self.step_data 24 | } 25 | 26 | def test_param(self): 27 | param = parameter(input_var=self.input_vars, 28 | length=3, 29 | initializer=tf.constant_initializer( 30 | self.initial_params)) 31 | self.sess.run(tf.compat.v1.global_variables_initializer()) 32 | p = self.sess.run(param, feed_dict=self.feed_dict) 33 | 34 | assert p.shape == (5, 3) 35 | assert np.all(p == self.initial_params) 36 | 37 | def test_recurrent_param(self): 38 | param, _ = recurrent_parameter(input_var=self.input_vars, 39 | step_input_var=self.step_input_vars, 40 | length=3, 41 | initializer=tf.constant_initializer( 42 | self.initial_params)) 43 | self.sess.run(tf.compat.v1.global_variables_initializer()) 44 | p = self.sess.run(param, feed_dict=self.feed_dict) 45 | 46 | assert p.shape == (5, 2, 3) 47 | assert np.array_equal(p, np.full([5, 2, 3], self.initial_params)) 48 | -------------------------------------------------------------------------------- /tests/garage/tf/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/optimizers/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/policies/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/policies/test_gaussian_policies.py: -------------------------------------------------------------------------------- 1 | # yapf: disable 2 | import pytest 3 | 4 | from garage.envs import GymEnv, normalize 5 | from garage.np.baselines import LinearFeatureBaseline 6 | from garage.sampler import LocalSampler 7 | from garage.tf.algos import TRPO 8 | from garage.tf.optimizers import (ConjugateGradientOptimizer, 9 | FiniteDifferenceHVP) 10 | from garage.tf.policies import (GaussianGRUPolicy, GaussianLSTMPolicy, 11 | GaussianMLPPolicy) 12 | from garage.trainer import TFTrainer 13 | 14 | from tests.fixtures import snapshot_config, TfGraphTestCase 15 | 16 | # yapf: enable 17 | 18 | policies = [GaussianGRUPolicy, GaussianLSTMPolicy, GaussianMLPPolicy] 19 | 20 | 21 | class TestGaussianPolicies(TfGraphTestCase): 22 | 23 | @pytest.mark.parametrize('policy_cls', policies) 24 | def test_gaussian_policies(self, policy_cls): 25 | with TFTrainer(snapshot_config, sess=self.sess) as trainer: 26 | env = normalize(GymEnv('Pendulum-v0')) 27 | 28 | policy = policy_cls(name='policy', env_spec=env.spec) 29 | 30 | baseline = LinearFeatureBaseline(env_spec=env.spec) 31 | 32 | sampler = LocalSampler( 33 | agents=policy, 34 | envs=env, 35 | max_episode_length=env.spec.max_episode_length, 36 | is_tf_worker=True) 37 | 38 | algo = TRPO( 39 | env_spec=env.spec, 40 | policy=policy, 41 | baseline=baseline, 42 | sampler=sampler, 43 | discount=0.99, 44 | max_kl_step=0.01, 45 | optimizer=ConjugateGradientOptimizer, 46 | optimizer_args=dict(hvp_approach=FiniteDifferenceHVP( 47 | base_eps=1e-5)), 48 | ) 49 | 50 | trainer.setup(algo, env) 51 | trainer.train(n_epochs=1, batch_size=4000) 52 | env.close() 53 | -------------------------------------------------------------------------------- /tests/garage/tf/q_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/q_functions/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/tf/samplers/__init__.py -------------------------------------------------------------------------------- /tests/garage/tf/samplers/test_ray_batched_sampler_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test whether tensorflow session is properly created and destroyed. 3 | 4 | Other features of ray sampler are tested in 5 | tests/garage/sampler/test_ray_sampler.py 6 | 7 | """ 8 | 9 | from unittest.mock import Mock 10 | 11 | import ray 12 | 13 | # pylint: disable=unused-import 14 | from garage.envs import GridWorldEnv, GymEnv 15 | from garage.np.policies import ScriptedPolicy 16 | from garage.sampler import RaySampler, WorkerFactory 17 | 18 | from tests.fixtures.sampler import ray_local_session_fixture 19 | 20 | 21 | class TestRaySamplerTF(): 22 | """ 23 | Uses mock policy for 4x4 gridworldenv 24 | '4x4': [ 25 | 'SFFF', 26 | 'FHFH', 27 | 'FFFH', 28 | 'HFFG' 29 | ] 30 | 0: left 31 | 1: down 32 | 2: right 33 | 3: up 34 | -1: no move 35 | 'S' : starting point 36 | 'F' or '.': free space 37 | 'W' or 'x': wall 38 | 'H' or 'o': hole (terminates episode) 39 | 'G' : goal 40 | [2,2,1,0,3,1,1,1,2,2,1,1,1,2,2,1] 41 | """ 42 | 43 | def setup_method(self): 44 | self.env = GridWorldEnv(desc='4x4') 45 | self.policy = ScriptedPolicy( 46 | scripted_actions=[2, 2, 1, 0, 3, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1]) 47 | self.algo = Mock(env_spec=self.env.spec, 48 | policy=self.policy, 49 | max_episode_length=16) 50 | 51 | def teardown_method(self): 52 | self.env.close() 53 | 54 | def test_ray_batch_sampler(self, ray_local_session_fixture): 55 | del ray_local_session_fixture 56 | assert ray.is_initialized() 57 | workers = WorkerFactory( 58 | seed=100, max_episode_length=self.algo.max_episode_length) 59 | sampler1 = RaySampler(self.policy, self.env, worker_factory=workers) 60 | sampler1.start_worker() 61 | sampler1.shutdown_worker() 62 | -------------------------------------------------------------------------------- /tests/garage/tf/samplers/test_task_embedding_worker.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import numpy as np 4 | 5 | from garage.envs import GymEnv 6 | from garage.tf.algos.te import TaskEmbeddingWorker 7 | 8 | from tests.fixtures import TfGraphTestCase 9 | from tests.fixtures.envs.dummy import DummyBoxEnv 10 | 11 | 12 | class TestTaskEmbeddingWorker(TfGraphTestCase): 13 | 14 | def test_task_embedding_worker(self): 15 | env = GymEnv(DummyBoxEnv(obs_dim=(1, ))) 16 | env.active_task_one_hot = np.array([1., 0., 0., 0.]) 17 | env._active_task_one_hot = lambda: np.array([1., 0., 0., 0.]) 18 | 19 | a = np.random.random(env.action_space.shape) 20 | z = np.random.random(5) 21 | latent_info = dict(mean=np.random.random(5)) 22 | agent_info = dict(dummy='dummy') 23 | 24 | policy = Mock() 25 | policy.get_latent.return_value = (z, latent_info) 26 | policy.latent_space.flatten.return_value = z 27 | policy.get_action_given_latent.return_value = (a, agent_info) 28 | 29 | worker = TaskEmbeddingWorker(seed=1, 30 | max_episode_length=100, 31 | worker_number=1) 32 | worker.update_agent(policy) 33 | worker.update_env(env) 34 | 35 | episodes = worker.rollout() 36 | assert 'task_onehot' in episodes.env_infos 37 | assert np.array_equal(episodes.env_infos['task_onehot'][0], 38 | env.active_task_one_hot) 39 | assert 'latent' in episodes.agent_infos 40 | assert np.array_equal(episodes.agent_infos['latent'][0], z) 41 | assert 'latent_mean' in episodes.agent_infos 42 | assert np.array_equal(episodes.agent_infos['latent_mean'][0], 43 | latent_info['mean']) 44 | -------------------------------------------------------------------------------- /tests/garage/tf/samplers/test_tf_worker.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from garage.envs import GymEnv 4 | from garage.sampler import DefaultWorker 5 | from garage.tf.samplers import TFWorkerWrapper 6 | from garage.trainer import TFTrainer 7 | 8 | from tests.fixtures import snapshot_config 9 | from tests.fixtures.envs.dummy import DummyBoxEnv 10 | 11 | 12 | class TestTFWorker: 13 | 14 | def test_tf_worker_with_default_session(self): 15 | with TFTrainer(snapshot_config): 16 | tf_worker = TFWorkerWrapper() 17 | worker = DefaultWorker(seed=1, 18 | max_episode_length=100, 19 | worker_number=1) 20 | worker.update_env(GymEnv(DummyBoxEnv())) 21 | tf_worker._inner_worker = worker 22 | tf_worker.worker_init() 23 | assert tf_worker._sess == tf.compat.v1.get_default_session() 24 | assert tf_worker._sess._closed 25 | 26 | def test_tf_worker_without_default_session(self): 27 | tf_worker = TFWorkerWrapper() 28 | worker = DefaultWorker(seed=1, max_episode_length=100, worker_number=1) 29 | worker.update_env(GymEnv(DummyBoxEnv())) 30 | tf_worker._inner_worker = worker 31 | tf_worker.worker_init() 32 | assert tf_worker._sess == tf.compat.v1.get_default_session() 33 | tf_worker.shutdown() 34 | assert tf_worker._sess._closed 35 | -------------------------------------------------------------------------------- /tests/garage/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/torch/__init__.py -------------------------------------------------------------------------------- /tests/garage/torch/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/torch/algos/__init__.py -------------------------------------------------------------------------------- /tests/garage/torch/algos/test_ppo.py: -------------------------------------------------------------------------------- 1 | """This script creates a test that fails when PPO performance is too low.""" 2 | import pytest 3 | import torch 4 | 5 | from garage.envs import GymEnv, normalize 6 | from garage.experiment import deterministic 7 | from garage.sampler import LocalSampler 8 | from garage.torch.algos import PPO 9 | from garage.torch.policies import GaussianMLPPolicy 10 | from garage.torch.value_functions import GaussianMLPValueFunction 11 | from garage.trainer import Trainer 12 | 13 | from tests.fixtures import snapshot_config 14 | 15 | 16 | class TestPPO: 17 | """Test class for PPO.""" 18 | 19 | def setup_method(self): 20 | """Setup method which is called before every test.""" 21 | self.env = normalize( 22 | GymEnv('InvertedDoublePendulum-v2', max_episode_length=100)) 23 | self.policy = GaussianMLPPolicy( 24 | env_spec=self.env.spec, 25 | hidden_sizes=(64, 64), 26 | hidden_nonlinearity=torch.tanh, 27 | output_nonlinearity=None, 28 | ) 29 | self.value_function = GaussianMLPValueFunction(env_spec=self.env.spec) 30 | 31 | def teardown_method(self): 32 | """Teardown method which is called after every test.""" 33 | self.env.close() 34 | 35 | @pytest.mark.mujoco 36 | def test_ppo_pendulum(self): 37 | """Test PPO with Pendulum environment.""" 38 | deterministic.set_seed(0) 39 | sampler = LocalSampler( 40 | agents=self.policy, 41 | envs=self.env, 42 | max_episode_length=self.env.spec.max_episode_length) 43 | trainer = Trainer(snapshot_config) 44 | algo = PPO(env_spec=self.env.spec, 45 | policy=self.policy, 46 | value_function=self.value_function, 47 | sampler=sampler, 48 | discount=0.99, 49 | gae_lambda=0.97, 50 | lr_clip_range=2e-1) 51 | 52 | trainer.setup(algo, self.env) 53 | last_avg_ret = trainer.train(n_epochs=10, batch_size=100) 54 | assert last_avg_ret > 0 55 | -------------------------------------------------------------------------------- /tests/garage/torch/algos/test_trpo.py: -------------------------------------------------------------------------------- 1 | """This script creates a test that fails when TRPO performance is too low.""" 2 | import pytest 3 | import torch 4 | 5 | from garage.envs import GymEnv, normalize 6 | from garage.experiment import deterministic 7 | from garage.sampler import LocalSampler 8 | from garage.torch.algos import TRPO 9 | from garage.torch.policies import GaussianMLPPolicy 10 | from garage.torch.value_functions import GaussianMLPValueFunction 11 | from garage.trainer import Trainer 12 | 13 | from tests.fixtures import snapshot_config 14 | 15 | 16 | class TestTRPO: 17 | """Test class for TRPO.""" 18 | 19 | def setup_method(self): 20 | """Setup method which is called before every test.""" 21 | self.env = normalize( 22 | GymEnv('InvertedDoublePendulum-v2', max_episode_length=100)) 23 | self.policy = GaussianMLPPolicy( 24 | env_spec=self.env.spec, 25 | hidden_sizes=(64, 64), 26 | hidden_nonlinearity=torch.tanh, 27 | output_nonlinearity=None, 28 | ) 29 | self.value_function = GaussianMLPValueFunction(env_spec=self.env.spec) 30 | 31 | def teardown_method(self): 32 | """Teardown method which is called after every test.""" 33 | self.env.close() 34 | 35 | @pytest.mark.mujoco 36 | def test_trpo_pendulum(self): 37 | """Test TRPO with Pendulum environment.""" 38 | deterministic.set_seed(0) 39 | sampler = LocalSampler( 40 | agents=self.policy, 41 | envs=self.env, 42 | max_episode_length=self.env.spec.max_episode_length) 43 | trainer = Trainer(snapshot_config) 44 | algo = TRPO(env_spec=self.env.spec, 45 | policy=self.policy, 46 | value_function=self.value_function, 47 | sampler=sampler, 48 | discount=0.99, 49 | gae_lambda=0.98) 50 | 51 | trainer.setup(algo, self.env) 52 | last_avg_ret = trainer.train(n_epochs=10, batch_size=100) 53 | assert last_avg_ret > 0 54 | -------------------------------------------------------------------------------- /tests/garage/torch/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/torch/modules/__init__.py -------------------------------------------------------------------------------- /tests/garage/torch/optimizers/test_differentiable_sgd.py: -------------------------------------------------------------------------------- 1 | """Tests for DifferentialSGD optimizer.""" 2 | import torch 3 | 4 | from garage.torch import update_module_params 5 | from garage.torch.optimizers import DifferentiableSGD 6 | 7 | 8 | def test_differentiable_sgd(): 9 | """Test second order derivative after taking optimization step.""" 10 | policy = torch.nn.Linear(10, 10, bias=False) 11 | lr = 0.01 12 | diff_sgd = DifferentiableSGD(policy, lr=lr) 13 | 14 | named_theta = dict(policy.named_parameters()) 15 | theta = list(named_theta.values())[0] 16 | meta_loss = torch.sum(theta**2) 17 | meta_loss.backward(create_graph=True) 18 | 19 | diff_sgd.step() 20 | 21 | theta_prime = list(policy.parameters())[0] 22 | loss = torch.sum(theta_prime**2) 23 | update_module_params(policy, named_theta) 24 | diff_sgd.zero_grad() 25 | loss.backward() 26 | 27 | result = theta.grad 28 | 29 | dtheta_prime = 1 - 2 * lr # dtheta_prime/dtheta 30 | dloss = 2 * theta_prime # dloss/dtheta_prime 31 | expected_result = dloss * dtheta_prime # dloss/dtheta 32 | 33 | assert torch.allclose(result, expected_result) 34 | -------------------------------------------------------------------------------- /tests/garage/torch/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/torch/policies/__init__.py -------------------------------------------------------------------------------- /tests/garage/torch/q_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/garage/torch/q_functions/__init__.py -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlworkgroup/garage/2d594803636e341660cab0e81343abbe9a325353/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /tests/mock.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | 4 | class PickleableMagicMock(mock.MagicMock): 5 | def __reduce__(self): 6 | return (mock.MagicMock, ()) 7 | -------------------------------------------------------------------------------- /tests/quirks.py: -------------------------------------------------------------------------------- 1 | """Documented breakages and quirks caused by dependencies.""" 2 | 3 | # openai/gym environments known to not implement render() 4 | # 5 | # e.g. 6 | # > gym/core.py", line 111, in render 7 | # > raise NotImplementedError 8 | # > NotImplementedError 9 | # 10 | # Tests calling render() on these should verify they raise NotImplementedError 11 | # ``` 12 | # with pytest.raises(NotImplementedError): 13 | # env.render() 14 | # ``` 15 | KNOWN_GYM_RENDER_NOT_IMPLEMENTED = [ 16 | # Please keep alphabetized 17 | 'Blackjack-v0', 18 | 'GuessingGame-v0', 19 | 'HotterColder-v0', 20 | 'NChain-v0', 21 | 'Roulette-v0', 22 | ] 23 | -------------------------------------------------------------------------------- /tests/wrappers.py: -------------------------------------------------------------------------------- 1 | """Test environment wrapper.""" 2 | 3 | import gym 4 | 5 | 6 | class AutoStopEnv(gym.Wrapper): 7 | """Environment wrapper that stops episode at step max_episode_length.""" 8 | 9 | def __init__(self, env=None, env_name='', max_episode_length=100): 10 | """Create an AutoStepEnv. 11 | 12 | Args: 13 | env (gym.Env): Environment to be wrapped. 14 | env_name (str): Name of the environment. 15 | max_episode_length (int): Maximum length of the episode. 16 | """ 17 | if env_name: 18 | super().__init__(gym.make(env_name)) 19 | else: 20 | super().__init__(env) 21 | 22 | self._episode_step = 0 23 | self._max_episode_length = max_episode_length 24 | 25 | def step(self, action): 26 | """Step the wrapped environment. 27 | 28 | Args: 29 | action (np.ndarray): the action. 30 | 31 | Returns: 32 | np.ndarray: Next observation 33 | float: Reward 34 | bool: Termination signal 35 | dict: Environment information 36 | """ 37 | self._episode_step += 1 38 | next_obs, reward, done, info = self.env.step(action) 39 | if self._episode_step == self._max_episode_length: 40 | done = True 41 | self._episode_step = 0 42 | return next_obs, reward, done, info 43 | 44 | def reset(self, **kwargs): 45 | """Reset the wrapped environment. 46 | 47 | Args: 48 | **kwargs: Keyword arguments. 49 | 50 | Returns: 51 | np.ndarray: Initial observation. 52 | """ 53 | return self.env.reset(**kwargs) 54 | --------------------------------------------------------------------------------